def __init__(self, copy_from=None, state=None, alg=None): """Creates a generator. The new generator will be initialized by one of the following ways, with decreasing precedence: (1) If `copy_from` is not None, the new generator is initialized by copying information from another generator. (2) If `state` and `alg` are not None (they must be set together), the new generator is initialized by a state. Args: copy_from: a generator to be copied from. state: a vector of dtype STATE_TYPE representing the initial state of the RNG, whose length and semantics are algorithm-specific. If it's a variable, the generator will reuse it instead of creating a new variable. alg: the RNG algorithm. Possible values are `tf.random.Algorithm.PHILOX` for the Philox algorithm and `tf.random.Algorithm.THREEFRY` for the ThreeFry algorithm (see paper 'Parallel Random Numbers: As Easy as 1, 2, 3' [https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]). The string names `"philox"` and `"threefry"` can also be used. Note `PHILOX` guarantees the same numbers are produced (given the same random state) across all architectures (CPU, GPU, XLA etc). """ # TODO(b/175072242): Remove distribution-strategy dependencies in this file. if ds_context.has_strategy(): self._distribution_strategy = ds_context.get_strategy() else: self._distribution_strategy = None if copy_from is not None: # All other arguments should be None assert (alg or state) is None self._state_var = self._create_variable(copy_from.state, dtype=STATE_TYPE, trainable=False) self._alg = copy_from.algorithm else: assert alg is not None and state is not None if ds_context.has_strategy(): strat_name = type(ds_context.get_strategy()).__name__ # TODO(b/174610856): Support CentralStorageStrategy and # ParameterServerStrategy. if "CentralStorage" in strat_name or "ParameterServer" in strat_name: raise ValueError("%s is not supported yet" % strat_name) alg = stateless_random_ops.convert_alg_to_int(alg) if isinstance(state, variables.Variable): _check_state_shape(state.shape, alg) self._state_var = state else: state = _convert_to_state_tensor(state) _check_state_shape(state.shape, alg) self._state_var = self._create_variable(state, dtype=STATE_TYPE, trainable=False) self._alg = alg
def from_non_deterministic_state(cls, alg=None): """Creates a generator by non-deterministically initializing its state. The source of the non-determinism will be platform- and time-dependent. Args: alg: (optional) the RNG algorithm. If None, it will be auto-selected. See `__init__` for its possible values. Returns: The new generator. """ if alg is None: # TODO(b/170668986): more sophisticated algorithm selection alg = DEFAULT_ALGORITHM alg = stateless_random_ops.convert_alg_to_int(alg) state = non_deterministic_ints(shape=[_get_state_size(alg)], dtype=SEED_TYPE) return cls(state=state, alg=alg)
def rng_bit_generator(algorithm, initial_state, shape, dtype): """Stateless PRNG bit generator. Wraps the XLA RngBitGenerator operator, documented at https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator. Args: algorithm: The PRNG algorithm to use, one of tf.random.Algorithm.{PHILOX, THREEFRY, AUTO_SELECT}. initial_state: Initial state for the PRNG algorithm. For THREEFRY, it should be a u64[2] and for PHILOX a u64[3]. shape: The output shape of the generated data. dtype: The type of the tensor. Returns: a tuple with a new state and generated data of the given shape. """ alg_int = stateless_random_ops.convert_alg_to_int(algorithm) return gen_xla_ops.xla_rng_bit_generator(alg_int, initial_state, shape, dtype=dtype)
def create_rng_state(seed, alg): """Creates a RNG state from an integer or a vector. Example: >>> tf.random.create_rng_state( ... 1234, "philox") <tf.Tensor: shape=(3,), dtype=int64, numpy=array([1234, 0, 0])> >>> tf.random.create_rng_state( ... [12, 34], "threefry") <tf.Tensor: shape=(2,), dtype=int64, numpy=array([12, 34])> Args: seed: an integer or 1-D numpy array. alg: the RNG algorithm. Can be a string, an `Algorithm` or an integer. Returns: a 1-D numpy array whose size depends on the algorithm. """ alg = stateless_random_ops.convert_alg_to_int(alg) return _make_state_from_seed(seed, alg)
def from_key_counter(cls, key, counter, alg): """Creates a generator from a key and a counter. This constructor only applies if the algorithm is a counter-based algorithm. See method `key` for the meaning of "key" and "counter". Args: key: the key for the RNG, a scalar of type STATE_TYPE. counter: a vector of dtype STATE_TYPE representing the initial counter for the RNG, whose length is algorithm-specific., alg: the RNG algorithm. If None, it will be auto-selected. See `__init__` for its possible values. Returns: The new generator. """ counter = _convert_to_state_tensor(counter) key = _convert_to_state_tensor(key) alg = stateless_random_ops.convert_alg_to_int(alg) counter.shape.assert_is_compatible_with([_get_state_size(alg) - 1]) key.shape.assert_is_compatible_with([]) key = array_ops.reshape(key, [1]) state = array_ops.concat([counter, key], 0) return cls(state=state, alg=alg)
def from_seed(cls, seed, alg=None): """Creates a generator from a seed. A seed is a 1024-bit unsigned integer represented either as a Python integer or a vector of integers. Seeds shorter than 1024-bit will be padded. The padding, the internal structure of a seed and the way a seed is converted to a state are all opaque (unspecified). The only semantics specification of seeds is that two different seeds are likely to produce two independent generators (but no guarantee). Args: seed: the seed for the RNG. alg: (optional) the RNG algorithm. If None, it will be auto-selected. See `__init__` for its possible values. Returns: The new generator. """ if alg is None: # TODO(b/170668986): more sophisticated algorithm selection alg = DEFAULT_ALGORITHM alg = stateless_random_ops.convert_alg_to_int(alg) state = create_rng_state(seed, alg) return cls(state=state, alg=alg)