def f(): key, counter = ( gen_stateless_random_ops_v2.stateless_random_get_key_counter( seed=math_ops.cast(v.read_value(), dtypes.int32))) alg = gen_stateless_random_ops_v2.stateless_random_get_alg() return gen_stateless_random_ops_v2.stateless_random_normal_v2( shape=[], key=key, counter=counter, alg=alg)
def testLargeNormal(self): """Tests an OOM bug of StatelessRandomNormalV2 on TPU.""" with self.session() as sess, self.test_scope(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) key, counter, alg = (gen_stateless_random_ops_v2. stateless_random_get_key_counter_alg(seed_t)) x = gen_stateless_random_ops_v2.stateless_random_normal_v2( shape=[1024, 32000], key=key, counter=counter, dtype=dtypes.float32, alg=alg) y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]}) self.assertAllEqual([1024, 32000], y.shape) key, counter = (gen_stateless_random_ops_v2. stateless_random_get_key_counter(seed_t)) alg = gen_stateless_random_ops_v2.stateless_random_get_alg() x = gen_stateless_random_ops_v2.stateless_random_normal_v2( shape=[1024, 32000], key=key, counter=counter, dtype=dtypes.float32, alg=alg) y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]}) self.assertAllEqual([1024, 32000], y.shape)
def testGetKeyCounterAlg(self): seed = [1, 2] key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter( seed) self.assertAllEqual(key.shape, [1]) self.assertAllEqual(counter.shape, [2]) alg = gen_stateless_random_ops_v2.stateless_random_get_alg() self.assertAllEqual(alg.shape, [])
def _get_key_counter_alg(seed): if compat.forward_compatible(2021, 2, 2): key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter( seed) alg = gen_stateless_random_ops_v2.stateless_random_get_alg() return key, counter, alg else: return gen_stateless_random_ops_v2.stateless_random_get_key_counter_alg( seed)
def _resolve_alg(alg): if alg == Algorithm.AUTO_SELECT.value: return gen_stateless_random_ops_v2.stateless_random_get_alg() return alg
def _get_key_counter_alg(seed): key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter( seed) alg = gen_stateless_random_ops_v2.stateless_random_get_alg() return key, counter, alg