Ejemplo n.º 1
0
  def test_correctly_capture_default(self, jit, enable_or_disable):
    # The fact we defined a jitted function with a block with a different value
    # of `config.enable_x64` has no impact on the output.
    with enable_or_disable():
      func = _maybe_jit(jit, lambda: jnp.arange(10.0))
      func()

    expected_dtype = "float64" if config._read("jax_enable_x64") else "float32"
    self.assertEqual(func().dtype, expected_dtype)

    with enable_x64():
      self.assertEqual(func().dtype, "float64")
    with disable_x64():
      self.assertEqual(func().dtype, "float32")
Ejemplo n.º 2
0
 def setUp(self):
     self.cfg = config._read("jax_debug_nans")
     config.update("jax_debug_nans", True)
Ejemplo n.º 3
0
 def setUp(self):
     self.cfg = config._read("jax_experimental_name_stack")
     config.update("jax_experimental_name_stack", True)