def testNewStatePhilox(self): """Tests that the new state is correct (for Philox). """ if compat.forward_compatible(2020, 10, 25): self.skipTest("The expected values in this test is inconsistent with " "CPU/GPU. testXLAEqualsCPU has the correct checks of the " "new states for the new version.") with ops.device(xla_device_name()): counter_low = 57 counter_high = 283 key = 0x1234 size = 47 state = [counter_low, counter_high, key] gen = random.Generator(state=state, alg=random.RNG_ALG_PHILOX) gen.uniform_full_int(shape=(size,), dtype=dtypes.uint32) self.assertAllEqual([counter_low+(size+3)//4, counter_high, key], gen.state.read_value()) gen.reset(state) gen.uniform_full_int(shape=(size,), dtype=dtypes.uint64) self.assertAllEqual([counter_low+(size+1)//2, counter_high, key], gen.state.read_value()) # Tests that large counter_low will correctly overflows to counter_high counter_low = -1 # same as 0xffffffffffffffff counter_high = 283 size = 47 state = [counter_low, counter_high, key] gen = random.Generator(state=state, alg=random.RNG_ALG_PHILOX) gen.uniform_full_int(shape=(size,), dtype=dtypes.uint32) self.assertAllEqual([(size+3)//4-1, counter_high+1, key], gen.state.read_value()) gen.reset(state) gen.uniform_full_int(shape=(size,), dtype=dtypes.uint64) self.assertAllEqual([(size+1)//2-1, counter_high+1, key], gen.state.read_value())
def testTF1(self): seed = 1234 shape = [2, 3] expected_normal1 = constant_op.constant( [[0.9356609, 1.0854305, -0.93788373], [-0.50615472, 1.31697023, 0.71375787]], dtype=dtypes.float32) expected_normal2 = constant_op.constant( [[-0.3964749, 0.8369565, -0.30946946], [1.1206646, 1.00852597, -0.10185789]], dtype=dtypes.float32) with self.cached_session() as sess: gen1 = random.Generator(seed=seed) gen2 = random.Generator() sess.run( (gen1._state_var.initializer, gen2._state_var.initializer)) r1 = gen1.normal(shape, dtype=dtypes.float32) r2 = gen2.normal(shape, dtype=dtypes.float32) def f(): return sess.run((r1, r2)) def check_results(expected_normal, v1, v2): self.assertAllClose(expected_normal, v1, rtol=1e-5, atol=1e-5) self.assertAllEqual(shape, v2.shape) check_results(expected_normal1, *f()) check_results(expected_normal2, *f())
def testNewStatePhilox(self): """Tests that the new state is correct (for Philox). """ with ops.device(xla_device_name()): counter_low = 57 counter_high = 283 key = 0x1234 size = 47 state = [counter_low, counter_high, key] gen = random.Generator(state=state, alg=random.RNG_ALG_PHILOX) gen.uniform_full_int(shape=(size, ), dtype=dtypes.uint32) self.assertAllEqual( [counter_low + (size + 3) // 4, counter_high, key], gen.state.read_value()) gen.reset(state) gen.uniform_full_int(shape=(size, ), dtype=dtypes.uint64) self.assertAllEqual( [counter_low + (size + 1) // 2, counter_high, key], gen.state.read_value()) # Tests that large counter_low will correctly overflows to counter_high counter_low = -1 # same as 0xffffffffffffffff counter_high = 283 size = 47 state = [counter_low, counter_high, key] gen = random.Generator(state=state, alg=random.RNG_ALG_PHILOX) gen.uniform_full_int(shape=(size, ), dtype=dtypes.uint32) self.assertAllEqual([(size + 3) // 4 - 1, counter_high + 1, key], gen.state.read_value()) gen.reset(state) gen.uniform_full_int(shape=(size, ), dtype=dtypes.uint64) self.assertAllEqual([(size + 1) // 2 - 1, counter_high + 1, key], gen.state.read_value())
def testGPUEqualsCPU(self, dtype): """Tests that GPU and CPU generate the same integer outputs.""" seed = 1234 shape = [315, 49] with ops.device("/device:CPU:0"): cpu = random.Generator(seed=seed).uniform_full_int(shape=shape, dtype=dtype) with ops.device(test_util.gpu_device_name()): gpu = random.Generator(seed=seed).uniform_full_int(shape=shape, dtype=dtype) self.assertAllEqual(cpu, gpu)
def testXLAEqualsCPU(self, dtype): """Tests that XLA and CPU kernels generate the same integers.""" seed = 1234 shape = [315, 49] with ops.device("/device:CPU:0"): cpu = (random.Generator( seed=seed, algorithm=random.RNG_ALG_PHILOX).uniform_full_int(shape=shape, dtype=dtype)) with ops.device(xla_device_name()): xla = (random.Generator( seed=seed, algorithm=random.RNG_ALG_PHILOX).uniform_full_int(shape=shape, dtype=dtype)) self.assertAllEqual(cpu, xla)
def f(): global g_seeded # defun'ed function should only create variables once if g_seeded is None: g_seeded = random.Generator(g) self.assertAllEqual(g.algorithm, g_seeded.algorithm) self.assertAllEqual(g.state.read_value(), g_seeded.state.read_value())
def testGeneratorCreation(self): """Tests generator creation, in both eager and tf.function. The interaction between Generator creation and defun should be the same as tf.Variable. """ shape = [2, 3] alg = random.RNG_ALG_PHILOX for constructor in [ lambda: random.Generator(state=[1, 2, 3], alg=alg), lambda: random.Generator.from_seed(1234), lambda: random.Generator.from_key_counter( # pylint: disable=g-long-lambda key=1, counter=[2, 3], alg=alg), ]: gen = constructor() # Tests tf.function expected_normal1 = gen.normal(shape) expected_normal2 = gen.normal(shape) global g_seeded g_seeded = None @def_function.function def f(constructor): global g_seeded # defun'ed function should only create variables once if g_seeded is None: g_seeded = constructor() return g_seeded.normal(shape) def check_results(expected_normal, v): self.assertAllEqual(expected_normal, v) check_results(expected_normal1, f(constructor)) check_results(expected_normal2, f(constructor))
def testBatchSeeds(self): """Test for batch seeds. """ shape = [2, 3] count = 6 gen = random.Generator(seed=1234) keys1 = gen._make_int64_keys(shape=shape) keys2 = gen._make_int64_keys(shape=shape) self.assertAllDifferent([keys1, keys2]) seeds1 = gen.make_seeds(count=count) seeds2 = gen.make_seeds(count=count) self.assertAllDifferent([seeds1[0, :], seeds2[0, :]]) gens = gen.split(count=count) self.assertAllEqual(count, len(gens)) randoms = [ g.uniform_full_int(shape=shape, dtype=dtypes.int32) for g in gens ] self.assertAllDifferent(randoms) # Tests graph mode. @def_function.function def f(): return gen.make_seeds(count=count) for _ in range(3): f()
def testSimple(self, alg): """A simple test.""" with ops.device(xla_device_name()): gen = random.Generator(seed=0, algorithm=alg) gen.normal(shape=(3, )) gen.uniform(shape=(3, ), minval=0, maxval=10, dtype=dtypes.uint32) gen.uniform_full_int(shape=(3, ))
def testNormalIsFinite(self): with ops.device(xla_device_name()): gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) for dtype in self._floats: x = gen.normal(shape=[10000], dtype=dtype).numpy() self.assertTrue(np.all(np.isfinite(x)))
def testNormalIsNotConstant(self, alg): with ops.device(xla_device_name()): gen = random.Generator(seed=1234, algorithm=alg) def rng(dtype): return gen.normal(shape=[2], dtype=dtype) for dtype in self._floats: self._testRngIsNotConstant(rng, dtype)
def testSkip(self): key = 1234 counter = 5678 gen = random.Generator(state=[counter, 0, key], alg=random.RNG_ALG_PHILOX) delta = 432 gen.skip(delta) new_counter = gen._state_var[0] self.assertAllEqual(counter + delta * 256, new_counter)
def testTruncatedNormal(self, alg): with ops.device(xla_device_name()): for dtype in self._floats: gen = random.Generator(seed=123, algorithm=alg) n = 10000000 y = gen.truncated_normal(shape=[n], dtype=dtype).numpy() random_test_util.test_truncated_normal( self.assertEqual, self.assertAllClose, dtype, n, y)
def testKey(self): key = 1234 gen = random.Generator(state=[0, 0, key], alg=random.RNG_ALG_PHILOX) got = gen.key self.assertAllEqual(key, got) @def_function.function def f(): return gen.key got = f() self.assertAllEqual(key, got)
def testDefun(self, alg): """Test for defun.""" with ops.device(xla_device_name()): gen = random.Generator(seed=0, algorithm=alg) @def_function.function def f(): x = gen.normal(shape=(3,)) y = gen.uniform(shape=(3,), minval=0, maxval=10, dtype=dtypes.uint32) z = gen.uniform_full_int(shape=(3,)) return (x, y, z) f()
def testUniformIsInRange(self, alg): minval = 2 maxval = 33 size = 1000 with ops.device(xla_device_name()): for dtype in self._ints + self._floats: gen = random.Generator(seed=1234, algorithm=alg) x = gen.uniform( shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy() self.assertTrue(np.all(x >= minval)) self.assertTrue(np.all(x <= maxval))
def testDistributionOfNormal(self): """Use Anderson-Darling test to test distribution appears normal.""" with ops.device(xla_device_name()): n = 1000 for dtype in {dtypes.float32}: gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) x = gen.normal(shape=[n], dtype=dtype).numpy() # The constant 2.492 is the 5% critical value for the Anderson-Darling # test where the mean and variance are known. This test is probabilistic # so to avoid flakiness the seed is fixed. self.assertLess(self._anderson_darling(x.astype(float)), 2.492)
def testUniformIsNotConstant(self, alg): with ops.device(xla_device_name()): gen = random.Generator(seed=1234, algorithm=alg) def rng(dtype): maxval = dtype.max # Workaround for b/125364959 if dtype == dtypes.uint64: maxval = 10000000 return gen.uniform(shape=[2], dtype=dtype, maxval=maxval) for dtype in self._ints + self._floats: self._testRngIsNotConstant(rng, dtype)
def testKey(self): key = 1234 gen = random.Generator(seed=[0, 0, key]) got = gen.key self.assertAllEqual(key, got) @def_function.function def f(): return gen.key got = f() self.assertAllEqual(key, got)