def test_make_array(self, jit): func = _maybe_jit(jit, lambda: jnp.arange(10.0)) dtype_start = func().dtype with enable_x64(): self.assertEqual(func().dtype, "float64") with disable_x64(): self.assertEqual(func().dtype, "float32") self.assertEqual(func().dtype, dtype_start)
def test_make_array(self, jit): func = jit(lambda: jnp.array(np.float64(0))) dtype_start = func().dtype with enable_x64(): self.assertEqual(func().dtype, "float64") with disable_x64(): self.assertEqual(func().dtype, "float32") self.assertEqual(func().dtype, dtype_start)
def test_while_loop(self, jit): @partial(_maybe_jit, jit) def count_to(N): return lax.while_loop(lambda x: x < N, lambda x: x + 1.0, 0.0) with enable_x64(): self.assertArraysEqual(count_to(10), jnp.float64(10), check_dtypes=True) with disable_x64(): self.assertArraysEqual(count_to(10), jnp.float32(10), check_dtypes=True)
def test_make_array(self, jit): if jit == "cpp" and not config.omnistaging_enabled: self.skipTest("cpp_jit requires omnistaging") func = _maybe_jit(jit, lambda: jnp.arange(10.0)) dtype_start = func().dtype with enable_x64(): self.assertEqual(func().dtype, 'float64') with disable_x64(): self.assertEqual(func().dtype, 'float32') self.assertEqual(func().dtype, dtype_start)
def test_convert_element_type(self): # Regression test for part of https://github.com/google/jax/issues/5982 with enable_x64(): x = jnp.int64(1) self.assertEqual(x.dtype, jnp.int64) y = x.astype(jnp.int32) self.assertEqual(y.dtype, jnp.int32) z = api.jit(lambda x: x.astype(jnp.int32))(x) self.assertEqual(z.dtype, jnp.int32)
def test_jit_cache(self): if jtu.device_under_test() == "tpu": self.skipTest("64-bit random not available on TPU") f = partial(random.uniform, random.PRNGKey(0), (1,), 'float64', -1, 1) with disable_x64(): for _ in range(2): f() with enable_x64(): for _ in range(2): f()
def test_while_loop(self, jit): if jit == "cpp" and not config.omnistaging_enabled: self.skipTest("cpp_jit requires omnistaging") @partial(_maybe_jit, jit) def count_to(N): return lax.while_loop(lambda x: x < N, lambda x: x + 1.0, 0.0) with enable_x64(): self.assertArraysEqual(count_to(10), jnp.float64(10), check_dtypes=True) with disable_x64(): self.assertArraysEqual(count_to(10), jnp.float32(10), check_dtypes=True)
def test_correctly_capture_default(self, jit, enable_or_disable): # The fact we defined a jitted function with a block with a different value # of `config.enable_x64` has no impact on the output. with enable_or_disable(): func = _maybe_jit(jit, lambda: jnp.arange(10.0)) func() expected_dtype = "float64" if config._read("jax_enable_x64") else "float32" self.assertEqual(func().dtype, expected_dtype) with enable_x64(): self.assertEqual(func().dtype, "float64") with disable_x64(): self.assertEqual(func().dtype, "float32")
def test_jit_cache(self): # TODO(jakevdp): enable this test when CPP jit cache is fixed. if FLAGS.experimental_cpp_jit: self.skipTest( "Known failure due to https://github.com/google/jax/issues/5532" ) f = partial(random.uniform, random.PRNGKey(0), (1, ), 'float64', -1, 1) with disable_x64(): for _ in range(2): f() with enable_x64(): for _ in range(2): f()
def test_jit_cache(self): if jtu.device_under_test() == "tpu": self.skipTest("64-bit random not available on TPU") if jax.lib._xla_extension_version < 4 and FLAGS.experimental_cpp_jit: self.skipTest( "Known failure due to https://github.com/google/jax/issues/5532") f = partial(random.uniform, random.PRNGKey(0), (1,), 'float64', -1, 1) with disable_x64(): for _ in range(2): f() with enable_x64(): for _ in range(2): f()
def test_near_singular_inverse(self, jit): rng = jtu.rand_default(self.rng()) @partial(_maybe_jit, jit, static_argnums=1) def near_singular_inverse(N=5, eps=1E-40): X = rng((N, N), dtype='float64') X = jnp.asarray(X) X = X.at[-1].mul(eps) return jnp.linalg.inv(X) with enable_x64(): result_64 = near_singular_inverse() self.assertTrue(jnp.all(jnp.isfinite(result_64))) with disable_x64(): result_32 = near_singular_inverse() self.assertTrue(jnp.all(~jnp.isfinite(result_32)))
def test_near_singular_inverse(self, jit): if jtu.device_under_test() == "tpu": self.skipTest("64-bit inverse not available on TPU") @partial(_maybe_jit, jit, static_argnums=1) def near_singular_inverse(key, N, eps): X = random.uniform(key, (N, N)) X = X.at[-1].mul(eps) return jnp.linalg.inv(X) key = random.PRNGKey(1701) eps = 1E-40 N = 5 with enable_x64(): result_64 = near_singular_inverse(key, N, eps) self.assertTrue(jnp.all(jnp.isfinite(result_64))) with disable_x64(): result_32 = near_singular_inverse(key, N, eps) self.assertTrue(jnp.all(~jnp.isfinite(result_32)))
def test_near_singular_inverse(self, jit): if jit == "cpp" and not config.omnistaging_enabled: self.skipTest("cpp_jit requires omnistaging") @partial(_maybe_jit, jit, static_argnums=1) def near_singular_inverse(key, N, eps): X = random.uniform(key, (N, N)) X = X.at[-1].mul(eps) return jnp.linalg.inv(X) key = random.PRNGKey(1701) eps = 1E-40 N = 5 with enable_x64(): result_64 = near_singular_inverse(key, N, eps) self.assertTrue(jnp.all(jnp.isfinite(result_64))) with disable_x64(): result_32 = near_singular_inverse(key, N, eps) self.assertTrue(jnp.all(~jnp.isfinite(result_32)))
def func_x64(): with enable_x64(): time.sleep(0.1) return jnp.arange(10).dtype
def _jax_enable64() -> Generator[None, None, None]: with enable_x64(): yield
def func_x64(): with enable_x64(): time.sleep(0.1) return jnp.array(np.int64(0)).dtype