コード例 #1
0
def _get_key_counter(seed, alg):
    """Calculates the key and counter to pass to raw RNG ops.

  This function calculates the key and counter that will be passed to
  the raw RNG ops like `StatelessRandomUniformV2`. Depending on the
  input `alg`, the key and counter may be scrambled or copied from
  `seed`. If `alg` is `"auto_select"`, the key and counter will be
  determined at runtime based on device type.

  Args:
    seed: An integer tensor of shape [2]. The seed to calculate the
      key and counter from.
    alg: The RNG algorithm. See `tf.random.stateless_uniform` for an
      explanation.

  Returns:
    A pair (key, counter) suitable for V2 stateless RNG ops like
    `StatelessRandomUniformV2`.
  """
    if alg == Algorithm.AUTO_SELECT.value:
        key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter(
            seed)
    elif alg == Algorithm.PHILOX.value:
        key, counter = _philox_scramble_seed(seed)
    elif alg == Algorithm.THREEFRY.value:
        key = array_ops.reshape(
            uint32s_to_uint64(math_ops.cast(seed, dtypes.uint32)), [1])
        counter = array_ops.zeros([1], dtypes.uint64)
    else:
        raise ValueError(
            f"Argument `alg` got unsupported value {alg}. Supported values are "
            f"{Algorithm.PHILOX.value} for the Philox algorithm, "
            f"{Algorithm.THREEFRY.value} for the ThreeFry algorithm, and "
            f"{Algorithm.AUTO_SELECT.value} for auto-selection.")
    return key, counter
コード例 #2
0
 def f():
   key, counter = (
       gen_stateless_random_ops_v2.stateless_random_get_key_counter(
           seed=math_ops.cast(v.read_value(), dtypes.int32)))
   alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
   return gen_stateless_random_ops_v2.stateless_random_normal_v2(
       shape=[], key=key, counter=counter, alg=alg)
コード例 #3
0
 def testGetKeyCounterAlg(self):
   seed = [1, 2]
   key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter(
       seed)
   self.assertAllEqual(key.shape, [1])
   self.assertAllEqual(counter.shape, [2])
   alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
   self.assertAllEqual(alg.shape, [])
コード例 #4
0
def _get_key_counter_alg(seed):
    if compat.forward_compatible(2021, 2, 2):
        key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter(
            seed)
        alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
        return key, counter, alg
    else:
        return gen_stateless_random_ops_v2.stateless_random_get_key_counter_alg(
            seed)
コード例 #5
0
def _get_key_counter_alg(seed):
    key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter(
        seed)
    alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
    return key, counter, alg