예제 #1
0
def check_grads(f, args, order, atol=None, rtol=None, eps=None):
  # TODO(mattjj,dougalm): add higher-order check
  default_tol = 1e-6 if FLAGS.jax_enable_x64 else 1e-2
  atol = atol or default_tol
  rtol = rtol or default_tol
  eps = eps or default_tol
  jtu.check_jvp(f, partial(api.jvp, f), args, atol, rtol, eps)
  jtu.check_vjp(f, partial(api.vjp, f), args, atol, rtol, eps)
예제 #2
0
 def test_vjp(self, f, args):
     jtu.check_vjp(f,
                   partial(vjp, f),
                   args,
                   rtol={
                       np.float32: 3e-1,
                       np.float64: 1e-5
                   },
                   atol={
                       np.float32: 1e-2,
                       np.float64: 1e-5
                   })
예제 #3
0
def test_odeint_vjp():
    """Use check_vjp to check odeint VJP calculations."""

    # check pend()
    y = np.array([np.pi - 0.1, 0.0])
    t = np.linspace(0., 10., 11)
    b = 0.25
    c = 9.8
    wrap_args = (y, t, b, c)
    pend_odeint_wrap = lambda y, t, *args: odeint(pend, y, t, *args)
    pend_vjp_wrap = lambda y, t, *args: vjp_odeint(pend, y, t, *args)
    check_vjp(pend_odeint_wrap, pend_vjp_wrap, wrap_args)

    # check swoop()
    y = np.array([0.1])
    t = np.linspace(0., 10., 11)
    arg1 = 0.1
    arg2 = 0.2
    wrap_args = (y, t, arg1, arg2)
    swoop_odeint_wrap = lambda y, t, *args: odeint(swoop, y, t, *args)
    swoop_vjp_wrap = lambda y, t, *args: vjp_odeint(swoop, y, t, *args)
    check_vjp(swoop_odeint_wrap, swoop_vjp_wrap, wrap_args)
예제 #4
0
 def test_vjp(self, f, args):
     print(f)
     jtu.check_vjp(f, partial(vjp, f), args)