Exemple #1
0
    def test_while_loop_body_and_cond_error(self):
        def while_cond(val):
            i, cond_val, _ = val
            _ = jnp.sin(cond_val)
            return i < 2

        def while_body(val):
            i, cond_val, body_val = val
            possible_nan = jnp.cos(body_val)
            return i + 1., cond_val, possible_nan

        @jax.jit
        def f(cond_val, body_val):
            return lax.while_loop(while_cond, while_body,
                                  (0., cond_val, body_val))

        cond_val = jnp.inf
        body_val = 1.
        err, _ = checkify.checkify(f)(cond_val, body_val)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "nan generated by primitive sin")

        cond_val = 1.
        body_val = jnp.inf
        err, _ = checkify.checkify(f)(cond_val, body_val)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "nan generated by primitive cos")

        cond_val = jnp.inf
        body_val = jnp.inf
        err, _ = checkify.checkify(f)(cond_val, body_val)
        self.assertIsNotNone(err.get())
        # first error which occurs is in cond
        self.assertStartsWith(err.get(), "nan generated by primitive sin")
Exemple #2
0
    def test_cond_of_named_call(self):
        def g(x):
            branch = jax.named_call(lambda x: x)
            out = jax.lax.cond(True, branch, branch, x)
            return out

        checkify.checkify(g)(0.)  # does not crash
Exemple #3
0
    def test_while_loop_body_error(self):
        def while_cond(val):
            i, _ = val
            return i < 2

        def while_body(val):
            i, x = val
            possible_nan = jnp.sin(1. / i)
            return i + 1., x + possible_nan

        @jax.jit
        def f(init_val):
            return lax.while_loop(while_cond, while_body, (init_val, 0.))

        init_val = 1.
        err, ch_out = checkify.checkify(f)(init_val)
        out = f(init_val)
        self.assertIs(err.get(), None)
        self.assertArraysEqual(ch_out, out)

        init_val = 0.
        err, ch_out = checkify.checkify(f)(init_val)
        out = f(init_val)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "nan generated by primitive sin")
        self.assertArraysEqual(ch_out, out)
Exemple #4
0
    def test_scan_carry(self):
        def scan_body(carry, x):
            carry = carry - 1.
            possible_nan = jnp.sin(1. / carry)
            return carry, x + possible_nan

        @jax.jit
        def f(carry, xs):
            return lax.scan(scan_body, carry, xs)

        carry, xs = 3., jnp.ones((2, ))
        err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs)
        out_carry, outs = f(carry, xs)
        self.assertIs(err.get(), None)
        self.assertArraysEqual(ch_outs, outs)
        self.assertArraysEqual(ch_out_carry, out_carry)

        # error happens on first iteration
        carry, xs = 1., jnp.ones((2, ))
        err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs)
        out_carry, outs = f(carry, xs)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "nan generated by primitive sin")
        self.assertArraysEqual(ch_outs, outs)
        self.assertArraysEqual(ch_out_carry, out_carry)

        # error happens on second iteration
        carry, xs = 2., jnp.ones((4, ))
        err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs)
        out_carry, outs = f(carry, xs)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "nan generated by primitive sin")
        self.assertArraysEqual(ch_outs, outs)
        self.assertArraysEqual(ch_out_carry, out_carry)
Exemple #5
0
  def test_custom_vjp(self):
    @jax.custom_vjp
    def sin(x):
      return jnp.sin(x)

    def sin_fwd(x):
      return jnp.sin(x), 2. * x
    def sin_bwd(x2, g):
      return jnp.cos(x2 / 2.) * g,
    sin.defvjp(sin_fwd, sin_bwd)

    f = checkify.checkify(sin, errors=checkify.float_checks)

    # no differentiation, no error
    err, y = f(3.)
    self.assertIsNone(err.get())

    # no differentiation, yes error
    err, y = f(jnp.inf)
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), 'nan generated by primitive sin')

    # When we hit the custom vjp rule with vjp-of-checkify, no checks are added.
    (err, y), f_vjp = jax.vjp(f, 3.)
    self.assertIsNone(err.get())  # no error
    self.assertEmpty(err.msgs)    # and no checks were added!

    # Checkify-of-vjp adds checks (unlike vjp-of-checkify above).
    err, y = checkify.checkify(jax.grad(sin), errors=checkify.float_checks)(3.)
    self.assertIsNone(err.get())   # no error
    self.assertNotEmpty(err.msgs)  # but checks were added!
    err, y = checkify.checkify(jax.grad(sin),
                               errors=checkify.float_checks)(jnp.inf)
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), "nan generated by primitive sin")
Exemple #6
0
    def test_jit_nan(self, jit):
        def f(x1, x2):
            y1 = jnp.sin(x1)
            y2 = jnp.sin(x2)
            return y1 + y2

        f = jax.jit(f) if jit else f

        err, _ = checkify.checkify(f)(3., 4.)
        self.assertIs(err.get(), None)

        err, _ = checkify.checkify(f)(3., jnp.inf)
        self.assertStartsWith(err.get(), 'nan generated by primitive sin')
Exemple #7
0
    def test_cond_basic(self):
        @jax.jit
        def f(x):
            return lax.cond(x > 0, lambda: jnp.sin(x), lambda: x)

        err, y = checkify.checkify(f)(3.)
        self.assertIs(err.get(), None)

        err, y = checkify.checkify(f)(jnp.inf)
        self.assertStartsWith(err.get(), 'nan generated by primitive sin')

        err, y = checkify.checkify(f)(-jnp.inf)
        self.assertIs(err.get(), None)
Exemple #8
0
    def test_jit_oob(self, jit):
        def f(x, i):
            y = jnp.sin(x)
            z = y[i]
            w = jnp.cos(z)
            return w

        f = jax.jit(f) if jit else f

        err, _ = checkify.checkify(f)(jnp.arange(3), 2)
        self.assertIs(err.get(), None)

        err, _ = checkify.checkify(f)(jnp.arange(3), 5)
        self.assertStartsWith(err.get(), 'out-of-bounds indexing')
Exemple #9
0
    def test_assert_discharging_scan(self):
        def body(carry, x):
            checkify.check(jnp.all(x > 0), "must be positive")
            return carry, x

        def f(x):
            return jax.lax.scan(body, (None, ), x)

        err, _ = checkify.checkify(f)(jnp.array([-1]))
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "must be positive")

        err, _ = checkify.checkify(f)(jnp.array([1, 0, -1]))
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "must be positive")
Exemple #10
0
    def test_assert_batching_rule(self):
        @jax.vmap
        def f(x):
            checkify.check(jnp.sum(x) == 1., "x must sum to one.")
            return x

        no_failures = jnp.array([[0.5, 0.5], [1., 0.]])
        one_batch_fails = jnp.array([[0.5, 0.5], [1, 1]])
        mult_batch_fail = jnp.array([[0.5, 0.5], [1, 1], [2, 2]])

        f(no_failures)
        with self.assertRaisesRegex(ValueError, "x must sum to one."):
            f(one_batch_fails)

        with self.assertRaisesRegex(ValueError, "x must sum to one."):
            f(mult_batch_fail)

        checked_f = checkify.checkify(f)
        err, _ = checked_f(no_failures)
        self.assertIsNone(err.get())

        err, _ = checked_f(one_batch_fails)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "x must sum to one")

        err, _ = checked_f(mult_batch_fail)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "x must sum to one")
Exemple #11
0
 def raises_oob(fn, idx, *expected_strs):
   err, _ = checkify.checkify(fn, errors=checkify.index_checks)(x, idx)
   error_txt = err.get()
   self.assertIsNotNone(error_txt)
   self.assertStartsWith(error_txt, "out-of-bounds indexing")
   for s in expected_strs:
     self.assertIn(s, error_txt)
Exemple #12
0
    def test_while_loop_cond_error(self):
        def while_cond(val):
            _ = jnp.sin(1. / val)
            return val < 2.

        def while_body(val):
            return val + 1.

        @jax.jit
        def f(init_val):
            return lax.while_loop(while_cond, while_body, init_val)

        checked_f = checkify.checkify(f, errors=checkify.float_checks)

        init_val = 1.
        err, ch_out = checked_f(init_val)
        out = f(init_val)
        self.assertIs(err.get(), None)
        self.assertArraysEqual(ch_out, out)

        init_val = 0.
        err, ch_out = checked_f(init_val)
        out = f(init_val)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "divided by zero")
        self.assertArraysEqual(ch_out, out)
Exemple #13
0
    def test_while_loop_body_error(self):
        def while_cond(val):
            i, _ = val
            return i < 2

        def while_body(val):
            i, x = val
            possible_nan = jnp.sin(1. / i)
            return i + 1., x + possible_nan

        @jax.jit
        def f(init_val):
            return lax.while_loop(while_cond, while_body, (init_val, 0.))

        checked_f = checkify.checkify(f, errors=checkify.float_checks)

        init_val = 1.
        err, ch_out = checked_f(init_val)
        out = f(init_val)
        self.assertIs(err.get(), None)
        self.assertArraysEqual(ch_out, out)

        init_val = 0.
        err, ch_out = checked_f(init_val)
        out = f(init_val)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "divided by zero")
        self.assertArraysEqual(ch_out, out)
Exemple #14
0
 def ejit(f):
   f = checkify.checkify(f)
   f = jax.jit(f)
   def jitted_f(*args):
     err, out = f(*args)
     checkify.check_error(err)
     return out
   return jitted_f
Exemple #15
0
 def ejit(f):
   f = checkify.checkify(f)
   f = jax.jit(f)
   def jitted_f(*args):
     err, out = f(*args)
     checkify.assert2_(~err.err, err.code, err.msgs)
     return out
   return jitted_f
Exemple #16
0
  def test_multiple_payloads(self):
    def f(x):
      _ = x[5]
      _ = x[6]

    err, _ = checkify.checkify(f, errors=checkify.index_checks)(jnp.ones((2,)))
    self.assertIsNotNone(err.get())
    self.assertIn("index 5", err.get())
Exemple #17
0
    def test_pmap_basic(self):
        if len(jax.devices()) < 2:
            raise unittest.SkipTest("requires at least 2 devices")

        @jax.pmap
        def f(x1, x2):
            y1 = jnp.sin(x1)
            y2 = jnp.sin(x2)
            return y1 + y2

        xs = jnp.array([0., 2.])
        err, _ = checkify.checkify(f)(xs, xs)
        self.assertIs(err.get(), None)

        ys = jnp.array([3., jnp.inf])
        err, _ = checkify.checkify(f)(xs, ys)
        self.assertStartsWith(err.get(), 'nan generated by primitive sin')
Exemple #18
0
    def test_mapped_error_one_payload(self):
        def f(x, i):
            x = x[i]
            return x / 0

        cf = checkify.checkify(f, errors=checkify.automatic_checks)
        errs, _ = jax.vmap(cf)(jnp.ones((2, 1)), jnp.array([0, 100]))
        self.assertIsNotNone(errs.get())
        self.assertIn("divided by zero", errs.get())
        self.assertIn("index 100", errs.get())
Exemple #19
0
    def test_jit_multi(self, jit):
        def f(x, i):
            y = x[i]
            z = jnp.cos(y)
            return z

        f = jax.jit(f) if jit else f

        # no error
        err, _ = checkify.checkify(f)(jnp.array([0., jnp.inf, 2.]), 2)
        self.assertIs(err.get(), None)

        # oob error
        err, _ = checkify.checkify(f)(jnp.array([0., 1., 2.]), 5)
        self.assertStartsWith(err.get(), 'out-of-bounds indexing')

        # nan error
        err, _ = checkify.checkify(f)(jnp.array([0., 1., jnp.inf]), 2)
        self.assertStartsWith(err.get(), 'nan generated by primitive cos')
Exemple #20
0
    def test_custom_jvp(self):
        @jax.custom_jvp
        def sin(x):
            return jnp.sin(x)

        @sin.defjvp
        def sin_jvp(primals, tangents):
            (x, ), (xdot, ) = primals, tangents
            return sin(x), jnp.cos(x) * xdot

        f = checkify.checkify(sin, errors=checkify.float_checks)

        err, y = f(3.)
        self.assertIsNone(err.get())
        err, y = f(jnp.inf)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), 'nan generated by primitive sin')

        # When we hit the custom jvp rule with jvp-of-checkify, no checks are added.
        (err, y), (errdot, ydot) = jax.jvp(f, (3., ), (1., ))  # doesn't crash
        self.assertIsNone(err.get())  # no error
        self.assertEmpty(err.msgs)  # and no checks were added!
        self.assertEmpty(errdot.msgs)
        y_expected, ydot_expected = jax.jvp(jnp.sin, (3., ), (1., ))
        self.assertAllClose(y, y_expected)
        self.assertAllClose(ydot, ydot_expected)

        # Grad-of-checkify doesn't crash either.
        x_bar = jax.grad(lambda x: f(x)[1])(3.)
        self.assertAllClose(x_bar, jnp.cos(3.))

        # Checkify-of-jvp adds checks (unlike jvp-of-checkify above).
        g = checkify.checkify(lambda x, xdot: jax.jvp(sin, (x, ), (xdot, )),
                              errors=checkify.float_checks)
        err, (y, ydot) = g(3., 1.)  # doesn't crash
        self.assertIsNone(err.get())  # no error
        self.assertNotEmpty(err.msgs)  # but checks were added!
        self.assertAllClose(y, jnp.sin(3.))
        self.assertAllClose(ydot, jnp.cos(3.))
        err, _ = g(jnp.inf, 1.)
        self.assertIsNotNone(err.get())  # yes error
        self.assertStartsWith(err.get(), 'nan generated by primitive sin')
Exemple #21
0
    def test_empty_enabled_errors(self):
        def multi_errors(x):
            x = x / 0  # DIV
            x = jnp.sin(x)  # NAN
            x = x[500]  # OOB
            checkify.check(x < 0, "must be negative!")  # ASSERT
            return x

        x = jnp.ones((2, ))
        err, _ = checkify.checkify(multi_errors, errors=set())(x)
        self.assertIsNone(err.get())
Exemple #22
0
  def test_scan_consts(self):
    def f(xs):
      def scan_body(carry, _):
        # closes oves xs
        return carry+1, xs[carry]
      return lax.scan(scan_body, 1, xs)[1]

    checked_f = checkify.checkify(f, errors=checkify.index_checks)
    err, _ = checked_f(jnp.ones((7, 3)))
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), "out-of-bounds indexing")
Exemple #23
0
    def test_scan_map(self):
        def scan_body(_, x):
            return None, jnp.sin(x)

        @jax.jit
        def f(xs):
            return lax.scan(scan_body, None, xs)

        xs = jnp.array([0., 2.])
        err, (_, ch_outs) = checkify.checkify(f)(xs)
        _, outs = f(xs)
        self.assertIs(err.get(), None)
        self.assertArraysEqual(ch_outs, outs)

        xs = jnp.array([3., jnp.inf])
        err, (_, ch_outs) = checkify.checkify(f)(xs)
        _, outs = f(xs)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "nan generated by primitive sin")
        self.assertArraysEqual(ch_outs, outs)
Exemple #24
0
    def test_jit_ordering(self, jit):
        def f(x, i):
            y = x[i]
            z = jnp.sin(x)
            return y * z

        f = jax.jit(f) if jit else f

        # both oob and nan error, but oob happens first
        err, _ = checkify.checkify(f)(jnp.array([0., 1., jnp.inf]), 5)
        self.assertStartsWith(err.get(), 'out-of-bounds indexing')
Exemple #25
0
    def test_enabled_errors(self, error_set, expected_error):
        def multi_errors(x):
            checkify.check(jnp.all(x < 0), "must be negative!")  # ASSERT
            x = x / 0  # DIV
            x = jnp.sin(x)  # NAN
            x = x[500]  # OOB
            return x

        x = jnp.ones((2, ))
        err, _ = checkify.checkify(multi_errors, errors=error_set)(x)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), expected_error)
Exemple #26
0
  def test_scan_consts2(self):
    def f(xs):
      def scan_body(carry, _):
        # add more consts!
        _ = xs[carry], xs[carry], jnp.sin(np.arange(11.))
        return carry+1, xs[carry]
      return lax.scan(scan_body, 1, xs)[1]

    checked_f = checkify.checkify(f, errors=checkify.index_checks)
    err, _ = checked_f(jnp.ones((7, 3)))
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), "out-of-bounds indexing")
Exemple #27
0
    def test_check_error_scanned(self):
        def body(carry, x):
            checkify.check(jnp.all(x > 0), "should be positive")
            return carry, x

        def checked_body(carry, x):
            err, (carry, x) = checkify.checkify(body)(carry, x)
            return carry, (x, err)

        def f(x):
            _, (xs, errs) = jax.lax.scan(checked_body, (None, ), x)
            checkify.check_error(errs)
            return xs

        err, _ = checkify.checkify(f)(jnp.array([-1]))
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "should be positive")

        err, _ = checkify.checkify(f)(jnp.array([1, 0, -1]))
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "should be positive")
Exemple #28
0
    def test_empty_enabled_errors(self):
        def multi_errors(x):
            x = x / 0  # DIV
            x = jnp.sin(x)  # NAN
            x = x[500]  # OOB
            # TODO(lenamartens): this error should also be disabled.
            # checkify.check(x < 0, "must be negative!")  # ASSERT
            return x

        x = jnp.ones((2, ))
        err, _ = checkify.checkify(multi_errors, errors={})(x)
        self.assertIsNone(err.get())
Exemple #29
0
  def test_while_loop_cond_error_and_false(self):
    # Tests if an error is generated when cond returns False.
    def while_cond(val):
      possible_nan = jnp.sin(1./val)
      return jnp.logical_not(jnp.isnan(possible_nan))

    @jax.jit
    def f(init_val):
      return lax.while_loop(while_cond, lambda val: val-1, init_val)

    # error on first cond
    init_val = 0.
    err, _ = checkify.checkify(f)(init_val)
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), "nan generated by primitive sin")

    # error on second cond
    init_val = 1.
    err, _ = checkify.checkify(f)(init_val)
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), "nan generated by primitive sin")
Exemple #30
0
    def test_jit_oob_update(self, update_fn):
        def f(x, i):
            return getattr(x.at[i], update_fn)(1)

        f = jax.jit(f)
        checked_f = checkify.checkify(f, errors=checkify.index_checks)

        err, _ = checked_f(jnp.arange(3), 2)
        self.assertIs(err.get(), None)

        err, _ = checked_f(jnp.arange(3), 3)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "out-of-bounds indexing")