def test_while_loop_body_and_cond_error(self): def while_cond(val): i, cond_val, _ = val _ = jnp.sin(cond_val) return i < 2 def while_body(val): i, cond_val, body_val = val possible_nan = jnp.cos(body_val) return i + 1., cond_val, possible_nan @jax.jit def f(cond_val, body_val): return lax.while_loop(while_cond, while_body, (0., cond_val, body_val)) cond_val = jnp.inf body_val = 1. err, _ = checkify.checkify(f)(cond_val, body_val) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "nan generated by primitive sin") cond_val = 1. body_val = jnp.inf err, _ = checkify.checkify(f)(cond_val, body_val) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "nan generated by primitive cos") cond_val = jnp.inf body_val = jnp.inf err, _ = checkify.checkify(f)(cond_val, body_val) self.assertIsNotNone(err.get()) # first error which occurs is in cond self.assertStartsWith(err.get(), "nan generated by primitive sin")
def test_cond_of_named_call(self): def g(x): branch = jax.named_call(lambda x: x) out = jax.lax.cond(True, branch, branch, x) return out checkify.checkify(g)(0.) # does not crash
def test_while_loop_body_error(self): def while_cond(val): i, _ = val return i < 2 def while_body(val): i, x = val possible_nan = jnp.sin(1. / i) return i + 1., x + possible_nan @jax.jit def f(init_val): return lax.while_loop(while_cond, while_body, (init_val, 0.)) init_val = 1. err, ch_out = checkify.checkify(f)(init_val) out = f(init_val) self.assertIs(err.get(), None) self.assertArraysEqual(ch_out, out) init_val = 0. err, ch_out = checkify.checkify(f)(init_val) out = f(init_val) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "nan generated by primitive sin") self.assertArraysEqual(ch_out, out)
def test_scan_carry(self): def scan_body(carry, x): carry = carry - 1. possible_nan = jnp.sin(1. / carry) return carry, x + possible_nan @jax.jit def f(carry, xs): return lax.scan(scan_body, carry, xs) carry, xs = 3., jnp.ones((2, )) err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs) out_carry, outs = f(carry, xs) self.assertIs(err.get(), None) self.assertArraysEqual(ch_outs, outs) self.assertArraysEqual(ch_out_carry, out_carry) # error happens on first iteration carry, xs = 1., jnp.ones((2, )) err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs) out_carry, outs = f(carry, xs) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "nan generated by primitive sin") self.assertArraysEqual(ch_outs, outs) self.assertArraysEqual(ch_out_carry, out_carry) # error happens on second iteration carry, xs = 2., jnp.ones((4, )) err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs) out_carry, outs = f(carry, xs) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "nan generated by primitive sin") self.assertArraysEqual(ch_outs, outs) self.assertArraysEqual(ch_out_carry, out_carry)
def test_custom_vjp(self): @jax.custom_vjp def sin(x): return jnp.sin(x) def sin_fwd(x): return jnp.sin(x), 2. * x def sin_bwd(x2, g): return jnp.cos(x2 / 2.) * g, sin.defvjp(sin_fwd, sin_bwd) f = checkify.checkify(sin, errors=checkify.float_checks) # no differentiation, no error err, y = f(3.) self.assertIsNone(err.get()) # no differentiation, yes error err, y = f(jnp.inf) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), 'nan generated by primitive sin') # When we hit the custom vjp rule with vjp-of-checkify, no checks are added. (err, y), f_vjp = jax.vjp(f, 3.) self.assertIsNone(err.get()) # no error self.assertEmpty(err.msgs) # and no checks were added! # Checkify-of-vjp adds checks (unlike vjp-of-checkify above). err, y = checkify.checkify(jax.grad(sin), errors=checkify.float_checks)(3.) self.assertIsNone(err.get()) # no error self.assertNotEmpty(err.msgs) # but checks were added! err, y = checkify.checkify(jax.grad(sin), errors=checkify.float_checks)(jnp.inf) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "nan generated by primitive sin")
def test_jit_nan(self, jit): def f(x1, x2): y1 = jnp.sin(x1) y2 = jnp.sin(x2) return y1 + y2 f = jax.jit(f) if jit else f err, _ = checkify.checkify(f)(3., 4.) self.assertIs(err.get(), None) err, _ = checkify.checkify(f)(3., jnp.inf) self.assertStartsWith(err.get(), 'nan generated by primitive sin')
def test_cond_basic(self): @jax.jit def f(x): return lax.cond(x > 0, lambda: jnp.sin(x), lambda: x) err, y = checkify.checkify(f)(3.) self.assertIs(err.get(), None) err, y = checkify.checkify(f)(jnp.inf) self.assertStartsWith(err.get(), 'nan generated by primitive sin') err, y = checkify.checkify(f)(-jnp.inf) self.assertIs(err.get(), None)
def test_jit_oob(self, jit): def f(x, i): y = jnp.sin(x) z = y[i] w = jnp.cos(z) return w f = jax.jit(f) if jit else f err, _ = checkify.checkify(f)(jnp.arange(3), 2) self.assertIs(err.get(), None) err, _ = checkify.checkify(f)(jnp.arange(3), 5) self.assertStartsWith(err.get(), 'out-of-bounds indexing')
def test_assert_discharging_scan(self): def body(carry, x): checkify.check(jnp.all(x > 0), "must be positive") return carry, x def f(x): return jax.lax.scan(body, (None, ), x) err, _ = checkify.checkify(f)(jnp.array([-1])) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "must be positive") err, _ = checkify.checkify(f)(jnp.array([1, 0, -1])) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "must be positive")
def test_assert_batching_rule(self): @jax.vmap def f(x): checkify.check(jnp.sum(x) == 1., "x must sum to one.") return x no_failures = jnp.array([[0.5, 0.5], [1., 0.]]) one_batch_fails = jnp.array([[0.5, 0.5], [1, 1]]) mult_batch_fail = jnp.array([[0.5, 0.5], [1, 1], [2, 2]]) f(no_failures) with self.assertRaisesRegex(ValueError, "x must sum to one."): f(one_batch_fails) with self.assertRaisesRegex(ValueError, "x must sum to one."): f(mult_batch_fail) checked_f = checkify.checkify(f) err, _ = checked_f(no_failures) self.assertIsNone(err.get()) err, _ = checked_f(one_batch_fails) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "x must sum to one") err, _ = checked_f(mult_batch_fail) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "x must sum to one")
def raises_oob(fn, idx, *expected_strs): err, _ = checkify.checkify(fn, errors=checkify.index_checks)(x, idx) error_txt = err.get() self.assertIsNotNone(error_txt) self.assertStartsWith(error_txt, "out-of-bounds indexing") for s in expected_strs: self.assertIn(s, error_txt)
def test_while_loop_cond_error(self): def while_cond(val): _ = jnp.sin(1. / val) return val < 2. def while_body(val): return val + 1. @jax.jit def f(init_val): return lax.while_loop(while_cond, while_body, init_val) checked_f = checkify.checkify(f, errors=checkify.float_checks) init_val = 1. err, ch_out = checked_f(init_val) out = f(init_val) self.assertIs(err.get(), None) self.assertArraysEqual(ch_out, out) init_val = 0. err, ch_out = checked_f(init_val) out = f(init_val) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "divided by zero") self.assertArraysEqual(ch_out, out)
def test_while_loop_body_error(self): def while_cond(val): i, _ = val return i < 2 def while_body(val): i, x = val possible_nan = jnp.sin(1. / i) return i + 1., x + possible_nan @jax.jit def f(init_val): return lax.while_loop(while_cond, while_body, (init_val, 0.)) checked_f = checkify.checkify(f, errors=checkify.float_checks) init_val = 1. err, ch_out = checked_f(init_val) out = f(init_val) self.assertIs(err.get(), None) self.assertArraysEqual(ch_out, out) init_val = 0. err, ch_out = checked_f(init_val) out = f(init_val) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "divided by zero") self.assertArraysEqual(ch_out, out)
def ejit(f): f = checkify.checkify(f) f = jax.jit(f) def jitted_f(*args): err, out = f(*args) checkify.check_error(err) return out return jitted_f
def ejit(f): f = checkify.checkify(f) f = jax.jit(f) def jitted_f(*args): err, out = f(*args) checkify.assert2_(~err.err, err.code, err.msgs) return out return jitted_f
def test_multiple_payloads(self): def f(x): _ = x[5] _ = x[6] err, _ = checkify.checkify(f, errors=checkify.index_checks)(jnp.ones((2,))) self.assertIsNotNone(err.get()) self.assertIn("index 5", err.get())
def test_pmap_basic(self): if len(jax.devices()) < 2: raise unittest.SkipTest("requires at least 2 devices") @jax.pmap def f(x1, x2): y1 = jnp.sin(x1) y2 = jnp.sin(x2) return y1 + y2 xs = jnp.array([0., 2.]) err, _ = checkify.checkify(f)(xs, xs) self.assertIs(err.get(), None) ys = jnp.array([3., jnp.inf]) err, _ = checkify.checkify(f)(xs, ys) self.assertStartsWith(err.get(), 'nan generated by primitive sin')
def test_mapped_error_one_payload(self): def f(x, i): x = x[i] return x / 0 cf = checkify.checkify(f, errors=checkify.automatic_checks) errs, _ = jax.vmap(cf)(jnp.ones((2, 1)), jnp.array([0, 100])) self.assertIsNotNone(errs.get()) self.assertIn("divided by zero", errs.get()) self.assertIn("index 100", errs.get())
def test_jit_multi(self, jit): def f(x, i): y = x[i] z = jnp.cos(y) return z f = jax.jit(f) if jit else f # no error err, _ = checkify.checkify(f)(jnp.array([0., jnp.inf, 2.]), 2) self.assertIs(err.get(), None) # oob error err, _ = checkify.checkify(f)(jnp.array([0., 1., 2.]), 5) self.assertStartsWith(err.get(), 'out-of-bounds indexing') # nan error err, _ = checkify.checkify(f)(jnp.array([0., 1., jnp.inf]), 2) self.assertStartsWith(err.get(), 'nan generated by primitive cos')
def test_custom_jvp(self): @jax.custom_jvp def sin(x): return jnp.sin(x) @sin.defjvp def sin_jvp(primals, tangents): (x, ), (xdot, ) = primals, tangents return sin(x), jnp.cos(x) * xdot f = checkify.checkify(sin, errors=checkify.float_checks) err, y = f(3.) self.assertIsNone(err.get()) err, y = f(jnp.inf) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), 'nan generated by primitive sin') # When we hit the custom jvp rule with jvp-of-checkify, no checks are added. (err, y), (errdot, ydot) = jax.jvp(f, (3., ), (1., )) # doesn't crash self.assertIsNone(err.get()) # no error self.assertEmpty(err.msgs) # and no checks were added! self.assertEmpty(errdot.msgs) y_expected, ydot_expected = jax.jvp(jnp.sin, (3., ), (1., )) self.assertAllClose(y, y_expected) self.assertAllClose(ydot, ydot_expected) # Grad-of-checkify doesn't crash either. x_bar = jax.grad(lambda x: f(x)[1])(3.) self.assertAllClose(x_bar, jnp.cos(3.)) # Checkify-of-jvp adds checks (unlike jvp-of-checkify above). g = checkify.checkify(lambda x, xdot: jax.jvp(sin, (x, ), (xdot, )), errors=checkify.float_checks) err, (y, ydot) = g(3., 1.) # doesn't crash self.assertIsNone(err.get()) # no error self.assertNotEmpty(err.msgs) # but checks were added! self.assertAllClose(y, jnp.sin(3.)) self.assertAllClose(ydot, jnp.cos(3.)) err, _ = g(jnp.inf, 1.) self.assertIsNotNone(err.get()) # yes error self.assertStartsWith(err.get(), 'nan generated by primitive sin')
def test_empty_enabled_errors(self): def multi_errors(x): x = x / 0 # DIV x = jnp.sin(x) # NAN x = x[500] # OOB checkify.check(x < 0, "must be negative!") # ASSERT return x x = jnp.ones((2, )) err, _ = checkify.checkify(multi_errors, errors=set())(x) self.assertIsNone(err.get())
def test_scan_consts(self): def f(xs): def scan_body(carry, _): # closes oves xs return carry+1, xs[carry] return lax.scan(scan_body, 1, xs)[1] checked_f = checkify.checkify(f, errors=checkify.index_checks) err, _ = checked_f(jnp.ones((7, 3))) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "out-of-bounds indexing")
def test_scan_map(self): def scan_body(_, x): return None, jnp.sin(x) @jax.jit def f(xs): return lax.scan(scan_body, None, xs) xs = jnp.array([0., 2.]) err, (_, ch_outs) = checkify.checkify(f)(xs) _, outs = f(xs) self.assertIs(err.get(), None) self.assertArraysEqual(ch_outs, outs) xs = jnp.array([3., jnp.inf]) err, (_, ch_outs) = checkify.checkify(f)(xs) _, outs = f(xs) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "nan generated by primitive sin") self.assertArraysEqual(ch_outs, outs)
def test_jit_ordering(self, jit): def f(x, i): y = x[i] z = jnp.sin(x) return y * z f = jax.jit(f) if jit else f # both oob and nan error, but oob happens first err, _ = checkify.checkify(f)(jnp.array([0., 1., jnp.inf]), 5) self.assertStartsWith(err.get(), 'out-of-bounds indexing')
def test_enabled_errors(self, error_set, expected_error): def multi_errors(x): checkify.check(jnp.all(x < 0), "must be negative!") # ASSERT x = x / 0 # DIV x = jnp.sin(x) # NAN x = x[500] # OOB return x x = jnp.ones((2, )) err, _ = checkify.checkify(multi_errors, errors=error_set)(x) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), expected_error)
def test_scan_consts2(self): def f(xs): def scan_body(carry, _): # add more consts! _ = xs[carry], xs[carry], jnp.sin(np.arange(11.)) return carry+1, xs[carry] return lax.scan(scan_body, 1, xs)[1] checked_f = checkify.checkify(f, errors=checkify.index_checks) err, _ = checked_f(jnp.ones((7, 3))) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "out-of-bounds indexing")
def test_check_error_scanned(self): def body(carry, x): checkify.check(jnp.all(x > 0), "should be positive") return carry, x def checked_body(carry, x): err, (carry, x) = checkify.checkify(body)(carry, x) return carry, (x, err) def f(x): _, (xs, errs) = jax.lax.scan(checked_body, (None, ), x) checkify.check_error(errs) return xs err, _ = checkify.checkify(f)(jnp.array([-1])) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "should be positive") err, _ = checkify.checkify(f)(jnp.array([1, 0, -1])) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "should be positive")
def test_empty_enabled_errors(self): def multi_errors(x): x = x / 0 # DIV x = jnp.sin(x) # NAN x = x[500] # OOB # TODO(lenamartens): this error should also be disabled. # checkify.check(x < 0, "must be negative!") # ASSERT return x x = jnp.ones((2, )) err, _ = checkify.checkify(multi_errors, errors={})(x) self.assertIsNone(err.get())
def test_while_loop_cond_error_and_false(self): # Tests if an error is generated when cond returns False. def while_cond(val): possible_nan = jnp.sin(1./val) return jnp.logical_not(jnp.isnan(possible_nan)) @jax.jit def f(init_val): return lax.while_loop(while_cond, lambda val: val-1, init_val) # error on first cond init_val = 0. err, _ = checkify.checkify(f)(init_val) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "nan generated by primitive sin") # error on second cond init_val = 1. err, _ = checkify.checkify(f)(init_val) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "nan generated by primitive sin")
def test_jit_oob_update(self, update_fn): def f(x, i): return getattr(x.at[i], update_fn)(1) f = jax.jit(f) checked_f = checkify.checkify(f, errors=checkify.index_checks) err, _ = checked_f(jnp.arange(3), 2) self.assertIs(err.get(), None) err, _ = checked_f(jnp.arange(3), 3) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "out-of-bounds indexing")