Exemple #1
0
 def testNewStatePhilox(self):
   """Tests that the new state is correct (for Philox).
   """
   if compat.forward_compatible(2020, 10, 25):
     self.skipTest("The expected values in this test is inconsistent with "
                   "CPU/GPU. testXLAEqualsCPU has the correct checks of the "
                   "new states for the new version.")
   with ops.device(xla_device_name()):
     counter_low = 57
     counter_high = 283
     key = 0x1234
     size = 47
     state = [counter_low, counter_high, key]
     gen = random.Generator(state=state, alg=random.RNG_ALG_PHILOX)
     gen.uniform_full_int(shape=(size,), dtype=dtypes.uint32)
     self.assertAllEqual([counter_low+(size+3)//4, counter_high, key],
                         gen.state.read_value())
     gen.reset(state)
     gen.uniform_full_int(shape=(size,), dtype=dtypes.uint64)
     self.assertAllEqual([counter_low+(size+1)//2, counter_high, key],
                         gen.state.read_value())
     # Tests that large counter_low will correctly overflows to counter_high
     counter_low = -1  # same as 0xffffffffffffffff
     counter_high = 283
     size = 47
     state = [counter_low, counter_high, key]
     gen = random.Generator(state=state, alg=random.RNG_ALG_PHILOX)
     gen.uniform_full_int(shape=(size,), dtype=dtypes.uint32)
     self.assertAllEqual([(size+3)//4-1, counter_high+1, key],
                         gen.state.read_value())
     gen.reset(state)
     gen.uniform_full_int(shape=(size,), dtype=dtypes.uint64)
     self.assertAllEqual([(size+1)//2-1, counter_high+1, key],
                         gen.state.read_value())
Exemple #2
0
    def testTF1(self):
        seed = 1234
        shape = [2, 3]
        expected_normal1 = constant_op.constant(
            [[0.9356609, 1.0854305, -0.93788373],
             [-0.50615472, 1.31697023, 0.71375787]],
            dtype=dtypes.float32)
        expected_normal2 = constant_op.constant(
            [[-0.3964749, 0.8369565, -0.30946946],
             [1.1206646, 1.00852597, -0.10185789]],
            dtype=dtypes.float32)
        with self.cached_session() as sess:
            gen1 = random.Generator(seed=seed)
            gen2 = random.Generator()
            sess.run(
                (gen1._state_var.initializer, gen2._state_var.initializer))
            r1 = gen1.normal(shape, dtype=dtypes.float32)
            r2 = gen2.normal(shape, dtype=dtypes.float32)

            def f():
                return sess.run((r1, r2))

            def check_results(expected_normal, v1, v2):
                self.assertAllClose(expected_normal, v1, rtol=1e-5, atol=1e-5)
                self.assertAllEqual(shape, v2.shape)

            check_results(expected_normal1, *f())
            check_results(expected_normal2, *f())
Exemple #3
0
 def testNewStatePhilox(self):
     """Tests that the new state is correct (for Philox).
 """
     with ops.device(xla_device_name()):
         counter_low = 57
         counter_high = 283
         key = 0x1234
         size = 47
         state = [counter_low, counter_high, key]
         gen = random.Generator(state=state, alg=random.RNG_ALG_PHILOX)
         gen.uniform_full_int(shape=(size, ), dtype=dtypes.uint32)
         self.assertAllEqual(
             [counter_low + (size + 3) // 4, counter_high, key],
             gen.state.read_value())
         gen.reset(state)
         gen.uniform_full_int(shape=(size, ), dtype=dtypes.uint64)
         self.assertAllEqual(
             [counter_low + (size + 1) // 2, counter_high, key],
             gen.state.read_value())
         # Tests that large counter_low will correctly overflows to counter_high
         counter_low = -1  # same as 0xffffffffffffffff
         counter_high = 283
         size = 47
         state = [counter_low, counter_high, key]
         gen = random.Generator(state=state, alg=random.RNG_ALG_PHILOX)
         gen.uniform_full_int(shape=(size, ), dtype=dtypes.uint32)
         self.assertAllEqual([(size + 3) // 4 - 1, counter_high + 1, key],
                             gen.state.read_value())
         gen.reset(state)
         gen.uniform_full_int(shape=(size, ), dtype=dtypes.uint64)
         self.assertAllEqual([(size + 1) // 2 - 1, counter_high + 1, key],
                             gen.state.read_value())
 def testGPUEqualsCPU(self, dtype):
     """Tests that GPU and CPU generate the same integer outputs."""
     seed = 1234
     shape = [315, 49]
     with ops.device("/device:CPU:0"):
         cpu = random.Generator(seed=seed).uniform_full_int(shape=shape,
                                                            dtype=dtype)
     with ops.device(test_util.gpu_device_name()):
         gpu = random.Generator(seed=seed).uniform_full_int(shape=shape,
                                                            dtype=dtype)
     self.assertAllEqual(cpu, gpu)
Exemple #5
0
 def testXLAEqualsCPU(self, dtype):
     """Tests that XLA and CPU kernels generate the same integers."""
     seed = 1234
     shape = [315, 49]
     with ops.device("/device:CPU:0"):
         cpu = (random.Generator(
             seed=seed,
             algorithm=random.RNG_ALG_PHILOX).uniform_full_int(shape=shape,
                                                               dtype=dtype))
     with ops.device(xla_device_name()):
         xla = (random.Generator(
             seed=seed,
             algorithm=random.RNG_ALG_PHILOX).uniform_full_int(shape=shape,
                                                               dtype=dtype))
     self.assertAllEqual(cpu, xla)
Exemple #6
0
 def f():
   global g_seeded
   # defun'ed function should only create variables once
   if g_seeded is None:
     g_seeded = random.Generator(g)
   self.assertAllEqual(g.algorithm, g_seeded.algorithm)
   self.assertAllEqual(g.state.read_value(), g_seeded.state.read_value())
Exemple #7
0
  def testGeneratorCreation(self):
    """Tests generator creation, in both eager and tf.function.

    The interaction between Generator creation and defun should be the same as
    tf.Variable.
    """
    shape = [2, 3]
    alg = random.RNG_ALG_PHILOX
    for constructor in [
        lambda: random.Generator(state=[1, 2, 3], alg=alg),
        lambda: random.Generator.from_seed(1234),
        lambda: random.Generator.from_key_counter(  # pylint: disable=g-long-lambda
            key=1, counter=[2, 3], alg=alg),
    ]:
      gen = constructor()
      # Tests tf.function
      expected_normal1 = gen.normal(shape)
      expected_normal2 = gen.normal(shape)
      global g_seeded
      g_seeded = None
      @def_function.function
      def f(constructor):
        global g_seeded
        # defun'ed function should only create variables once
        if g_seeded is None:
          g_seeded = constructor()
        return g_seeded.normal(shape)
      def check_results(expected_normal, v):
        self.assertAllEqual(expected_normal, v)
      check_results(expected_normal1, f(constructor))
      check_results(expected_normal2, f(constructor))
Exemple #8
0
    def testBatchSeeds(self):
        """Test for batch seeds.
    """
        shape = [2, 3]
        count = 6
        gen = random.Generator(seed=1234)
        keys1 = gen._make_int64_keys(shape=shape)
        keys2 = gen._make_int64_keys(shape=shape)
        self.assertAllDifferent([keys1, keys2])
        seeds1 = gen.make_seeds(count=count)
        seeds2 = gen.make_seeds(count=count)
        self.assertAllDifferent([seeds1[0, :], seeds2[0, :]])
        gens = gen.split(count=count)
        self.assertAllEqual(count, len(gens))
        randoms = [
            g.uniform_full_int(shape=shape, dtype=dtypes.int32) for g in gens
        ]
        self.assertAllDifferent(randoms)
        # Tests graph mode.
        @def_function.function
        def f():
            return gen.make_seeds(count=count)

        for _ in range(3):
            f()
Exemple #9
0
 def testSimple(self, alg):
     """A simple test."""
     with ops.device(xla_device_name()):
         gen = random.Generator(seed=0, algorithm=alg)
         gen.normal(shape=(3, ))
         gen.uniform(shape=(3, ), minval=0, maxval=10, dtype=dtypes.uint32)
         gen.uniform_full_int(shape=(3, ))
Exemple #10
0
 def testNormalIsFinite(self):
     with ops.device(xla_device_name()):
         gen = random.Generator(seed=1234,
                                algorithm=random.RNG_ALG_THREEFRY)
         for dtype in self._floats:
             x = gen.normal(shape=[10000], dtype=dtype).numpy()
             self.assertTrue(np.all(np.isfinite(x)))
Exemple #11
0
  def testNormalIsNotConstant(self, alg):
    with ops.device(xla_device_name()):
      gen = random.Generator(seed=1234, algorithm=alg)
      def rng(dtype):
        return gen.normal(shape=[2], dtype=dtype)

      for dtype in self._floats:
        self._testRngIsNotConstant(rng, dtype)
Exemple #12
0
 def testSkip(self):
   key = 1234
   counter = 5678
   gen = random.Generator(state=[counter, 0, key], alg=random.RNG_ALG_PHILOX)
   delta = 432
   gen.skip(delta)
   new_counter = gen._state_var[0]
   self.assertAllEqual(counter + delta * 256, new_counter)
Exemple #13
0
 def testTruncatedNormal(self, alg):
   with ops.device(xla_device_name()):
     for dtype in self._floats:
       gen = random.Generator(seed=123, algorithm=alg)
       n = 10000000
       y = gen.truncated_normal(shape=[n], dtype=dtype).numpy()
       random_test_util.test_truncated_normal(
           self.assertEqual, self.assertAllClose, dtype, n, y)
Exemple #14
0
 def testKey(self):
   key = 1234
   gen = random.Generator(state=[0, 0, key], alg=random.RNG_ALG_PHILOX)
   got = gen.key
   self.assertAllEqual(key, got)
   @def_function.function
   def f():
     return gen.key
   got = f()
   self.assertAllEqual(key, got)
Exemple #15
0
 def testDefun(self, alg):
   """Test for defun."""
   with ops.device(xla_device_name()):
     gen = random.Generator(seed=0, algorithm=alg)
     @def_function.function
     def f():
       x = gen.normal(shape=(3,))
       y = gen.uniform(shape=(3,), minval=0, maxval=10, dtype=dtypes.uint32)
       z = gen.uniform_full_int(shape=(3,))
       return (x, y, z)
     f()
Exemple #16
0
 def testUniformIsInRange(self, alg):
   minval = 2
   maxval = 33
   size = 1000
   with ops.device(xla_device_name()):
     for dtype in self._ints + self._floats:
       gen = random.Generator(seed=1234, algorithm=alg)
       x = gen.uniform(
           shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy()
       self.assertTrue(np.all(x >= minval))
       self.assertTrue(np.all(x <= maxval))
 def testDistributionOfNormal(self):
   """Use Anderson-Darling test to test distribution appears normal."""
   with ops.device(xla_device_name()):
     n = 1000
     for dtype in {dtypes.float32}:
       gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
       x = gen.normal(shape=[n], dtype=dtype).numpy()
       # The constant 2.492 is the 5% critical value for the Anderson-Darling
       # test where the mean and variance are known. This test is probabilistic
       # so to avoid flakiness the seed is fixed.
       self.assertLess(self._anderson_darling(x.astype(float)), 2.492)
Exemple #18
0
  def testUniformIsNotConstant(self, alg):
    with ops.device(xla_device_name()):
      gen = random.Generator(seed=1234, algorithm=alg)
      def rng(dtype):
        maxval = dtype.max
        # Workaround for b/125364959
        if dtype == dtypes.uint64:
          maxval = 10000000
        return gen.uniform(shape=[2], dtype=dtype, maxval=maxval)

      for dtype in self._ints + self._floats:
        self._testRngIsNotConstant(rng, dtype)
    def testKey(self):
        key = 1234
        gen = random.Generator(seed=[0, 0, key])
        got = gen.key
        self.assertAllEqual(key, got)

        @def_function.function
        def f():
            return gen.key

        got = f()
        self.assertAllEqual(key, got)