Source code for statetracker.ops.stack_op

"""StackOp - Disjoint union of N states."""

from ..imports import Optional, Sequence, State_type, beartype
from ..operation import Operation


[docs] @beartype def stack(states: Sequence[State_type], name: Optional[str] = None): """ Create a State representing the disjoint union of the provided States. Parameters ---------- states : Sequence[State_type] Sequence of parent States to combine into a disjoint union State. name : Optional[str], default=None Optional name for the resulting disjoint union State. Returns ------- State_type A State whose values correspond to the disjoint union of the input States' values. """ from ..state import State if len(states) == 0: result = State(0) else: result = State(_parents=states, _op=StackOp(), name=name) return result
[docs] @beartype class StackOp(Operation):
[docs] def compute_num_states(self, parent_num_values): # Treat None (fixed states) as 1 return sum(n if n is not None else 1 for n in parent_num_values)
[docs] def decompose(self, value, parent_num_values): if value is None: return tuple(None for _ in parent_num_values) cumsum = 0 for i, n in enumerate(parent_num_values): effective_n = n if n is not None else 1 if value < cumsum + effective_n: return tuple( value - cumsum if j == i else None for j in range(len(parent_num_values)) ) cumsum += effective_n raise ValueError(f"Invalid value {value}")