def test_odeint_2_linearize(): def odeint2(y0, ts, fargs): return odeint(f, y0, ts, fargs, atol=1e-8, rtol=1e-8) odeint2_prim = custom_transforms(odeint2).primitive def odeint2_jvp((y0, ts, fargs), (tan_y, tan_ts, tan_fargs)): return jvp_odeint(f, (y0, ts, fargs), (tan_y, tan_ts, tan_fargs))
def test_odeint_jvp(): def odeint2(y0, t0, t1, fargs): return odeint(y0, np.array([t0, t1]), fargs, func=f, atol=1e-8, rtol=1e-8) def odeint2_jvp((y0, t0, t1, fargs), (tan_y, tan_t0, tan_t1, tan_fargs)): return jvp_odeint((y0, np.array([t0, t1]), fargs), (tan_y, np.array([tan_t0, tan_t1]), tan_fargs), func=f)
def test_odeint_jvp_z(): D = 10 t0 = 0.1 t1 = 0.2 y0 = np.linspace(0.1, 0.9, D) arg = np.zeros((0, )) def f(y, t, args): return -np.sqrt(t) - y @custom_transforms def onearg_odeint(y0): return odeint(f, y0, np.array([t0, t1]), atol=1e-8, rtol=1e-8)[1] def onearg_jvp((y0, arg), (tangent_all, )): return jvp_odeint(tangent_all, f, y0, t0, t1, arg)
def test_odeint_jvp_all(): D = 10 t0 = 0.1 t1 = 0.2 y0 = np.linspace(0.1, 0.9, D) fargs = (0.1, 0.2) def f(y, t, arg1, arg2): return -np.sqrt(t) - y + arg1 - np.mean((y + arg2)**2) @custom_transforms def twoarg_odeint(y0, args): return odeint(f, y0, np.array([t0, t1]), args=args, atol=1e-8, rtol=1e-8)[1] def twoarg_jvp((y0, args), tangent_all): return jvp_odeint(tangent_all, f, y0, t0, t1, args)
def onearg_jvp(y0, tangent_y0): return jvp_odeint(tangent_y0, f, y0, t0, t1)