def testDebugNansDoesntReturnDeoptimizedResult(self): @jax.jit def f(x): x + 2 # avoid trivial dispatch path by adding some eqn return jnp.nan with self.assertRaisesRegex(FloatingPointError, "de-optimized"): with jax.debug_nans(True): f(3)
def testLogSumExpNans(self): # Regression test for https://github.com/google/jax/issues/7634 with jax.debug_nans(True): with jax.disable_jit(): result = lsp_special.logsumexp(1.0) self.assertEqual(result, 1.0) result = lsp_special.logsumexp(1.0, b=1.0) self.assertEqual(result, 1.0)
def testJitComputationNaNContextManager(self): config.update("jax_debug_nans", False) A = jnp.array(0.) f = jax.jit(lambda x: 0. / x) ans = f(A) ans = f(A) with self.assertRaises(FloatingPointError): with jax.debug_nans(True): ans = f(A) ans.block_until_ready()
def testDebugNansDoesntCorruptCaches(self): # https://github.com/google/jax/issues/6614 @jax.jit def f(x): return jnp.divide(x, x) for _ in range(2): try: with jax.debug_nans(True): jax.grad(f)(0.) except FloatingPointError: pass
def test_grad_norm(self): e = None try: with jax.debug_nans(True): jax.grad(jnp.linalg.norm)(jnp.zeros((3, 3), jnp.float32)) except FloatingPointError as exc: e = exc self.assertIsNot(e, None) self.assertIn("invalid value", str(e)) self.assertIsInstance( e.__cause__.__cause__, source_info_util.JaxStackTraceBeforeTransformation)
def test_grad_norm(self): e = None try: with jax.debug_nans(True): jax.grad(jnp.linalg.norm)(jnp.zeros((3, 3), jnp.float32)) except FloatingPointError as exc: e = exc self.assertIsNot(e, None) self.assertIn("invalid value", str(e)) # TODO(phawkins): make this test unconditional after jaxlib 0.1.66 is the # minimum. if jax.lib._xla_extension_version >= 19: self.assertIsInstance( e.__cause__.__cause__, source_info_util.JaxStackTraceBeforeTransformation)
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None: a, b = _promote_args_inexact("logsumexp", a, b) a = jnp.where(b != 0, a, -jnp.inf) else: a, = _promote_args_inexact("logsumexp", a) pos_dims, dims = _reduction_dims(a, axis) amax = jnp.max(a, axis=dims, keepdims=keepdims) amax = lax.stop_gradient( lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0))) amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims) # fast path if the result cannot be negative. if b is None and not np.issubdtype(a.dtype, np.complexfloating): out = lax.add( lax.log( jnp.sum(lax.exp(lax.sub(a, amax_with_dims)), axis=dims, keepdims=keepdims)), amax) sign = jnp.where(jnp.isnan(out), out, 1.0) sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype) else: expsub = lax.exp(lax.sub(a, amax_with_dims)) if b is not None: expsub = lax.mul(expsub, b) sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims) sign = lax.stop_gradient(jnp.sign(sumexp)) if np.issubdtype(sumexp.dtype, np.complexfloating): if return_sign: sumexp = sign * sumexp out = lax.add(lax.log(sumexp), amax) else: out = lax.add(lax.log(lax.abs(sumexp)), amax) if return_sign: return (out, sign) if b is not None: if not np.issubdtype(out.dtype, np.complexfloating): with jax.debug_nans(False): out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out) return out