def test_correctly_capture_default(self, jit, enable_or_disable): # The fact we defined a jitted function with a block with a different value # of `config.enable_x64` has no impact on the output. with enable_or_disable(): func = _maybe_jit(jit, lambda: jnp.arange(10.0)) func() expected_dtype = "float64" if config._read("jax_enable_x64") else "float32" self.assertEqual(func().dtype, expected_dtype) with enable_x64(): self.assertEqual(func().dtype, "float64") with disable_x64(): self.assertEqual(func().dtype, "float32")
def setUp(self): self.cfg = config._read("jax_debug_nans") config.update("jax_debug_nans", True)
def setUp(self): self.cfg = config._read("jax_experimental_name_stack") config.update("jax_experimental_name_stack", True)