def check_grads(f, args, order, atol=None, rtol=None, eps=None): # TODO(mattjj,dougalm): add higher-order check default_tol = 1e-6 if FLAGS.jax_enable_x64 else 1e-2 atol = atol or default_tol rtol = rtol or default_tol eps = eps or default_tol jtu.check_jvp(f, partial(api.jvp, f), args, atol, rtol, eps) jtu.check_vjp(f, partial(api.vjp, f), args, atol, rtol, eps)
def test_vjp(self, f, args): jtu.check_vjp(f, partial(vjp, f), args, rtol={ np.float32: 3e-1, np.float64: 1e-5 }, atol={ np.float32: 1e-2, np.float64: 1e-5 })
def test_odeint_vjp(): """Use check_vjp to check odeint VJP calculations.""" # check pend() y = np.array([np.pi - 0.1, 0.0]) t = np.linspace(0., 10., 11) b = 0.25 c = 9.8 wrap_args = (y, t, b, c) pend_odeint_wrap = lambda y, t, *args: odeint(pend, y, t, *args) pend_vjp_wrap = lambda y, t, *args: vjp_odeint(pend, y, t, *args) check_vjp(pend_odeint_wrap, pend_vjp_wrap, wrap_args) # check swoop() y = np.array([0.1]) t = np.linspace(0., 10., 11) arg1 = 0.1 arg2 = 0.2 wrap_args = (y, t, arg1, arg2) swoop_odeint_wrap = lambda y, t, *args: odeint(swoop, y, t, *args) swoop_vjp_wrap = lambda y, t, *args: vjp_odeint(swoop, y, t, *args) check_vjp(swoop_odeint_wrap, swoop_vjp_wrap, wrap_args)
def test_vjp(self, f, args): print(f) jtu.check_vjp(f, partial(vjp, f), args)