def testSetGlobalGeneratorBadWithDefun(self): """Demonstrates set_global_generator does not affect compiled tf.function.""" shape = (3, ) @def_function.function def f(): return random.get_global_generator().normal(shape) random.set_global_generator(random.Generator.from_seed(50)) samples = f() # Resetting global generator has no effect to the compiled tf.function. random.set_global_generator(random.Generator.from_seed(50)) # New samples are returned. self.assertNotAllEqual(samples, f())
def testSetGlobalGeneratorBadWithDefun(self): """Demonstrates that set_global_generator don't work properly with defun. """ shape = (3, ) @def_function.function def f(): return random.get_global_generator().normal(shape) random.set_global_generator(random.Generator.from_seed(50)) with self.assertRaisesWithPredicateMatch(errors.NotFoundError, "Resource .+ does not exist"): _ = f() random.set_global_generator(random.Generator.from_seed(50)) _ = f()
def testSetGlobalGeneratorBadWithDefun(self): """Demonstrates that set_global_generator don't work properly with defun. """ shape = (3,) @def_function.function def f(): return random.get_global_generator().normal(shape) random.set_global_generator(random.Generator.from_seed(50)) with self.assertRaisesWithPredicateMatch( errors.NotFoundError, "Resource .+ does not exist"): _ = f() random.set_global_generator(random.Generator.from_seed(50)) _ = f()
def testDeterministicOpsErrors(self): try: config.enable_op_determinism() random.set_global_generator(None) with self.assertRaisesWithPredicateMatch( RuntimeError, '"get_global_generator" cannot be called if determinism is enabled' ): random.get_global_generator() random.set_global_generator(random.Generator.from_seed(50)) random.get_global_generator() with self.assertRaisesWithPredicateMatch( RuntimeError, '"from_non_deterministic_state" cannot be called when determinism ' "is enabled."): random.Generator.from_non_deterministic_state() finally: config.disable_op_determinism()
def testGetGlobalGeneratorWithXla(self): """Demonstrates using the global generator with XLA.""" if not config.list_physical_devices("XLA_CPU"): self.skipTest("No XLA_CPU device available.") random.set_global_generator(None) @def_function.function(experimental_compile=True) def make_seed(): generator = random.get_global_generator() state = array_ops.identity(generator.state, name="state") return generator.uniform_full_int((2,), dtypes.int32, name="seed"), state with ops.device("/device:XLA_CPU:0"): seed, state = make_seed() self.assertTrue(np.all(np.isfinite(seed.numpy()))) random.get_global_generator().reset(state) self.assertAllEqual(make_seed()[0], seed)
def testGetGlobalGeneratorWithXla(self): """Demonstrates using the global generator with XLA.""" # This test was passing before because soft placement silently picked the # CPU kernel. # TODO(wangpeng): Remove this skip self.skipTest("NonDeterministicInts lacks XLA kernel.") if not config.list_physical_devices("XLA_CPU"): self.skipTest("No XLA_CPU device available.") random.set_global_generator(None) @def_function.function(jit_compile=True) def make_seed(): generator = random.get_global_generator() state = array_ops.identity(generator.state, name="state") return generator.uniform_full_int((2,), dtypes.int32, name="seed"), state with ops.device("/device:XLA_CPU:0"): seed, state = make_seed() self.assertTrue(np.all(np.isfinite(seed.numpy()))) random.get_global_generator().reset(state) self.assertAllEqual(make_seed()[0], seed)