Source code for statetracker.ops.shuffle_op

"""ShuffleOp - Randomly shuffle state values given a seed."""

import random

from ..imports import Integral, Optional, Sequence, State_type, beartype
from ..operation import Operation


[docs] @beartype def shuffle( state: State_type, seed: Optional[Integral] = None, permutation: Optional[Sequence[Integral]] = None, name: Optional[str] = None, ): """ Create a State with the values of a parent State randomly shuffled. Parameters ---------- state : State_type The State whose values will be shuffled. seed : Optional[Integral], default=None Random seed to generate the shuffle permutation. Cannot be provided with 'permutation'. permutation : Optional[Sequence[Integral]], default=None Explicit permutation of parent value indices, as a sequence of length parent.num_values. Cannot be provided with 'seed'. name : Optional[str], default=None Name for the resulting shuffled State. Returns ------- State_type A State whose value order corresponds to a permutation of the parent's values. """ from ..state import State result = State( _parents=(state,), _op=ShuffleOp(state.num_values, seed=seed, permutation=permutation), name=name, ) return result
[docs] @beartype class ShuffleOp(Operation): """Randomly shuffle state values using a deterministic seed."""
[docs] def __init__( self, num_parent_values: Integral, seed: Optional[Integral] = None, permutation: Optional[Sequence[Integral]] = None, ): match (seed, permutation): case (None, None): self.seed = random.randint(0, 2**32 - 1) indices = list(range(num_parent_values)) random.Random(self.seed).shuffle(indices) self.permutation = tuple(indices) case (_, None): self.seed = seed indices = list(range(num_parent_values)) random.Random(seed).shuffle(indices) self.permutation = tuple(indices) case (None, _): if len(permutation) != num_parent_values: raise ValueError( f"permutation has length {len(permutation)}, expected {num_parent_values}." ) if set(permutation) != set(range(num_parent_values)): raise ValueError( f"permutation must contain exactly the integers 0 to {num_parent_values - 1}." ) self.seed = None self.permutation = tuple(permutation) case (_, _): raise ValueError( "Cannot specify both 'seed' and 'permutation'; they are mutually exclusive." )
[docs] def compute_num_states(self, parent_num_values): return parent_num_values[0]
[docs] def decompose(self, value, parent_num_values): if value is None: return (None,) return (self.permutation[value],)