Exemplo n.º 1
0
 def _standard_normal(self, shape, dtype):
   if compat.forward_compatible(2020, 10, 25):
     key, counter = self._prepare_key_counter(shape)
     return gen_stateless_random_ops_v2.stateless_random_normal_v2(
         shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm)
   return gen_stateful_random_ops.stateful_standard_normal_v2(
       self.state.handle, self.algorithm, shape, dtype=dtype)
Exemplo n.º 2
0
 def normal(self,
            shape,
            mean=0.0,
            stddev=1.0,
            dtype=dtypes.float32,
            name=None):
     with ops.name_scope(name, "stateful_normal",
                         [shape, mean, stddev]) as name:
         shape = _shape_tensor(shape)
         mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
         stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
         rnd = gen_stateful_random_ops.stateful_standard_normal_v2(
             self.state.handle, self.algorithm, shape, dtype=dtype)
         return math_ops.add(rnd * stddev, mean, name=name)
 def testStatefulStandardNormal(self):
     """Tests that op 'StatefulStandardNormal' still works.
 """
     shape = constant_op.constant([4, 7])
     dtype = dtypes.float64
     seed = 1234
     algorithm = random.RNG_ALG_PHILOX
     state = random._make_state_from_seed(seed, algorithm)
     with ops.device("/device:CPU:0"):
         var1 = variables.Variable(np.concatenate(
             (np.array([algorithm], dtype=random.STATE_TYPE), state),
             axis=None),
                                   dtype=random.STATE_TYPE)
         var2 = variables.Variable(state, dtype=random.STATE_TYPE)
         for _ in range(100):
             t1 = gen_stateful_random_ops.stateful_standard_normal(
                 var1.handle, shape, dtype)
             t2 = gen_stateful_random_ops.stateful_standard_normal_v2(
                 var2.handle, algorithm, shape, dtype)
             self.assertAllEqual(t1, t2)
Exemplo n.º 4
0
 def testErrors(self):
   """Tests that proper errors are raised.
   """
   shape = [2, 3]
   gen = random.Generator(seed=1234)
   with self.assertRaisesWithPredicateMatch(
       errors.InvalidArgumentError,
       r"algorithm must be of shape \[\], not"):
     gen_stateful_random_ops.stateful_standard_normal_v2(
         gen.state.handle, [0, 0], shape)
   with self.assertRaisesWithPredicateMatch(
       TypeError, "Requested dtype: int64"):
     gen_stateful_random_ops.stateful_standard_normal_v2(
         gen.state.handle, 1.1, shape)
   with self.assertRaisesWithPredicateMatch(
       errors.InvalidArgumentError,
       "Unsupported algorithm id"):
     gen_stateful_random_ops.stateful_standard_normal_v2(
         gen.state.handle, 123, shape)
   var = variables.Variable([0, 0], dtype=dtypes.int32)
   with self.assertRaisesWithPredicateMatch(
       errors.InvalidArgumentError,
       "dtype of RNG state variable must be int64, not"):
     gen_stateful_random_ops.stateful_standard_normal_v2(
         var.handle, random.RNG_ALG_PHILOX, shape)
   var = variables.Variable([[0]], dtype=dtypes.int64)
   with self.assertRaisesWithPredicateMatch(
       errors.InvalidArgumentError,
       "RNG state must have one and only one dimension, not"):
     gen_stateful_random_ops.stateful_standard_normal_v2(
         var.handle, random.RNG_ALG_PHILOX, shape)
   var = variables.Variable([0], dtype=dtypes.int64)
   with self.assertRaisesWithPredicateMatch(
       errors.InvalidArgumentError,
       "For the Philox algorithm, the size of state must be at least"):
     gen_stateful_random_ops.stateful_standard_normal_v2(
         var.handle, random.RNG_ALG_PHILOX, shape)
Exemplo n.º 5
0
 def _standard_normal(self, shape, dtype):
     return gen_stateful_random_ops.stateful_standard_normal_v2(
         self.state.handle, self.algorithm, shape, dtype=dtype)
 def testErrors(self):
   """Tests that proper errors are raised.
   """
   shape = [2, 3]
   with ops.device(xla_device_name()):
     gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
     with self.assertRaisesWithPredicateMatch(
         errors_impl.InvalidArgumentError,
         r"algorithm must be of shape \[\], not"):
       gen_stateful_random_ops.stateful_standard_normal_v2(
           gen.state.handle, [0, 0], shape)
     with self.assertRaisesWithPredicateMatch(
         TypeError, "Requested dtype: int64"):
       gen_stateful_random_ops.stateful_standard_normal_v2(
           gen.state.handle, 1.1, shape)
     with self.assertRaisesWithPredicateMatch(
         errors_impl.InvalidArgumentError,
         "Unsupported algorithm id"):
       gen_stateful_random_ops.stateful_standard_normal_v2(
           gen.state.handle, 123, shape)
     var = variables.Variable([0, 0], dtype=dtypes.uint32)
     with self.assertRaisesWithPredicateMatch(
         errors_impl.InvalidArgumentError,
         "Type mismatch for read of variable .* Expected int64; got"):
       gen_stateful_random_ops.stateful_standard_normal_v2(
           var.handle, random.RNG_ALG_THREEFRY, shape)
     var = variables.Variable([[0]], dtype=dtypes.int64)
     with self.assertRaisesWithPredicateMatch(
         errors_impl.InvalidArgumentError,
         "RNG state must have one and only one dimension, not"):
       gen_stateful_random_ops.stateful_standard_normal_v2(
           var.handle, random.RNG_ALG_THREEFRY, shape)
     var = variables.Variable([0], dtype=dtypes.int64)
     with self.assertRaisesWithPredicateMatch(
         errors_impl.InvalidArgumentError,
         "For the ThreeFry algorithm, the size of state must be at least"):
       gen_stateful_random_ops.stateful_standard_normal_v2(
           var.handle, random.RNG_ALG_THREEFRY, shape)
Exemplo n.º 7
0
    def testErrors(self):
        """Tests that proper errors are raised.
    """
        shape = [2, 3]
        gen = random.Generator.from_seed(1234)
        with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
                                                 r"must have shape \[\], not"):
            gen_stateful_random_ops.stateful_standard_normal_v2(
                gen.state.handle, [0, 0], shape)
        with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
                                                 r"must have shape \[\], not"):
            gen_stateful_random_ops.rng_skip(gen.state.handle, gen.algorithm,
                                             [0, 0])
        with self.assertRaisesWithPredicateMatch(TypeError,
                                                 "EagerTensor of dtype int64"):
            gen_stateful_random_ops.stateful_standard_normal_v2(
                gen.state.handle, 1.1, shape)
        with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
                                                 "Unsupported algorithm id"):
            gen_stateful_random_ops.stateful_standard_normal_v2(
                gen.state.handle, 123, shape)
        var = variables.Variable([0, 0], dtype=dtypes.int32)
        with self.assertRaisesWithPredicateMatch(
                errors.InvalidArgumentError,
                "dtype of RNG state variable must be int64, not"):
            gen_stateful_random_ops.stateful_standard_normal_v2(
                var.handle, random.RNG_ALG_PHILOX, shape)
        var = variables.Variable([[0]], dtype=dtypes.int64)
        with self.assertRaisesWithPredicateMatch(
                errors.InvalidArgumentError,
                "RNG state must have one and only one dimension, not"):
            gen_stateful_random_ops.stateful_standard_normal_v2(
                var.handle, random.RNG_ALG_PHILOX, shape)
        var = variables.Variable([0], dtype=dtypes.int64)
        with self.assertRaisesWithPredicateMatch(
                errors.InvalidArgumentError,
                "For the Philox algorithm, the size of state must be at least"
        ):
            gen_stateful_random_ops.stateful_standard_normal_v2(
                var.handle, random.RNG_ALG_PHILOX, shape)
        with self.assertRaisesWithPredicateMatch(
                ValueError, "minval must be a scalar; got a tensor of shape "):

            @def_function.function
            def f():
                gen.uniform(shape=shape,
                            minval=array_ops.zeros(shape, "int32"),
                            maxval=100,
                            dtype="int32")

            f()
        with self.assertRaisesWithPredicateMatch(
                ValueError, "maxval must be a scalar; got a tensor of shape "):

            @def_function.function
            def f2():
                gen.uniform(shape=shape,
                            minval=0,
                            maxval=array_ops.ones(shape, "int32") * 100,
                            dtype="int32")

            f2()
 def standard_normal(self, shape, dtype=dtypes.float32):
     output = gen_stateful_random_ops.stateful_standard_normal_v2(
         self.state.handle, self.algorithm, shape, dtype)
     return output
 def testErrors(self):
   """Tests that proper errors are raised.
   """
   shape = [2, 3]
   with ops.device(xla_device_name()):
     gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
     with self.assertRaisesWithPredicateMatch(
         errors_impl.InvalidArgumentError,
         r"algorithm must be of shape \[\], not"):
       gen_stateful_random_ops.stateful_standard_normal_v2(
           gen.state.handle, [0, 0], shape)
     with self.assertRaisesWithPredicateMatch(
         TypeError, "Requested dtype: int64"):
       gen_stateful_random_ops.stateful_standard_normal_v2(
           gen.state.handle, 1.1, shape)
     with self.assertRaisesWithPredicateMatch(
         errors_impl.InvalidArgumentError,
         "Unsupported algorithm id"):
       gen_stateful_random_ops.stateful_standard_normal_v2(
           gen.state.handle, 123, shape)
     var = variables.Variable([0, 0], dtype=dtypes.uint32)
     with self.assertRaisesWithPredicateMatch(
         errors_impl.InvalidArgumentError,
         "Type mismatch for read of variable .* Expected int64; got"):
       gen_stateful_random_ops.stateful_standard_normal_v2(
           var.handle, random.RNG_ALG_THREEFRY, shape)
     var = variables.Variable([[0]], dtype=dtypes.int64)
     with self.assertRaisesWithPredicateMatch(
         errors_impl.InvalidArgumentError,
         "RNG state must have one and only one dimension, not"):
       gen_stateful_random_ops.stateful_standard_normal_v2(
           var.handle, random.RNG_ALG_THREEFRY, shape)
     var = variables.Variable([0], dtype=dtypes.int64)
     with self.assertRaisesWithPredicateMatch(
         errors_impl.InvalidArgumentError,
         "For the ThreeFry algorithm, the size of state must be at least"):
       gen_stateful_random_ops.stateful_standard_normal_v2(
           var.handle, random.RNG_ALG_THREEFRY, shape)