Source code for statetracker.ops.sample_op

"""SampleOp - Sample values from a parent state."""

import random

from ..imports import Integral, Optional, Sequence, State_type, beartype
from ..operation import Operation


[docs] @beartype def sample( state: State_type, num_values: Optional[Integral] = None, sampled_states: Optional[Sequence[Integral]] = None, seed: Optional[Integral] = None, with_replacement: bool = True, name: Optional[str] = None, ): """ Create a State that samples values from the provided parent State. Parameters ---------- state : State_type Parent State whose values will be sampled. num_values : Optional[Integral], default=None Number of values to sample from the parent. Mutually exclusive with 'sampled_states'. sampled_states : Optional[Sequence[Integral]], default=None Explicit sequence of parent value indices to use as samples. Mutually exclusive with 'num_values'. seed : Optional[Integral], default=None Random seed for sampling. Only relevant if sampling with 'num_values' and not supplying 'sampled_states'. with_replacement : bool, default=True Whether to sample with replacement (True) or without replacement (False). name : Optional[str], default=None Name for the resulting sampled State. Returns ------- State_type A State whose values are a sampled subset of those of the parent State. """ from ..state import State return State( _parents=(state,), _op=SampleOp( state.num_values, num_values=num_values, sampled_states=sampled_states, seed=seed, with_replacement=with_replacement, ), name=name, )
[docs] @beartype class SampleOp(Operation): """Sample values from parent state."""
[docs] def __init__( self, num_parent_values: Integral, num_values: Optional[Integral] = None, sampled_states: Optional[Sequence[Integral]] = None, seed: Optional[Integral] = None, with_replacement: bool = True, ): # Validate mutually exclusive args match (num_values, sampled_states): case (None, None): raise ValueError("Must specify either 'num_values' or 'sampled_states'.") case (_, None): # Sample num_values from parent using seed if not with_replacement and num_values > num_parent_values: raise ValueError( f"num_values ({num_values}) exceeds parent.num_values ({num_parent_values}) " f"and with_replacement=False." ) self.seed = seed if seed is not None else random.randint(0, 2**32 - 1) rng = random.Random(self.seed) if with_replacement: self.sampled_states = tuple(rng.choices(range(num_parent_values), k=num_values)) else: self.sampled_states = tuple(rng.sample(range(num_parent_values), k=num_values)) case (None, _): # Explicit sampled_states provided if seed is not None: raise ValueError("Cannot specify 'seed' with 'sampled_states'.") # Validate states are in valid range for s in sampled_states: if s < 0 or s >= num_parent_values: raise ValueError(f"Value {s} out of range [0, {num_parent_values}).") self.seed = None self.sampled_states = tuple(sampled_states) case (_, _): raise ValueError("Cannot specify both 'num_values' and 'sampled_states'.")
[docs] def compute_num_states(self, parent_num_values): return len(self.sampled_states)
[docs] def decompose(self, value, parent_num_values): if value is None: return (None,) return (self.sampled_states[value],)