コード例 #1
0
    def testSetGlobalGeneratorBadWithDefun(self):
        """Demonstrates set_global_generator does not affect compiled tf.function."""
        shape = (3, )

        @def_function.function
        def f():
            return random.get_global_generator().normal(shape)

        random.set_global_generator(random.Generator.from_seed(50))
        samples = f()
        # Resetting global generator has no effect to the compiled tf.function.
        random.set_global_generator(random.Generator.from_seed(50))
        # New samples are returned.
        self.assertNotAllEqual(samples, f())
コード例 #2
0
    def testSetGlobalGeneratorBadWithDefun(self):
        """Demonstrates that set_global_generator don't work properly with defun.
    """
        shape = (3, )

        @def_function.function
        def f():
            return random.get_global_generator().normal(shape)

        random.set_global_generator(random.Generator.from_seed(50))
        with self.assertRaisesWithPredicateMatch(errors.NotFoundError,
                                                 "Resource .+ does not exist"):
            _ = f()
            random.set_global_generator(random.Generator.from_seed(50))
            _ = f()
コード例 #3
0
  def testSetGlobalGeneratorBadWithDefun(self):
    """Demonstrates that set_global_generator don't work properly with defun.
    """
    shape = (3,)

    @def_function.function
    def f():
      return random.get_global_generator().normal(shape)

    random.set_global_generator(random.Generator.from_seed(50))
    with self.assertRaisesWithPredicateMatch(
        errors.NotFoundError, "Resource .+ does not exist"):
      _ = f()
      random.set_global_generator(random.Generator.from_seed(50))
      _ = f()
コード例 #4
0
 def testDeterministicOpsErrors(self):
     try:
         config.enable_op_determinism()
         random.set_global_generator(None)
         with self.assertRaisesWithPredicateMatch(
                 RuntimeError,
                 '"get_global_generator" cannot be called if determinism is enabled'
         ):
             random.get_global_generator()
         random.set_global_generator(random.Generator.from_seed(50))
         random.get_global_generator()
         with self.assertRaisesWithPredicateMatch(
                 RuntimeError,
                 '"from_non_deterministic_state" cannot be called when determinism '
                 "is enabled."):
             random.Generator.from_non_deterministic_state()
     finally:
         config.disable_op_determinism()
コード例 #5
0
  def testGetGlobalGeneratorWithXla(self):
    """Demonstrates using the global generator with XLA."""
    if not config.list_physical_devices("XLA_CPU"):
      self.skipTest("No XLA_CPU device available.")

    random.set_global_generator(None)

    @def_function.function(experimental_compile=True)
    def make_seed():
      generator = random.get_global_generator()
      state = array_ops.identity(generator.state, name="state")
      return generator.uniform_full_int((2,), dtypes.int32, name="seed"), state

    with ops.device("/device:XLA_CPU:0"):
      seed, state = make_seed()
      self.assertTrue(np.all(np.isfinite(seed.numpy())))
      random.get_global_generator().reset(state)
      self.assertAllEqual(make_seed()[0], seed)
コード例 #6
0
  def testGetGlobalGeneratorWithXla(self):
    """Demonstrates using the global generator with XLA."""
    # This test was passing before because soft placement silently picked the
    # CPU kernel.
    # TODO(wangpeng): Remove this skip
    self.skipTest("NonDeterministicInts lacks XLA kernel.")

    if not config.list_physical_devices("XLA_CPU"):
      self.skipTest("No XLA_CPU device available.")

    random.set_global_generator(None)

    @def_function.function(jit_compile=True)
    def make_seed():
      generator = random.get_global_generator()
      state = array_ops.identity(generator.state, name="state")
      return generator.uniform_full_int((2,), dtypes.int32, name="seed"), state

    with ops.device("/device:XLA_CPU:0"):
      seed, state = make_seed()
      self.assertTrue(np.all(np.isfinite(seed.numpy())))
      random.get_global_generator().reset(state)
      self.assertAllEqual(make_seed()[0], seed)