Example #1
0
 def test_make_array(self, jit):
   func = _maybe_jit(jit, lambda: jnp.arange(10.0))
   dtype_start = func().dtype
   with enable_x64():
     self.assertEqual(func().dtype, "float64")
   with disable_x64():
     self.assertEqual(func().dtype, "float32")
   self.assertEqual(func().dtype, dtype_start)
Example #2
0
 def test_make_array(self, jit):
     func = jit(lambda: jnp.array(np.float64(0)))
     dtype_start = func().dtype
     with enable_x64():
         self.assertEqual(func().dtype, "float64")
     with disable_x64():
         self.assertEqual(func().dtype, "float32")
     self.assertEqual(func().dtype, dtype_start)
Example #3
0
  def test_while_loop(self, jit):
    @partial(_maybe_jit, jit)
    def count_to(N):
      return lax.while_loop(lambda x: x < N, lambda x: x + 1.0, 0.0)

    with enable_x64():
      self.assertArraysEqual(count_to(10), jnp.float64(10), check_dtypes=True)

    with disable_x64():
      self.assertArraysEqual(count_to(10), jnp.float32(10), check_dtypes=True)
Example #4
0
 def test_make_array(self, jit):
     if jit == "cpp" and not config.omnistaging_enabled:
         self.skipTest("cpp_jit requires omnistaging")
     func = _maybe_jit(jit, lambda: jnp.arange(10.0))
     dtype_start = func().dtype
     with enable_x64():
         self.assertEqual(func().dtype, 'float64')
     with disable_x64():
         self.assertEqual(func().dtype, 'float32')
     self.assertEqual(func().dtype, dtype_start)
Example #5
0
  def test_convert_element_type(self):
    # Regression test for part of https://github.com/google/jax/issues/5982
    with enable_x64():
      x = jnp.int64(1)
    self.assertEqual(x.dtype, jnp.int64)

    y = x.astype(jnp.int32)
    self.assertEqual(y.dtype, jnp.int32)

    z = api.jit(lambda x: x.astype(jnp.int32))(x)
    self.assertEqual(z.dtype, jnp.int32)
Example #6
0
  def test_jit_cache(self):
    if jtu.device_under_test() == "tpu":
      self.skipTest("64-bit random not available on TPU")

    f = partial(random.uniform, random.PRNGKey(0), (1,), 'float64', -1, 1)
    with disable_x64():
      for _ in range(2):
        f()
    with enable_x64():
      for _ in range(2):
        f()
Example #7
0
  def test_while_loop(self, jit):
    if jit == "cpp" and not config.omnistaging_enabled:
      self.skipTest("cpp_jit requires omnistaging")
    @partial(_maybe_jit, jit)
    def count_to(N):
      return lax.while_loop(lambda x: x < N, lambda x: x + 1.0, 0.0)

    with enable_x64():
      self.assertArraysEqual(count_to(10), jnp.float64(10), check_dtypes=True)

    with disable_x64():
      self.assertArraysEqual(count_to(10), jnp.float32(10), check_dtypes=True)
Example #8
0
  def test_correctly_capture_default(self, jit, enable_or_disable):
    # The fact we defined a jitted function with a block with a different value
    # of `config.enable_x64` has no impact on the output.
    with enable_or_disable():
      func = _maybe_jit(jit, lambda: jnp.arange(10.0))
      func()

    expected_dtype = "float64" if config._read("jax_enable_x64") else "float32"
    self.assertEqual(func().dtype, expected_dtype)

    with enable_x64():
      self.assertEqual(func().dtype, "float64")
    with disable_x64():
      self.assertEqual(func().dtype, "float32")
Example #9
0
    def test_jit_cache(self):
        # TODO(jakevdp): enable this test when CPP jit cache is fixed.
        if FLAGS.experimental_cpp_jit:
            self.skipTest(
                "Known failure due to https://github.com/google/jax/issues/5532"
            )

        f = partial(random.uniform, random.PRNGKey(0), (1, ), 'float64', -1, 1)
        with disable_x64():
            for _ in range(2):
                f()
        with enable_x64():
            for _ in range(2):
                f()
Example #10
0
  def test_jit_cache(self):
    if jtu.device_under_test() == "tpu":
      self.skipTest("64-bit random not available on TPU")
    if jax.lib._xla_extension_version < 4 and FLAGS.experimental_cpp_jit:
      self.skipTest(
          "Known failure due to https://github.com/google/jax/issues/5532")

    f = partial(random.uniform, random.PRNGKey(0), (1,), 'float64', -1, 1)
    with disable_x64():
      for _ in range(2):
        f()
    with enable_x64():
      for _ in range(2):
        f()
Example #11
0
    def test_near_singular_inverse(self, jit):
        rng = jtu.rand_default(self.rng())

        @partial(_maybe_jit, jit, static_argnums=1)
        def near_singular_inverse(N=5, eps=1E-40):
            X = rng((N, N), dtype='float64')
            X = jnp.asarray(X)
            X = X.at[-1].mul(eps)
            return jnp.linalg.inv(X)

        with enable_x64():
            result_64 = near_singular_inverse()
            self.assertTrue(jnp.all(jnp.isfinite(result_64)))

        with disable_x64():
            result_32 = near_singular_inverse()
            self.assertTrue(jnp.all(~jnp.isfinite(result_32)))
Example #12
0
  def test_near_singular_inverse(self, jit):
    if jtu.device_under_test() == "tpu":
      self.skipTest("64-bit inverse not available on TPU")
    @partial(_maybe_jit, jit, static_argnums=1)
    def near_singular_inverse(key, N, eps):
      X = random.uniform(key, (N, N))
      X = X.at[-1].mul(eps)
      return jnp.linalg.inv(X)

    key = random.PRNGKey(1701)
    eps = 1E-40
    N = 5

    with enable_x64():
      result_64 = near_singular_inverse(key, N, eps)
      self.assertTrue(jnp.all(jnp.isfinite(result_64)))

    with disable_x64():
      result_32 = near_singular_inverse(key, N, eps)
      self.assertTrue(jnp.all(~jnp.isfinite(result_32)))
Example #13
0
    def test_near_singular_inverse(self, jit):
        if jit == "cpp" and not config.omnistaging_enabled:
            self.skipTest("cpp_jit requires omnistaging")

        @partial(_maybe_jit, jit, static_argnums=1)
        def near_singular_inverse(key, N, eps):
            X = random.uniform(key, (N, N))
            X = X.at[-1].mul(eps)
            return jnp.linalg.inv(X)

        key = random.PRNGKey(1701)
        eps = 1E-40
        N = 5

        with enable_x64():
            result_64 = near_singular_inverse(key, N, eps)
            self.assertTrue(jnp.all(jnp.isfinite(result_64)))

        with disable_x64():
            result_32 = near_singular_inverse(key, N, eps)
            self.assertTrue(jnp.all(~jnp.isfinite(result_32)))
Example #14
0
 def func_x64():
   with enable_x64():
     time.sleep(0.1)
     return jnp.arange(10).dtype
Example #15
0
def _jax_enable64() -> Generator[None, None, None]:
    with enable_x64():
        yield
Example #16
0
 def func_x64():
     with enable_x64():
         time.sleep(0.1)
         return jnp.array(np.int64(0)).dtype