def inner_loop(opt, _):
    runny_run = policy_integrate_cost(dynamics_fn, position_cost_fn, control_cost_fn, gamma, policy)

    (y0_fwd, yT_fwd, y0_bwd), vjp = jax.vjp(runny_run, opt.value, x0, total_time)
    x_cost_T_fwd, u_cost_T_fwd, xT_fwd = yT_fwd
    x_cost_0_bwd, u_cost_0_bwd, x0_bwd = y0_bwd

    yT_fwd_bar = (jnp.ones(()), jnp.ones(()), jnp.zeros_like(x0))
    g, _, _ = vjp((zeros_like_tree(y0_fwd), yT_fwd_bar, zeros_like_tree(y0_bwd)))

    return opt.update(g), Record(x_cost_T_fwd, u_cost_T_fwd, xT_fwd, x_cost_0_bwd, u_cost_0_bwd,
                                 x0_bwd)
示例#2
0
def bvp_fwd_bwd(f, y0, t0, t1, f_args, adj_y_t1, init_num_nodes=2):
    z_bc = (y0, adj_y_t1, 0.0, zeros_like_tree(f_args))
    z_bc_flat, unravel = ravel_pytree(z_bc)

    def dynamics_one(t, aug, args):
        y, adj_y, _, _ = aug
        ydot, vjpfun = vjp(f, y, t, args)
        return (ydot, *tree_map(jnp.negative, vjpfun(adj_y)))

    def dynamics_one_flat(t, aug, args):
        flat, _ = ravel_pytree(dynamics_one(t, unravel(aug), args))
        return flat

    @jit
    def dynamics_many_flat(ts, augs, args):
        return vmap(dynamics_one_flat, in_axes=(0, 1, None))(ts, augs, args).T

    def bc(aug_t0, aug_t1):
        y_t0_, _, _, _ = aug_t0
        _, adj_y_t1_, adj_t_t1_, adj_args_t1_ = aug_t1
        return tree_multimap(jnp.subtract,
                             (y_t0_, adj_y_t1_, adj_t_t1_, adj_args_t1_), z_bc)

    @jit
    def bc_flat(aug_t0, aug_t1):
        error_flat, _ = ravel_pytree(bc(unravel(aug_t0), unravel(aug_t1)))
        return error_flat

    dynamics_one_jac = jacrev(dynamics_one_flat, argnums=1)

    @jit
    def dynamics_jac(ts, augs, args):
        return jnp.transpose(vmap(dynamics_one_jac,
                                  in_axes=(0, 1, None))(ts, augs, args),
                             axes=(1, 2, 0))

    # If fun_jac isn't provided then the number of nodes blows up, and we reach
    # memory errors, even on a machine with 90G. See the full error for more info:
    # https://gist.github.com/samuela/8c5f6463e08d15c9ffad1f352d1a5513.

    # Adding the bc_jac is super important for numerical stability.
    bvp_soln = solve_bvp(
        lambda ts, augs: dynamics_many_flat(ts, augs, f_args),
        bc_flat,
        jnp.linspace(t0, t1, num=init_num_nodes),
        jnp.array([z_bc_flat] * init_num_nodes).T,
        fun_jac=lambda ts, augs: dynamics_jac(ts, augs, f_args),
        bc_jac=jit(jacrev(bc_flat, argnums=(0, 1))))

    z_t1, _, _, _ = unravel(bvp_soln.y[:, -1])
    _, adj_y_t0, adj_t_t0, adj_args_t0 = unravel(bvp_soln.y[:, 0])
    return z_t1, adj_y_t0, adj_t_t0, adj_args_t0, bvp_soln
示例#3
0
    def run(policy_params, x0, total_time):
        # Run the forward pass.
        y0 = (jnp.zeros(()), x0)
        # t0 = time.time()
        (cost, _), y_fn = solve_ivp_fwd(0.0, total_time, y0, policy_params)
        # print(f"... Forward pass took {time.time() - t0}s")

        # Run the backward pass.
        # t0 = time.time()
        g = (jnp.ones(()), zeros_like_tree(x0))
        g = solve_ivp_bwd(policy_params, y_fn, g)
        # print(f"... Backward pass took {time.time() - t0}s")
        return cost, g
示例#4
0
    def bwd(args, y_fn, g):
        aug = (jnp.zeros(()), g, zeros_like_tree(args))
        for i in range(y_fn.Q.shape[0])[::-1]:
            # Believe it or not it's faster to pull these out into variables.
            ta = y_fn.ts[i]
            tb = y_fn.ts[i + 1]

            # Believe it or not solve_ivp hands us shit like time steps that are only
            # 1e-16 apart. And those don't play nicely with Runge-Kutta.
            if tb - ta > 1e-8:
                aug = bwd_spline_segment(ta, tb, args, y_fn.Q[i],
                                         y_fn.y_old[i], aug)
        (_, _, adj_args) = aug
        return adj_args
示例#5
0
 def run(x0, args):
     (loss, _), _, _, adj_args_t0, bvp_soln = bvp_fwd_bwd(
         f, (0.0, x0), 0, T, args, (1.0, zeros_like_tree(x0)))
     assert bvp_soln.success
     return loss, adj_args_t0