Exemplo n.º 1
0
    def __init__(self, copy_from=None, state=None, alg=None):
        """Creates a generator.

    The new generator will be initialized by one of the following ways, with
    decreasing precedence:
    (1) If `copy_from` is not None, the new generator is initialized by copying
        information from another generator.
    (2) If `state` and `alg` are not None (they must be set together), the new
        generator is initialized by a state.

    Args:
      copy_from: a generator to be copied from.
      state: a vector of dtype STATE_TYPE representing the initial state of the
        RNG, whose length and semantics are algorithm-specific. If it's a
        variable, the generator will reuse it instead of creating a new
        variable.
      alg: the RNG algorithm. Possible values are
        `tf.random.Algorithm.PHILOX` for the Philox algorithm and
        `tf.random.Algorithm.THREEFRY` for the ThreeFry algorithm
        (see paper 'Parallel Random Numbers: As Easy as 1, 2, 3'
        [https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]).
        The string names `"philox"` and `"threefry"` can also be used.
        Note `PHILOX` guarantees the same numbers are produced (given
        the same random state) across all architectures (CPU, GPU, XLA etc).
    """
        # TODO(b/175072242): Remove distribution-strategy dependencies in this file.
        if ds_context.has_strategy():
            self._distribution_strategy = ds_context.get_strategy()
        else:
            self._distribution_strategy = None
        if copy_from is not None:
            # All other arguments should be None
            assert (alg or state) is None
            self._state_var = self._create_variable(copy_from.state,
                                                    dtype=STATE_TYPE,
                                                    trainable=False)
            self._alg = copy_from.algorithm
        else:
            assert alg is not None and state is not None
            if ds_context.has_strategy():
                strat_name = type(ds_context.get_strategy()).__name__
                # TODO(b/174610856): Support CentralStorageStrategy and
                #   ParameterServerStrategy.
                if "CentralStorage" in strat_name or "ParameterServer" in strat_name:
                    raise ValueError("%s is not supported yet" % strat_name)
            alg = stateless_random_ops.convert_alg_to_int(alg)
            if isinstance(state, variables.Variable):
                _check_state_shape(state.shape, alg)
                self._state_var = state
            else:
                state = _convert_to_state_tensor(state)
                _check_state_shape(state.shape, alg)
                self._state_var = self._create_variable(state,
                                                        dtype=STATE_TYPE,
                                                        trainable=False)
            self._alg = alg
Exemplo n.º 2
0
  def from_non_deterministic_state(cls, alg=None):
    """Creates a generator by non-deterministically initializing its state.

    The source of the non-determinism will be platform- and time-dependent.

    Args:
      alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
        `__init__` for its possible values.

    Returns:
      The new generator.
    """
    if alg is None:
      # TODO(b/170668986): more sophisticated algorithm selection
      alg = DEFAULT_ALGORITHM
    alg = stateless_random_ops.convert_alg_to_int(alg)
    state = non_deterministic_ints(shape=[_get_state_size(alg)],
                                   dtype=SEED_TYPE)
    return cls(state=state, alg=alg)
Exemplo n.º 3
0
def rng_bit_generator(algorithm, initial_state, shape, dtype):
  """Stateless PRNG bit generator.

  Wraps the XLA RngBitGenerator operator, documented at
    https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator.

  Args:
    algorithm: The PRNG algorithm to use, one of
      tf.random.Algorithm.{PHILOX, THREEFRY, AUTO_SELECT}.
    initial_state: Initial state for the PRNG algorithm. For THREEFRY, it
      should be a u64[2] and for PHILOX a u64[3].
    shape: The output shape of the generated data.
    dtype: The type of the tensor.

  Returns:
    a tuple with a new state and generated data of the given shape.
  """
  alg_int = stateless_random_ops.convert_alg_to_int(algorithm)
  return gen_xla_ops.xla_rng_bit_generator(alg_int, initial_state, shape,
                                           dtype=dtype)
Exemplo n.º 4
0
def create_rng_state(seed, alg):
    """Creates a RNG state from an integer or a vector.

  Example:

  >>> tf.random.create_rng_state(
  ...     1234, "philox")
  <tf.Tensor: shape=(3,), dtype=int64, numpy=array([1234,    0,    0])>
  >>> tf.random.create_rng_state(
  ...     [12, 34], "threefry")
  <tf.Tensor: shape=(2,), dtype=int64, numpy=array([12, 34])>

  Args:
    seed: an integer or 1-D numpy array.
    alg: the RNG algorithm. Can be a string, an `Algorithm` or an integer.

  Returns:
    a 1-D numpy array whose size depends on the algorithm.
  """
    alg = stateless_random_ops.convert_alg_to_int(alg)
    return _make_state_from_seed(seed, alg)
Exemplo n.º 5
0
    def from_key_counter(cls, key, counter, alg):
        """Creates a generator from a key and a counter.

    This constructor only applies if the algorithm is a counter-based algorithm.
    See method `key` for the meaning of "key" and "counter".

    Args:
      key: the key for the RNG, a scalar of type STATE_TYPE.
      counter: a vector of dtype STATE_TYPE representing the initial counter for
        the RNG, whose length is algorithm-specific.,
      alg: the RNG algorithm. If None, it will be auto-selected. See
        `__init__` for its possible values.

    Returns:
      The new generator.
    """
        counter = _convert_to_state_tensor(counter)
        key = _convert_to_state_tensor(key)
        alg = stateless_random_ops.convert_alg_to_int(alg)
        counter.shape.assert_is_compatible_with([_get_state_size(alg) - 1])
        key.shape.assert_is_compatible_with([])
        key = array_ops.reshape(key, [1])
        state = array_ops.concat([counter, key], 0)
        return cls(state=state, alg=alg)
Exemplo n.º 6
0
    def from_seed(cls, seed, alg=None):
        """Creates a generator from a seed.

    A seed is a 1024-bit unsigned integer represented either as a Python
    integer or a vector of integers. Seeds shorter than 1024-bit will be
    padded. The padding, the internal structure of a seed and the way a seed
    is converted to a state are all opaque (unspecified). The only semantics
    specification of seeds is that two different seeds are likely to produce
    two independent generators (but no guarantee).

    Args:
      seed: the seed for the RNG.
      alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
        `__init__` for its possible values.

    Returns:
      The new generator.
    """
        if alg is None:
            # TODO(b/170668986): more sophisticated algorithm selection
            alg = DEFAULT_ALGORITHM
        alg = stateless_random_ops.convert_alg_to_int(alg)
        state = create_rng_state(seed, alg)
        return cls(state=state, alg=alg)