"""State class for composable iteration."""
import logging
logger = logging.getLogger(__name__)
from .imports import Integral, Operation_type, Optional, Real, Sequence, State_type, Union, beartype
from .manager import Manager
from .sync_group import SynchronizedGroup
[docs]
@beartype
class ConflictingValueAssignmentError(RuntimeError):
"""Raised when a state receives conflicting value assignments during propagation."""
pass
[docs]
@beartype
class State:
"""A state that can be iterated and composed with other states."""
[docs]
def __init__(
self,
num_values: Optional[Integral] = None,
name: Optional[str] = None,
value: Optional[Integral] = None,
iter_order: Optional[Real] = None,
*,
_parents: Optional[Sequence[State_type]] = None,
_op: Optional[Operation_type] = None,
):
"""Create a state."""
# Require an active Manager
if Manager._active_manager is None:
raise RuntimeError("State must be created within a Manager context")
self._id = None
self._name = name
self._parents = tuple(_parents) if _parents else ()
self._op = _op
# Set iter_order
if iter_order is None:
if len(self._parents) > 0:
iter_order = min(p.iter_order for p in self._parents)
else:
iter_order = 0
self._iter_order = iter_order
if _parents and _op:
parent_num_values = tuple(p.num_values for p in _parents)
self._num_values = _op.compute_num_states(parent_num_values)
else:
# Default to 1 if None is passed
self._num_values = 1 if num_values is None else num_values
# Each state starts in its own sync group
self._synced_group = SynchronizedGroup(self)
# Register with Manager
self._manager = Manager._active_manager
self._manager.register(self)
# Set value: default to None (inactive) unless explicitly provided
# States become active (value=0) when iteration starts via reset()
if value is not None:
self.value = value
else:
self._value = None
@property
def iter_order(self):
"""Iteration order for this state."""
return self._iter_order
@iter_order.setter
def iter_order(self, value: Real):
"""Set iteration order for this state."""
self._iter_order = value
@property
def num_values(self):
"""Number of values this state can take (read-only)."""
return self._num_values
@property
def id(self):
"""Unique ID assigned by Manager (None if not registered)."""
return self._id
@property
def name(self):
"""Name of this counter."""
return self._name
@name.setter
def name(self, value):
self._name = value
[docs]
def named(self, name):
"""Set name and return self for fluent chaining."""
self._name = name
return self
[docs]
def synced_parent(self, name: Optional[str] = None):
"""Create a synced state that shares this state's value."""
from .ops import synced_to
return synced_to(self, name=name)
[docs]
def sync_with(self, other: State_type) -> None:
"""Synchronize this state with another (bidirectional)."""
if self._synced_group is other._synced_group:
return
logger.debug("Syncing state id=%s with state id=%s", self._id, other._id)
if len(self._synced_group) >= len(other._synced_group):
self._synced_group.merge(other._synced_group)
else:
other._synced_group.merge(self._synced_group)
@property
def value(self):
"""Current value of this state."""
return self._value
@value.setter
def value(self, val: Optional[Integral]):
"""Set value and propagate to parents."""
logger.debug("Setting value=%s for state id=%s name=%s", val, self._id, self._name)
if val is None:
self._synced_group.inactivate_trees()
elif any(s._parents for s in self._synced_group._states):
# At least one state in the sync group has parents: clear all and propagate
self._manager.clear_all_values()
self._synced_group.set_inactivated_values_in_trees(val)
else:
# All states are leaves: set sync group value directly (no propagation needed)
self._synced_group._value = val
for state in self._synced_group._states:
state._value = val if val < state.num_values else None
[docs]
def advance(self):
"""Advance to next value (wraps around using this state's num_values)."""
if self._synced_group._value is None:
logger.warning(
"Attempting to advance inactive state id=%s name=%s", self._id, self._name
)
raise RuntimeError("Cannot advance an inactive state (group value=None)")
# Advance using this state's num_values for wrapping
new_val = (self._synced_group._value + 1) % self._num_values
self.value = new_val
[docs]
def reset(self, value: Integral = 0):
"""Reset to specified value (default 0)."""
self.value = value
@property
def is_active(self) -> bool:
"""True if state is active (value is not None). Read-only."""
return self._value is not None
[docs]
def __iter__(self):
"""Iterate through all values of this state."""
self.reset()
for _ in range(self._num_values):
yield self._value
self.advance()
self.reset()
[docs]
def __getitem__(self, key: Union[Integral, slice]):
"""Create sliced state: B = A[1:5] or A[::2] or A[::-1]."""
from .ops import SliceOp
# If key is an int, convert it to a slice for that single value
if isinstance(key, int):
key = slice(key, key + 1, 1)
start, stop, step = key.indices(self._num_values)
return State(_parents=(self,), _op=SliceOp(start, stop, step))
[docs]
def copy(self, name: Optional[str] = None):
"""Create a shallow copy with the same parents but a new State object."""
if self._parents:
new_state = State(_parents=self._parents, _op=self._op, name=name)
else:
new_state = State(self._num_values, name=name)
# Direct assignment to avoid triggering global clear
new_state._synced_group._value = self._value
new_state._value = self._value
return new_state
[docs]
def deepcopy(self, name: Optional[str] = None):
"""Create a deep copy with all ancestors also copied."""
if not self._parents:
new_state = State(self._num_values, name=name)
else:
new_parents = tuple(p.deepcopy() for p in self._parents)
new_state = State(_parents=new_parents, _op=self._op, name=name)
# Direct assignment to avoid triggering global clear
new_state._synced_group._value = self._value
new_state._value = self._value
return new_state
def __repr__(self):
if self._parents:
op_name = type(self._op).__name__
return f"State(name={self._name!r}, id={self._id}, op={op_name}, num_values={self._num_values}, value={self._value}, iter_order={self._iter_order})"
else:
return f"State(name={self._name!r}, id={self._id}, num_values={self._num_values}, value={self._value}, iter_order={self._iter_order})"
[docs]
def print_dag(self, style: str = "clean"):
"""Print the ASCII tree visualization rooted at this state.
Args:
style: Display style - 'clean' (default), 'minimal', or 'repr'.
"""
from .text_viz import print_dag
print_dag(self, style=style)
[docs]
def get_iteration_df(self, **kwargs):
ancestors = self._manager.get_ancestors(self)
df = self._manager.get_iteration_df(self, states=ancestors, **kwargs)
return df
[docs]
def get_ancestors(self):
return self._manager.get_ancestors(self)
def _has_auto_name(self) -> bool:
"""Return True if this state has an auto-generated name (State[N])."""
import re
return bool(re.match(r"^State\[\d+\]$", self._name or ""))
[docs]
def get_states(self, include_inactive: bool = True, named_only: bool = True) -> dict:
"""Return dict of {name: value} for this state and all ancestors.
Args:
include_inactive: If True (default), include states with None value.
If False, only include states that are currently active.
named_only: If True (default), exclude states with auto-generated
names (State[N]). If False, include all states.
Returns:
Dictionary mapping state names to their current values,
in reverse topological order (derived states first, parents last).
"""
from collections import deque
# BFS for reverse topological order (self first, then parents, then grandparents...)
ordered = []
visited = set()
queue = deque([self])
while queue:
state = queue.popleft()
if state.id in visited:
continue
visited.add(state.id)
ordered.append(state)
queue.extend(state._parents)
# Apply filters
if named_only:
ordered = [s for s in ordered if not s._has_auto_name()]
if not include_inactive:
ordered = [s for s in ordered if s.value is not None]
return {s.name: s.value for s in ordered}
[docs]
def print_states(self, include_inactive: bool = True, named_only: bool = True) -> None:
"""Print current values of this state and its ancestors.
Args:
include_inactive: If True (default), include states with None value.
If False, only include states that are currently active.
named_only: If True (default), exclude states with auto-generated
names (State[N]). If False, include all states.
"""
states = self.get_states(include_inactive=include_inactive, named_only=named_only)
parts = [f"{name}={value}" for name, value in states.items()]
print(", ".join(parts))