예제 #1
0
def prng(s):
  """Creates RNG state from seed.

  This implementation doesn't pass RNG states explicitly so the result is
  always a dummy 0.

  Args:
    s: the seed, an integer.

  Returns:
    A dummy integer 0.
  """
  # TODO(wangpeng): change it to use stateless random ops to truely mimic JAX
  #   RNGs
  random.seed(s)
  # Returning None will cause errors in some layer/optimizer libraries based on
  # JAX
  return asarray(0, dtype=np.int64)
예제 #2
0
        def run_test(*args):
            num_samples = 1000
            tol = 0.1  # High tolerance to keep the # of samples low else the test
            # takes a long time to run.
            random.seed(10)
            outputs = [random.randn(*args) for _ in range(num_samples)]

            # Test output shape.
            for output in outputs:
                self.assertEqual(output.shape, tuple(args))
                self.assertEqual(output.dtype.type, random.DEFAULT_RANDN_DTYPE)

            if np.prod(args):  # Don't bother with empty arrays.
                outputs = [output.tolist() for output in outputs]

                # Test that the properties of normal distribution are satisfied.
                mean = np.mean(outputs, axis=0)
                stddev = np.std(outputs, axis=0)
                self.assertAllClose(mean, np.zeros(args), atol=tol)
                self.assertAllClose(stddev, np.ones(args), atol=tol)

                # Test that outputs are different with different seeds.
                random.seed(20)
                diff_seed_outputs = [
                    random.randn(*args).tolist() for _ in range(num_samples)
                ]
                self.assertNotAllClose(outputs, diff_seed_outputs)

                # Test that outputs are the same with the same seed.
                random.seed(10)
                same_seed_outputs = [
                    random.randn(*args).tolist() for _ in range(num_samples)
                ]
                self.assertAllClose(outputs, same_seed_outputs)