def testGPUSameAsOldRandomOps(self):
        """Tests that the generated numbers are the same as the old random_ops.py.

    The GPU version.
    """
        seed1, seed2 = 79, 25
        with ops.device(test_util.gpu_device_name()):
            random.reset_global_generator([0, seed2, seed1])
        shape = constant_op.constant([4, 7])
        dtype = dtypes.float64

        @def_function.function
        def old():
            with ops.device(test_util.gpu_device_name()):
                return gen_random_ops.random_standard_normal(shape,
                                                             dtype=dtype,
                                                             seed=seed1,
                                                             seed2=seed2)

        def new():
            with ops.device(test_util.gpu_device_name()):
                return random.get_global_generator().standard_normal(
                    shape, dtype=dtype)

        for _ in range(100):
            self.assertAllEqual(old(), new())
  def testThreefry2x32(self):
    """Tests ThreeFry2x32 conforms to known results.
    """
    # Based on
    # https://github.com/google/jax/blob/8565a3486adf16beb388b2364c9cd930d7a0d92d/tests/random_test.py#L65-L85
    # which is in turn based on
    # https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_threefry.cpp#L30-L32

    def uint32s_to_uint64(a, b):
      return b << 32 | a

    def verify(counter1, counter2, key1, key2, expect1, expect2):
      counter = uint32s_to_uint64(counter1, counter2)
      key = uint32s_to_uint64(key1, key2)
      random.get_global_generator().reset([counter, key])
      got = random.get_global_generator().uniform_full_int(
          shape=(2,), dtype=dtypes.uint32)
      expect = [expect1, expect2]
      self.assertAllEqual(expect, got)
      random.get_global_generator().reset([counter, key])
      got = random.get_global_generator().uniform_full_int(
          shape=(), dtype=dtypes.uint64)
      self.assertAllEqual(uint32s_to_uint64(*expect), got)

    with ops.device(xla_device_name()):
      random.reset_global_generator(seed=0, algorithm=random.RNG_ALG_THREEFRY)
      verify(0x00000000, 0x00000000, 0x00000000, 0x00000000,
             0x6b200159, 0x99ba4efe)
      verify(0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
             0x1cb996fc, 0xbb002be7)
      verify(0x243f6a88, 0x85a308d3, 0x13198a2e, 0x03707344,
             0xc4923a9c, 0x483df7a0)
    def testCPUSameAsOldRandomOps(self):
        """Tests that the generated numbers are the same as the old random_ops.py.

    The CPU version.
    """
        seed1, seed2 = 79, 25
        # note how the two seeds for the old op correspond to the seed for the new
        # op
        with ops.device("/device:CPU:0"):
            random.reset_global_generator([0, seed2, seed1])
        shape = constant_op.constant([4, 7])
        dtype = dtypes.float64

        # create a graph for the old op in order to call it many times
        @def_function.function
        def old():
            with ops.device("/device:CPU:0"):
                return gen_random_ops.random_standard_normal(shape,
                                                             dtype=dtype,
                                                             seed=seed1,
                                                             seed2=seed2)

        def new():
            with ops.device("/device:CPU:0"):
                return random.get_global_generator().standard_normal(
                    shape, dtype=dtype)

        for _ in range(100):
            self.assertAllEqual(old(), new())
  def testThreefry2x32(self):
    """Tests ThreeFry2x32 conforms to known results.
    """
    # Based on
    # https://github.com/google/jax/blob/8565a3486adf16beb388b2364c9cd930d7a0d92d/tests/random_test.py#L65-L85
    # which is in turn based on
    # https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_threefry.cpp#L30-L32

    def uint32s_to_uint64(a, b):
      return b << 32 | a

    def verify(counter1, counter2, key1, key2, expect1, expect2):
      counter = uint32s_to_uint64(counter1, counter2)
      key = uint32s_to_uint64(key1, key2)
      random.get_global_generator().reset([counter, key])
      got = random.get_global_generator().uniform_full_int(
          shape=(2,), dtype=dtypes.uint32)
      expect = [expect1, expect2]
      self.assertAllEqual(expect, got)
      random.get_global_generator().reset([counter, key])
      got = random.get_global_generator().uniform_full_int(
          shape=(), dtype=dtypes.uint64)
      self.assertAllEqual(uint32s_to_uint64(*expect), got)

    with ops.device(xla_device_name()):
      random.reset_global_generator(seed=0, algorithm=random.RNG_ALG_THREEFRY)
      verify(0x00000000, 0x00000000, 0x00000000, 0x00000000,
             0x6b200159, 0x99ba4efe)
      verify(0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
             0x1cb996fc, 0xbb002be7)
      verify(0x243f6a88, 0x85a308d3, 0x13198a2e, 0x03707344,
             0xc4923a9c, 0x483df7a0)
Example #5
0
  def testResetGlobalGeneratorBadWithDefun(self):
    """Demonstrates that reset_global_generator don't work properly with defun.
    """
    shape = (3,)

    @def_function.function
    def f():
      return random.get_global_generator().normal(shape)

    random.reset_global_generator(50)
    with self.assertRaisesWithPredicateMatch(
        errors.NotFoundError, "Resource .+ does not exist"):
      _ = f()
      random.reset_global_generator(50)
      _ = f()
    def testResetGlobalGeneratorBadWithDefun(self):
        """Demonstrates that reset_global_generator don't work properly with defun.
    """
        shape = (3, )

        @def_function.function
        def f():
            return random.get_global_generator().normal(shape)

        random.reset_global_generator(50)
        with self.assertRaisesWithPredicateMatch(AssertionError,
                                                 "variable.*deleted"):
            a = f()
            random.reset_global_generator(50)
            b = f()
            self.assertAllEqual(a, b)
Example #7
0
  def testThreefry2x32(self):
    """Tests ThreeFry2x32 conforms to known results.
    """
    # Based on
    # https://github.com/google/jax/blob/8565a3486adf16beb388b2364c9cd930d7a0d92d/tests/random_test.py#L65-L85
    # which is in turn based on
    # https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_threefry.cpp#L30-L32

    with ops.device(xla_device_name()):
      random.reset_global_generator(seed=0, algorithm=random.RNG_ALG_THREEFRY)
      self._compareToKnownOutputs(
          [0x00000000, 0x00000000], [0x00000000, 0x00000000],
          [0x6b200159, 0x99ba4efe])
      self._compareToKnownOutputs(
          [0xffffffff, 0xffffffff], [0xffffffff, 0xffffffff],
          [0x1cb996fc, 0xbb002be7])
      self._compareToKnownOutputs(
          [0x243f6a88, 0x85a308d3], [0x13198a2e, 0x03707344],
          [0xc4923a9c, 0x483df7a0])
Example #8
0
  def testPhilox4x32(self):
    """Tests Philox4x32 conforms to known results.
    """
    # Based on
    # https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_philox.cpp#L50-L52

    with ops.device(xla_device_name()):
      random.reset_global_generator(seed=0, algorithm=random.RNG_ALG_PHILOX)
      self._compareToKnownOutputs(
          [0x00000000, 0x00000000, 0x00000000, 0x00000000],
          [0x00000000, 0x00000000],
          [0x6627e8d5, 0xe169c58d, 0xbc57ac4c, 0x9b00dbd8])
      self._compareToKnownOutputs(
          [0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff],
          [0xffffffff, 0xffffffff],
          [0x408f276d, 0x41c83b0e, 0xa20bc7c6, 0x6d5451fd])
      self._compareToKnownOutputs(
          [0x243f6a88, 0x85a308d3, 0x13198a2e, 0x03707344],
          [0xa4093822, 0x299f31d0],
          [0xd16cfe09, 0x94fdcceb, 0x5001e420, 0x24126ea1])