Ejemplo n.º 1
0
def test_fwd_back():
    # Run a system forwards then backwards,
    # and check that we end up in the same place.
    D = 10
    t0 = 0.1
    t1 = 2.2
    y0 = np.linspace(0.1, 0.9, D)

    def f(y, t):
        return -np.sqrt(t) - y + 0.1 - np.mean((y + 0.2)**2)

    ys = odeint(f, y0, np.array([t0, t1]), atol=1e-8, rtol=1e-8)
    rys = odeint(f, ys[-1], np.array([t1, t0]), atol=1e-8, rtol=1e-8)

    assert np.allclose(y0, rys[-1])
Ejemplo n.º 2
0
    def vjp_all(g):

        vjp_y = g[-1, :]
        vjp_t0 = 0
        time_vjp_list = []
        vjp_args = np.zeros(np.size(flat_args))

        for i in range(T - 1, 0, -1):

            # Compute effect of moving measurement time.
            vjp_cur_t = np.dot(func(yt[i, :], t[i], *func_args), g[i, :])
            time_vjp_list.append(vjp_cur_t)
            vjp_t0 = vjp_t0 - vjp_cur_t

            # Run augmented system backwards to the previous observation.
            aug_y0 = np.hstack((yt[i, :], vjp_y, vjp_t0, vjp_args))
            aug_ans = odeint(augmented_dynamics, aug_y0,
                             np.stack([t[i], t[i - 1]]), (flat_args, ))
            _, vjp_y, vjp_t0, vjp_args = unpack(aug_ans[1])

            # Add gradient from current output.
            vjp_y = vjp_y + g[i - 1, :]

        time_vjp_list.append(vjp_t0)
        vjp_times = np.hstack(time_vjp_list)[::-1]

        return None, vjp_y, vjp_times, unravel(vjp_args)
Ejemplo n.º 3
0
 def odeint2(y0, t0, t1, fargs):
     return odeint(y0,
                   np.array([t0, t1]),
                   fargs,
                   func=f,
                   atol=1e-8,
                   rtol=1e-8)
Ejemplo n.º 4
0
 def twoarg_odeint(y0, args):
     return odeint(f,
                   y0,
                   np.array([t0, t1]),
                   args=args,
                   atol=1e-8,
                   rtol=1e-8)[1]
Ejemplo n.º 5
0
def test_odeint_vjp():
    D = 3
    t0 = 0.1
    t1 = 0.2
    y0 = np.linspace(0.1, 0.9, D)
    fargs = (0.1, 0.2)

    def f(y, t, arg1, arg2):
        return -np.sqrt(t) - y + arg1 - np.mean((y + arg2)**2)

    def onearg_odeint(args):
        return np.sum(odeint(f, *args, atol=1e-8, rtol=1e-8))

    numerical_grad = nd(onearg_odeint, (y0, np.array([t0, t1]), fargs))

    ys = odeint(f, y0, np.array([t0, t1]), fargs, atol=1e-8, rtol=1e-8)
    ode_vjp = grad_odeint(ys, f, y0, np.array([t0, t1]), fargs)
    g = np.ones_like(ys)
    exact_grad, _ = ravel_pytree(ode_vjp(g))

    assert np.allclose(numerical_grad, exact_grad)
Ejemplo n.º 6
0
    init_state, unpack = ravel_pytree((y0, tan_y0))

    def augmented_dynamics(augmented_state, t, fargs):

        # state and senstivity state
        y, a = unpack(augmented_state)
        a = ad.instantiate_zeros(y, a)

        # combined dynamics
        dy_dt, da_dt = jvp(func, (y, t, fargs), (a, tan_t0, tan_fargs))

        # pack back to give dynamics of augmented_state
        return np.concatenate([dy_dt, da_dt])

    # Solve augmented dynamics
    aug_sol = odeint(init_state,
                     np.array([t0, t1]),
                     fargs,
                     func=augmented_dynamics)
    yt, at = unpack(aug_sol[1])

    # Sensitivities of y(t1) wrt t0 and t1
    jvp_t_total = (tan_t1 - tan_t0) * func(yt, t1, fargs)

    # Combine sensitivities
    tan_yt = jvp_t_total if at is ad_util.zero else at + jvp_t_total
    return (np.array([y0, yt]), np.array([tan_y0, tan_yt]))


ad.primitive_jvps[odeint.primitive] = jvp_odeint
Ejemplo n.º 7
0
 def odeint_fwrap(y0, ts, fargs):
     return odeint(y0, ts, func=f, fargs=fargs)
Ejemplo n.º 8
0
 def odeint2(y0, ts, fargs):
     return odeint(f, y0, ts, fargs, atol=1e-8, rtol=1e-8)
Ejemplo n.º 9
0
def ode_w_linear_part(func, y0, a0, t0, t1, func_args):
    # Just a wrapper around odeint for dynamics that are linear in a0, but not in y0.
    aug_y0, unpack = ravel_pytree((y0, a0))
    aug_ans = odeint(func, aug_y0, np.array([t0, t1]), func_args)
    yt, jvp_all = unpack(aug_ans[1])
    return yt, jvp_all
Ejemplo n.º 10
0
 def onearg_odeint(args):
     return np.sum(odeint(f, *args, atol=1e-8, rtol=1e-8))
Ejemplo n.º 11
0
 def onearg_odeint(y0):
     return odeint(f, y0, np.array([t0, t1]), atol=1e-8, rtol=1e-8)[1]