def testRngRandomBits(self): # Test specific outputs to ensure consistent random values between JAX versions. key = random.PRNGKey(1701) # U8 and U16 are not supported on TPU. if jtu.device_under_test() != "tpu": bits8 = random._random_bits(key, 8, (3, )) expected8 = np.array([216, 115, 43], dtype=np.uint8) self.assertArraysEqual(bits8, expected8) bits16 = random._random_bits(key, 16, (3, )) expected16 = np.array([41682, 1300, 55017], dtype=np.uint16) self.assertArraysEqual(bits16, expected16) bits32 = random._random_bits(key, 32, (3, )) expected32 = np.array([56197195, 4200222568, 961309823], dtype=np.uint32) self.assertArraysEqual(bits32, expected32) bits64 = random._random_bits(key, 64, (3, )) if FLAGS.jax_enable_x64: expected64 = np.array([ 3982329540505020460, 16822122385914693683, 7882654074788531506 ], dtype=np.uint64) else: expected64 = np.array([676898860, 3164047411, 4010691890], dtype=np.uint32) self.assertArraysEqual(bits64, expected64)
def testRngRandomBitsViewProperty(self): # TODO: add 64-bit if it ever supports this property. # TODO: will this property hold across endian-ness? N = 10 key = random.PRNGKey(1701) nbits = [8, 16, 32] rand_bits = [random._random_bits(key, n, (N * 64 // n,)) for n in nbits] rand_bits_32 = np.array([np.array(r).view(np.uint32) for r in rand_bits]) assert np.all(rand_bits_32 == rand_bits_32[0])
def testRngRandomBitsViewProperty(self): # TODO: add 64-bit if it ever supports this property. # TODO: will this property hold across endian-ness? N = 10 key = random.PRNGKey(1701) nbits = [8, 16, 32] if jtu.device_under_test() == "tpu": # U8 and U16 are not supported on TPU. nbits = [32] rand_bits = [ random._random_bits(key, n, (N * 64 // n, )) for n in nbits ] rand_bits_32 = np.array( [np.array(r).view(np.uint32) for r in rand_bits]) print(rand_bits_32) assert np.all(rand_bits_32 == rand_bits_32[0])