"""ProductOp - Cartesian product of N states."""
from ..imports import Literal, Optional, Sequence, State_type, beartype, math
from ..operation import Operation
# Module-level flag for product ordering mode
_product_order_mode: str = "first_state_slowest"
[docs]
def set_product_order_mode(mode: Literal["first_state_fastest", "first_state_slowest"]) -> None:
"""Set the global ordering mode for ordered_product()."""
global _product_order_mode
if mode not in ("first_state_fastest", "first_state_slowest"):
raise ValueError(
f"mode must be 'first_state_fastest' or 'first_state_slowest', got {mode!r}"
)
_product_order_mode = mode
[docs]
def get_product_order_mode() -> str:
"""Get the current global ordering mode for ordered_product()."""
return _product_order_mode
def _collect_product_bases(state: State_type) -> list:
"""
Recursively collect base states, flattening through ProductOp only.
This enables proper deduplication when the same state appears at different
levels of a nested product hierarchy (diamond pattern).
Parameters
----------
state : State_type
A state to collect bases from.
Returns
-------
list
List of base states. For ProductOp states, this recursively collects
from parents. For leaf states or non-product operations, returns [state].
"""
if not state._parents:
# Leaf state (no parents)
return [state]
elif isinstance(state._op, ProductOp):
# Product state - recurse into parents
bases = []
for parent in state._parents:
bases.extend(_collect_product_bases(parent))
return bases
else:
# Other operation (stack, slice, etc.) - treat as atomic
return [state]
[docs]
def ordered_product(states: Sequence[State_type], name: Optional[str] = None):
"""
Create a product State from the provided states, removing duplicates and
automatically imposing an order based on state.iter_order and state.id.
This function recursively flattens nested product states before deduplication,
which handles diamond patterns where the same state appears both as a direct
parent and as an ancestor through another parent.
Parameters
----------
states : Sequence[State_type]
Sequence of parent States to combine into the product. Duplicates are removed
and order is determined by (iter_order, id).
name : Optional[str], default=None
Name for the resulting product State.
Returns
-------
State_type
A State representing the ordered, uniquified cartesian product of the input states.
Notes
-----
Product is associative, so nested products are flattened:
``ordered_product([A*B, C, D*A])`` becomes ``ordered_product([A, B, C, D])``
Non-product operations (like stack, slice) are NOT flattened:
``ordered_product([stack(A,B), C])`` keeps ``stack(A,B)`` as an atomic unit.
"""
from ..state import State
if len(states) == 0:
return State(1, name=name)
# Recursively collect bases, flattening through ProductOp only
base_states = []
for s in states:
base_states.extend(_collect_product_bases(s))
# Deduplicate by sync group: synced states represent a single
# dimension of variation and must not multiply in the product.
seen_groups: set = set()
unique_states = []
for s in base_states:
gid = id(s._synced_group)
if gid not in seen_groups:
seen_groups.add(gid)
unique_states.append(s)
id_sign = -1 if _product_order_mode == "first_state_slowest" else 1
ordered_states = sorted(unique_states, key=lambda s: (s._iter_order, id_sign * s._id))
return State(_parents=ordered_states, _op=ProductOp(), name=name)
[docs]
@beartype
def product(states: Sequence[State_type], name: Optional[str] = None):
"""
Create a State representing the cartesian product of the provided States.
Parameters
----------
states : Sequence[State_type]
Sequence of parent States to combine into a product State. No duplicates allowed.
name : Optional[str], default=None
Optional name for the resulting product State.
Returns
-------
State_type
A State whose values index the cartesian product of the input states' values.
"""
from ..state import State
if len(states) != len(set(states)):
raise ValueError("product() does not allow duplicate states")
if len(states) == 0:
result = State(1, name=name)
else:
result = State(_parents=states, _op=ProductOp(), name=name)
return result
[docs]
@beartype
class ProductOp(Operation):
"""Cartesian product of N states."""
[docs]
def compute_num_states(self, parent_num_values: Sequence):
# If ALL parents are fixed (None), result is fixed (None)
# Otherwise, treat None as 1 in the product
non_none = [n for n in parent_num_values if n is not None]
if not non_none:
return None # All fixed -> fixed
return math.prod(non_none)
[docs]
def decompose(self, value, parent_num_values):
if value is None:
return tuple(None for _ in parent_num_values)
result = []
for n in parent_num_values:
if n is None:
# Fixed state: always 0 when active
result.append(0)
else:
result.append(value % n)
value //= n
return tuple(result)