예제 #1
0
def test_odeint_2_linearize():
    def odeint2(y0, ts, fargs):
        return odeint(f, y0, ts, fargs, atol=1e-8, rtol=1e-8)

    odeint2_prim = custom_transforms(odeint2).primitive

    def odeint2_jvp((y0, ts, fargs), (tan_y, tan_ts, tan_fargs)):
        return jvp_odeint(f, (y0, ts, fargs), (tan_y, tan_ts, tan_fargs))
예제 #2
0
파일: test_jvp.py 프로젝트: duvenaud/jaxde
def test_odeint_jvp():
    def odeint2(y0, t0, t1, fargs):
        return odeint(y0,
                      np.array([t0, t1]),
                      fargs,
                      func=f,
                      atol=1e-8,
                      rtol=1e-8)

    def odeint2_jvp((y0, t0, t1, fargs), (tan_y, tan_t0, tan_t1, tan_fargs)):
        return jvp_odeint((y0, np.array([t0, t1]), fargs),
                          (tan_y, np.array([tan_t0, tan_t1]), tan_fargs),
                          func=f)
예제 #3
0
def test_odeint_jvp_z():
    D = 10
    t0 = 0.1
    t1 = 0.2
    y0 = np.linspace(0.1, 0.9, D)
    arg = np.zeros((0, ))

    def f(y, t, args):
        return -np.sqrt(t) - y

    @custom_transforms
    def onearg_odeint(y0):
        return odeint(f, y0, np.array([t0, t1]), atol=1e-8, rtol=1e-8)[1]

    def onearg_jvp((y0, arg), (tangent_all, )):
        return jvp_odeint(tangent_all, f, y0, t0, t1, arg)
예제 #4
0
def test_odeint_jvp_all():
    D = 10
    t0 = 0.1
    t1 = 0.2
    y0 = np.linspace(0.1, 0.9, D)
    fargs = (0.1, 0.2)

    def f(y, t, arg1, arg2):
        return -np.sqrt(t) - y + arg1 - np.mean((y + arg2)**2)

    @custom_transforms
    def twoarg_odeint(y0, args):
        return odeint(f,
                      y0,
                      np.array([t0, t1]),
                      args=args,
                      atol=1e-8,
                      rtol=1e-8)[1]

    def twoarg_jvp((y0, args), tangent_all):
        return jvp_odeint(tangent_all, f, y0, t0, t1, args)
예제 #5
0
 def onearg_jvp(y0, tangent_y0):
     return jvp_odeint(tangent_y0, f, y0, t0, t1)