コード例 #1
0
 def testErrors(self):
   """Tests that proper errors are raised.
   """
   shape = [2, 3]
   gen = random.Generator.from_seed(1234)
   with self.assertRaisesWithPredicateMatch(
       errors.InvalidArgumentError,
       r"must have shape \[\], not"):
     gen_stateful_random_ops.stateful_standard_normal_v2(
         gen.state.handle, [0, 0], shape)
   with self.assertRaisesWithPredicateMatch(
       errors.InvalidArgumentError,
       r"must have shape \[\], not"):
     gen_stateful_random_ops.rng_skip(
         gen.state.handle, gen.algorithm, [0, 0])
   with self.assertRaisesWithPredicateMatch(
       TypeError, "EagerTensor of dtype int64"):
     gen_stateful_random_ops.stateful_standard_normal_v2(
         gen.state.handle, 1.1, shape)
   with self.assertRaisesWithPredicateMatch(
       errors.InvalidArgumentError,
       "Unsupported algorithm id"):
     gen_stateful_random_ops.stateful_standard_normal_v2(
         gen.state.handle, 123, shape)
   var = variables.Variable([0, 0], dtype=dtypes.int32)
   with self.assertRaisesWithPredicateMatch(
       errors.InvalidArgumentError,
       "dtype of RNG state variable must be int64, not"):
     gen_stateful_random_ops.stateful_standard_normal_v2(
         var.handle, random.RNG_ALG_PHILOX, shape)
   var = variables.Variable([[0]], dtype=dtypes.int64)
   with self.assertRaisesWithPredicateMatch(
       errors.InvalidArgumentError,
       "RNG state must have one and only one dimension, not"):
     gen_stateful_random_ops.stateful_standard_normal_v2(
         var.handle, random.RNG_ALG_PHILOX, shape)
   var = variables.Variable([0], dtype=dtypes.int64)
   with self.assertRaisesWithPredicateMatch(
       errors.InvalidArgumentError,
       "For the Philox algorithm, the size of state must be at least"):
     gen_stateful_random_ops.stateful_standard_normal_v2(
         var.handle, random.RNG_ALG_PHILOX, shape)
   with self.assertRaisesWithPredicateMatch(
       ValueError,
       "minval must be a scalar; got a tensor of shape "):
     @def_function.function
     def f():
       gen.uniform(shape=shape, minval=array_ops.zeros(shape, "int32"),
                   maxval=100, dtype="int32")
     f()
   with self.assertRaisesWithPredicateMatch(
       ValueError,
       "maxval must be a scalar; got a tensor of shape "):
     @def_function.function
     def f2():
       gen.uniform(
           shape=shape, minval=0, maxval=array_ops.ones(shape, "int32") * 100,
           dtype="int32")
     f2()
コード例 #2
0
  def skip(self, delta):
    """Advance the counter of a counter-based RNG.

    Args:
      delta: the amount of advancement. The state of the RNG after
        `skip(n)` will be the same as that after `normal([n])`
        (or any other distribution). The actual increment added to the
        counter is an unspecified implementation detail.
    """
    gen_stateful_random_ops.rng_skip(self.state.handle, self.algorithm, delta)
コード例 #3
0
    def skip(self, delta):
        """Advance the counter of a counter-based RNG.

    Args:
      delta: the amount of advancement. The state of the RNG after
        `skip(n)` will be the same as that after `normal([n])`
        (or any other distribution). The actual increment added to the
        counter is an unspecified implementation detail.
    """
        if compat.forward_compatible(2020, 10, 25):
            return self._skip(delta)
        gen_stateful_random_ops.rng_skip(
            self.state.handle, math_ops.cast(self.algorithm, dtypes.int64),
            math_ops.cast(delta, dtypes.int64))
コード例 #4
0
 def testErrors(self):
   """Tests that proper errors are raised.
   """
   shape = [2, 3]
   gen = random.Generator.from_seed(1234)
   with self.assertRaisesWithPredicateMatch(
       errors.InvalidArgumentError,
       r"must have shape \[\], not"):
     gen_stateful_random_ops.stateful_standard_normal_v2(
         gen.state.handle, [0, 0], shape)
   with self.assertRaisesWithPredicateMatch(
       errors.InvalidArgumentError,
       r"must have shape \[\], not"):
     gen_stateful_random_ops.rng_skip(
         gen.state.handle, gen.algorithm, [0, 0])
   with self.assertRaisesWithPredicateMatch(
       TypeError, "EagerTensor of dtype int64"):
     gen_stateful_random_ops.stateful_standard_normal_v2(
         gen.state.handle, 1.1, shape)
   with self.assertRaisesWithPredicateMatch(
       errors.InvalidArgumentError,
       "Unsupported algorithm id"):
     gen_stateful_random_ops.stateful_standard_normal_v2(
         gen.state.handle, 123, shape)
   var = variables.Variable([0, 0], dtype=dtypes.int32)
   with self.assertRaisesWithPredicateMatch(
       errors.InvalidArgumentError,
       "dtype of RNG state variable must be int64, not"):
     gen_stateful_random_ops.stateful_standard_normal_v2(
         var.handle, random.RNG_ALG_PHILOX, shape)
   var = variables.Variable([[0]], dtype=dtypes.int64)
   with self.assertRaisesWithPredicateMatch(
       errors.InvalidArgumentError,
       "RNG state must have one and only one dimension, not"):
     gen_stateful_random_ops.stateful_standard_normal_v2(
         var.handle, random.RNG_ALG_PHILOX, shape)
   var = variables.Variable([0], dtype=dtypes.int64)
   with self.assertRaisesWithPredicateMatch(
       errors.InvalidArgumentError,
       "For the Philox algorithm, the size of state must be at least"):
     gen_stateful_random_ops.stateful_standard_normal_v2(
         var.handle, random.RNG_ALG_PHILOX, shape)
コード例 #5
0
 def testErrors(self):
   """Tests that proper errors are raised.
   """
   shape = [2, 3]
   gen = random.Generator.from_seed(1234)
   with self.assertRaisesWithPredicateMatch(
       errors.InvalidArgumentError,
       r"must have shape \[\], not"):
     gen_stateful_random_ops.stateful_standard_normal_v2(
         gen.state.handle, [0, 0], shape)
   with self.assertRaisesWithPredicateMatch(
       errors.InvalidArgumentError,
       r"must have shape \[\], not"):
     gen_stateful_random_ops.rng_skip(
         gen.state.handle, gen.algorithm, [0, 0])
   with self.assertRaisesWithPredicateMatch(
       TypeError, "Requested dtype: int64"):
     gen_stateful_random_ops.stateful_standard_normal_v2(
         gen.state.handle, 1.1, shape)
   with self.assertRaisesWithPredicateMatch(
       errors.InvalidArgumentError,
       "Unsupported algorithm id"):
     gen_stateful_random_ops.stateful_standard_normal_v2(
         gen.state.handle, 123, shape)
   var = variables.Variable([0, 0], dtype=dtypes.int32)
   with self.assertRaisesWithPredicateMatch(
       errors.InvalidArgumentError,
       "dtype of RNG state variable must be int64, not"):
     gen_stateful_random_ops.stateful_standard_normal_v2(
         var.handle, random.RNG_ALG_PHILOX, shape)
   var = variables.Variable([[0]], dtype=dtypes.int64)
   with self.assertRaisesWithPredicateMatch(
       errors.InvalidArgumentError,
       "RNG state must have one and only one dimension, not"):
     gen_stateful_random_ops.stateful_standard_normal_v2(
         var.handle, random.RNG_ALG_PHILOX, shape)
   var = variables.Variable([0], dtype=dtypes.int64)
   with self.assertRaisesWithPredicateMatch(
       errors.InvalidArgumentError,
       "For the Philox algorithm, the size of state must be at least"):
     gen_stateful_random_ops.stateful_standard_normal_v2(
         var.handle, random.RNG_ALG_PHILOX, shape)