Пример #1
0
  def testDebugNansDoesntReturnDeoptimizedResult(self):
    @jax.jit
    def f(x):
      x + 2  # avoid trivial dispatch path by adding some eqn
      return jnp.nan

    with self.assertRaisesRegex(FloatingPointError, "de-optimized"):
      with jax.debug_nans(True):
        f(3)
Пример #2
0
    def testLogSumExpNans(self):
        # Regression test for https://github.com/google/jax/issues/7634
        with jax.debug_nans(True):
            with jax.disable_jit():
                result = lsp_special.logsumexp(1.0)
                self.assertEqual(result, 1.0)

                result = lsp_special.logsumexp(1.0, b=1.0)
                self.assertEqual(result, 1.0)
Пример #3
0
 def testJitComputationNaNContextManager(self):
   config.update("jax_debug_nans", False)
   A = jnp.array(0.)
   f = jax.jit(lambda x: 0. / x)
   ans = f(A)
   ans = f(A)
   with self.assertRaises(FloatingPointError):
     with jax.debug_nans(True):
       ans = f(A)
     ans.block_until_ready()
Пример #4
0
  def testDebugNansDoesntCorruptCaches(self):
    # https://github.com/google/jax/issues/6614
    @jax.jit
    def f(x):
      return jnp.divide(x, x)

    for _ in range(2):
      try:
       with jax.debug_nans(True):
         jax.grad(f)(0.)
      except FloatingPointError:
        pass
Пример #5
0
 def test_grad_norm(self):
     e = None
     try:
         with jax.debug_nans(True):
             jax.grad(jnp.linalg.norm)(jnp.zeros((3, 3), jnp.float32))
     except FloatingPointError as exc:
         e = exc
     self.assertIsNot(e, None)
     self.assertIn("invalid value", str(e))
     self.assertIsInstance(
         e.__cause__.__cause__,
         source_info_util.JaxStackTraceBeforeTransformation)
Пример #6
0
 def test_grad_norm(self):
     e = None
     try:
         with jax.debug_nans(True):
             jax.grad(jnp.linalg.norm)(jnp.zeros((3, 3), jnp.float32))
     except FloatingPointError as exc:
         e = exc
     self.assertIsNot(e, None)
     self.assertIn("invalid value", str(e))
     # TODO(phawkins): make this test unconditional after jaxlib 0.1.66 is the
     # minimum.
     if jax.lib._xla_extension_version >= 19:
         self.assertIsInstance(
             e.__cause__.__cause__,
             source_info_util.JaxStackTraceBeforeTransformation)
Пример #7
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = _promote_args_inexact("logsumexp", a, b)
        a = jnp.where(b != 0, a, -jnp.inf)
    else:
        a, = _promote_args_inexact("logsumexp", a)
    pos_dims, dims = _reduction_dims(a, axis)
    amax = jnp.max(a, axis=dims, keepdims=keepdims)
    amax = lax.stop_gradient(
        lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))
    amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
    # fast path if the result cannot be negative.
    if b is None and not np.issubdtype(a.dtype, np.complexfloating):
        out = lax.add(
            lax.log(
                jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
                        axis=dims,
                        keepdims=keepdims)), amax)
        sign = jnp.where(jnp.isnan(out), out, 1.0)
        sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype)
    else:
        expsub = lax.exp(lax.sub(a, amax_with_dims))
        if b is not None:
            expsub = lax.mul(expsub, b)
        sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims)

        sign = lax.stop_gradient(jnp.sign(sumexp))
        if np.issubdtype(sumexp.dtype, np.complexfloating):
            if return_sign:
                sumexp = sign * sumexp
            out = lax.add(lax.log(sumexp), amax)
        else:
            out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (out, sign)
    if b is not None:
        if not np.issubdtype(out.dtype, np.complexfloating):
            with jax.debug_nans(False):
                out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype),
                                out)
    return out