コード例 #1
0
ファイル: random_test.py プロジェクト: samuela/jax
    def testRngRandomBits(self):
        # Test specific outputs to ensure consistent random values between JAX versions.
        key = random.PRNGKey(1701)

        # U8 and U16 are not supported on TPU.
        if jtu.device_under_test() != "tpu":
            bits8 = random._random_bits(key, 8, (3, ))
            expected8 = np.array([216, 115, 43], dtype=np.uint8)
            self.assertArraysEqual(bits8, expected8)

            bits16 = random._random_bits(key, 16, (3, ))
            expected16 = np.array([41682, 1300, 55017], dtype=np.uint16)
            self.assertArraysEqual(bits16, expected16)

        bits32 = random._random_bits(key, 32, (3, ))
        expected32 = np.array([56197195, 4200222568, 961309823],
                              dtype=np.uint32)
        self.assertArraysEqual(bits32, expected32)

        bits64 = random._random_bits(key, 64, (3, ))
        if FLAGS.jax_enable_x64:
            expected64 = np.array([
                3982329540505020460, 16822122385914693683, 7882654074788531506
            ],
                                  dtype=np.uint64)
        else:
            expected64 = np.array([676898860, 3164047411, 4010691890],
                                  dtype=np.uint32)
        self.assertArraysEqual(bits64, expected64)
コード例 #2
0
ファイル: random_test.py プロジェクト: tigerneil/jax
 def testRngRandomBitsViewProperty(self):
   # TODO: add 64-bit if it ever supports this property.
   # TODO: will this property hold across endian-ness?
   N = 10
   key = random.PRNGKey(1701)
   nbits = [8, 16, 32]
   rand_bits = [random._random_bits(key, n, (N * 64 // n,)) for n in nbits]
   rand_bits_32 = np.array([np.array(r).view(np.uint32) for r in rand_bits])
   assert np.all(rand_bits_32 == rand_bits_32[0])
コード例 #3
0
ファイル: random_test.py プロジェクト: susannaaz/jax
 def testRngRandomBitsViewProperty(self):
     # TODO: add 64-bit if it ever supports this property.
     # TODO: will this property hold across endian-ness?
     N = 10
     key = random.PRNGKey(1701)
     nbits = [8, 16, 32]
     if jtu.device_under_test() == "tpu":
         # U8 and U16 are not supported on TPU.
         nbits = [32]
     rand_bits = [
         random._random_bits(key, n, (N * 64 // n, )) for n in nbits
     ]
     rand_bits_32 = np.array(
         [np.array(r).view(np.uint32) for r in rand_bits])
     print(rand_bits_32)
     assert np.all(rand_bits_32 == rand_bits_32[0])