Source code for statetracker.manager

"""Manager - Context manager for handling State objects."""

import logging

logger = logging.getLogger(__name__)

try:
    from IPython.display import display
except ImportError:
    display = print
import pandas as pd

from .utils import clean_df_int_columns


[docs] class Manager: """Context manager for handling State objects.""" _active_manager = None
[docs] def __init__(self): """Initialize empty state manager.""" self._states = [] self._next_id = 0
[docs] def __enter__(self): """Enter context and set as active manager, saving any previous.""" self._previous_manager = Manager._active_manager Manager._active_manager = self return self
[docs] def __exit__(self, exc_type, exc_val, exc_tb): """Exit context and restore previous active manager.""" Manager._active_manager = self._previous_manager self._previous_manager = None return False
[docs] def register(self, state): """Register a state with this manager.""" state._id = self._next_id self._next_id += 1 if state._name is None: state._name = f"State[{state._id}]" self._states.append(state) logger.debug( "Registered state id=%s name=%s num_values=%s", state._id, state._name, state.num_values )
[docs] def get_ancestors(self, state, visited=None): """Recursively collect all ancestor states and sync group peers.""" if visited is None: visited = [] logger.debug("Getting ancestors for state id=%s name=%s", state._id, state._name) if state in visited: return visited visited.append(state) # Include sync group peers and recurse into their parents too for peer in state._synced_group: if peer not in visited and peer is not state: self.get_ancestors(peer, visited) # Then recurse into computational parents for parent in state._parents: self.get_ancestors(parent, visited) return visited
[docs] def clear_all_values(self): """Directly clear all states and sync groups to None (bypasses setter).""" logger.debug("Clearing all values for %d states", len(self._states)) seen_groups = set() for state in self._states: state._value = None group = state._synced_group if id(group) not in seen_groups: group._value = None seen_groups.add(id(group))
[docs] def inactivate_all(self, states=None): """Set states to inactive value (None).""" targets = states if states is not None else self._states for state in targets: state.value = None
[docs] def reset_all(self, states=None): """Reset states to value 0.""" targets = states if states is not None else self._states for state in targets: state.reset()
[docs] def get_all_names(self): """Return list of names of all registered states.""" return [s.name for s in self._states]
[docs] def get_by_name(self, name): """Return state by name.""" for state in self._states: if state.name == name: return state raise KeyError(f"No state with name '{name}' found")
[docs] def get_iteration_df(self, iter_state, states=None): """Return DataFrame showing value of states as iter_state is iterated.""" targets = states if states is not None else self._states col_names = [f"{s.name}" for s in targets] self.inactivate_all() iter_state.reset() rows = [] for _ in iter_state: row = [s.value for s in targets] rows.append(row) self.inactivate_all() df = pd.DataFrame(rows, columns=col_names) df.index.name = f"{iter_state.name}" if iter_state.name in df.columns: df = df.drop(columns=[iter_state.name]) df = clean_df_int_columns(df) return df
[docs] def print_graph(self, style: str = "clean"): """Print an ASCII tree visualization of the state dependency graph. Args: style: Display style - 'clean' (default), 'minimal', or 'repr'. """ from .text_viz import print_graph print_graph(self._states, style=style)