def f():
   key, counter = (
       gen_stateless_random_ops_v2.stateless_random_get_key_counter(
           seed=math_ops.cast(v.read_value(), dtypes.int32)))
   alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
   return gen_stateless_random_ops_v2.stateless_random_normal_v2(
       shape=[], key=key, counter=counter, alg=alg)
 def testLargeNormal(self):
     """Tests an OOM bug of StatelessRandomNormalV2 on TPU."""
     with self.session() as sess, self.test_scope():
         seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
         key, counter, alg = (gen_stateless_random_ops_v2.
                              stateless_random_get_key_counter_alg(seed_t))
         x = gen_stateless_random_ops_v2.stateless_random_normal_v2(
             shape=[1024, 32000],
             key=key,
             counter=counter,
             dtype=dtypes.float32,
             alg=alg)
         y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]})
         self.assertAllEqual([1024, 32000], y.shape)
         key, counter = (gen_stateless_random_ops_v2.
                         stateless_random_get_key_counter(seed_t))
         alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
         x = gen_stateless_random_ops_v2.stateless_random_normal_v2(
             shape=[1024, 32000],
             key=key,
             counter=counter,
             dtype=dtypes.float32,
             alg=alg)
         y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]})
         self.assertAllEqual([1024, 32000], y.shape)
 def testGetKeyCounterAlg(self):
   seed = [1, 2]
   key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter(
       seed)
   self.assertAllEqual(key.shape, [1])
   self.assertAllEqual(counter.shape, [2])
   alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
   self.assertAllEqual(alg.shape, [])
def _get_key_counter_alg(seed):
    if compat.forward_compatible(2021, 2, 2):
        key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter(
            seed)
        alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
        return key, counter, alg
    else:
        return gen_stateless_random_ops_v2.stateless_random_get_key_counter_alg(
            seed)
def _resolve_alg(alg):
    if alg == Algorithm.AUTO_SELECT.value:
        return gen_stateless_random_ops_v2.stateless_random_get_alg()
    return alg
def _get_key_counter_alg(seed):
    key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter(
        seed)
    alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
    return key, counter, alg