def _standard_normal(self, shape, dtype): if compat.forward_compatible(2020, 10, 25): key, counter = self._prepare_key_counter(shape) return gen_stateless_random_ops_v2.stateless_random_normal_v2( shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm) return gen_stateful_random_ops.stateful_standard_normal_v2( self.state.handle, self.algorithm, shape, dtype=dtype)
def normal(self, shape, mean=0.0, stddev=1.0, dtype=dtypes.float32, name=None): with ops.name_scope(name, "stateful_normal", [shape, mean, stddev]) as name: shape = _shape_tensor(shape) mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean") stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") rnd = gen_stateful_random_ops.stateful_standard_normal_v2( self.state.handle, self.algorithm, shape, dtype=dtype) return math_ops.add(rnd * stddev, mean, name=name)
def testStatefulStandardNormal(self): """Tests that op 'StatefulStandardNormal' still works. """ shape = constant_op.constant([4, 7]) dtype = dtypes.float64 seed = 1234 algorithm = random.RNG_ALG_PHILOX state = random._make_state_from_seed(seed, algorithm) with ops.device("/device:CPU:0"): var1 = variables.Variable(np.concatenate( (np.array([algorithm], dtype=random.STATE_TYPE), state), axis=None), dtype=random.STATE_TYPE) var2 = variables.Variable(state, dtype=random.STATE_TYPE) for _ in range(100): t1 = gen_stateful_random_ops.stateful_standard_normal( var1.handle, shape, dtype) t2 = gen_stateful_random_ops.stateful_standard_normal_v2( var2.handle, algorithm, shape, dtype) self.assertAllEqual(t1, t2)
def testErrors(self): """Tests that proper errors are raised. """ shape = [2, 3] gen = random.Generator(seed=1234) with self.assertRaisesWithPredicateMatch( errors.InvalidArgumentError, r"algorithm must be of shape \[\], not"): gen_stateful_random_ops.stateful_standard_normal_v2( gen.state.handle, [0, 0], shape) with self.assertRaisesWithPredicateMatch( TypeError, "Requested dtype: int64"): gen_stateful_random_ops.stateful_standard_normal_v2( gen.state.handle, 1.1, shape) with self.assertRaisesWithPredicateMatch( errors.InvalidArgumentError, "Unsupported algorithm id"): gen_stateful_random_ops.stateful_standard_normal_v2( gen.state.handle, 123, shape) var = variables.Variable([0, 0], dtype=dtypes.int32) with self.assertRaisesWithPredicateMatch( errors.InvalidArgumentError, "dtype of RNG state variable must be int64, not"): gen_stateful_random_ops.stateful_standard_normal_v2( var.handle, random.RNG_ALG_PHILOX, shape) var = variables.Variable([[0]], dtype=dtypes.int64) with self.assertRaisesWithPredicateMatch( errors.InvalidArgumentError, "RNG state must have one and only one dimension, not"): gen_stateful_random_ops.stateful_standard_normal_v2( var.handle, random.RNG_ALG_PHILOX, shape) var = variables.Variable([0], dtype=dtypes.int64) with self.assertRaisesWithPredicateMatch( errors.InvalidArgumentError, "For the Philox algorithm, the size of state must be at least"): gen_stateful_random_ops.stateful_standard_normal_v2( var.handle, random.RNG_ALG_PHILOX, shape)
def _standard_normal(self, shape, dtype): return gen_stateful_random_ops.stateful_standard_normal_v2( self.state.handle, self.algorithm, shape, dtype=dtype)
def testErrors(self): """Tests that proper errors are raised. """ shape = [2, 3] with ops.device(xla_device_name()): gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, r"algorithm must be of shape \[\], not"): gen_stateful_random_ops.stateful_standard_normal_v2( gen.state.handle, [0, 0], shape) with self.assertRaisesWithPredicateMatch( TypeError, "Requested dtype: int64"): gen_stateful_random_ops.stateful_standard_normal_v2( gen.state.handle, 1.1, shape) with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, "Unsupported algorithm id"): gen_stateful_random_ops.stateful_standard_normal_v2( gen.state.handle, 123, shape) var = variables.Variable([0, 0], dtype=dtypes.uint32) with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, "Type mismatch for read of variable .* Expected int64; got"): gen_stateful_random_ops.stateful_standard_normal_v2( var.handle, random.RNG_ALG_THREEFRY, shape) var = variables.Variable([[0]], dtype=dtypes.int64) with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, "RNG state must have one and only one dimension, not"): gen_stateful_random_ops.stateful_standard_normal_v2( var.handle, random.RNG_ALG_THREEFRY, shape) var = variables.Variable([0], dtype=dtypes.int64) with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, "For the ThreeFry algorithm, the size of state must be at least"): gen_stateful_random_ops.stateful_standard_normal_v2( var.handle, random.RNG_ALG_THREEFRY, shape)
def testErrors(self): """Tests that proper errors are raised. """ shape = [2, 3] gen = random.Generator.from_seed(1234) with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, r"must have shape \[\], not"): gen_stateful_random_ops.stateful_standard_normal_v2( gen.state.handle, [0, 0], shape) with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, r"must have shape \[\], not"): gen_stateful_random_ops.rng_skip(gen.state.handle, gen.algorithm, [0, 0]) with self.assertRaisesWithPredicateMatch(TypeError, "EagerTensor of dtype int64"): gen_stateful_random_ops.stateful_standard_normal_v2( gen.state.handle, 1.1, shape) with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, "Unsupported algorithm id"): gen_stateful_random_ops.stateful_standard_normal_v2( gen.state.handle, 123, shape) var = variables.Variable([0, 0], dtype=dtypes.int32) with self.assertRaisesWithPredicateMatch( errors.InvalidArgumentError, "dtype of RNG state variable must be int64, not"): gen_stateful_random_ops.stateful_standard_normal_v2( var.handle, random.RNG_ALG_PHILOX, shape) var = variables.Variable([[0]], dtype=dtypes.int64) with self.assertRaisesWithPredicateMatch( errors.InvalidArgumentError, "RNG state must have one and only one dimension, not"): gen_stateful_random_ops.stateful_standard_normal_v2( var.handle, random.RNG_ALG_PHILOX, shape) var = variables.Variable([0], dtype=dtypes.int64) with self.assertRaisesWithPredicateMatch( errors.InvalidArgumentError, "For the Philox algorithm, the size of state must be at least" ): gen_stateful_random_ops.stateful_standard_normal_v2( var.handle, random.RNG_ALG_PHILOX, shape) with self.assertRaisesWithPredicateMatch( ValueError, "minval must be a scalar; got a tensor of shape "): @def_function.function def f(): gen.uniform(shape=shape, minval=array_ops.zeros(shape, "int32"), maxval=100, dtype="int32") f() with self.assertRaisesWithPredicateMatch( ValueError, "maxval must be a scalar; got a tensor of shape "): @def_function.function def f2(): gen.uniform(shape=shape, minval=0, maxval=array_ops.ones(shape, "int32") * 100, dtype="int32") f2()
def standard_normal(self, shape, dtype=dtypes.float32): output = gen_stateful_random_ops.stateful_standard_normal_v2( self.state.handle, self.algorithm, shape, dtype) return output
def testErrors(self): """Tests that proper errors are raised. """ shape = [2, 3] with ops.device(xla_device_name()): gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, r"algorithm must be of shape \[\], not"): gen_stateful_random_ops.stateful_standard_normal_v2( gen.state.handle, [0, 0], shape) with self.assertRaisesWithPredicateMatch( TypeError, "Requested dtype: int64"): gen_stateful_random_ops.stateful_standard_normal_v2( gen.state.handle, 1.1, shape) with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, "Unsupported algorithm id"): gen_stateful_random_ops.stateful_standard_normal_v2( gen.state.handle, 123, shape) var = variables.Variable([0, 0], dtype=dtypes.uint32) with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, "Type mismatch for read of variable .* Expected int64; got"): gen_stateful_random_ops.stateful_standard_normal_v2( var.handle, random.RNG_ALG_THREEFRY, shape) var = variables.Variable([[0]], dtype=dtypes.int64) with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, "RNG state must have one and only one dimension, not"): gen_stateful_random_ops.stateful_standard_normal_v2( var.handle, random.RNG_ALG_THREEFRY, shape) var = variables.Variable([0], dtype=dtypes.int64) with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, "For the ThreeFry algorithm, the size of state must be at least"): gen_stateful_random_ops.stateful_standard_normal_v2( var.handle, random.RNG_ALG_THREEFRY, shape)