def inner_loop(opt, _): runny_run = policy_integrate_cost(dynamics_fn, position_cost_fn, control_cost_fn, gamma, policy) (y0_fwd, yT_fwd, y0_bwd), vjp = jax.vjp(runny_run, opt.value, x0, total_time) x_cost_T_fwd, u_cost_T_fwd, xT_fwd = yT_fwd x_cost_0_bwd, u_cost_0_bwd, x0_bwd = y0_bwd yT_fwd_bar = (jnp.ones(()), jnp.ones(()), jnp.zeros_like(x0)) g, _, _ = vjp((zeros_like_tree(y0_fwd), yT_fwd_bar, zeros_like_tree(y0_bwd))) return opt.update(g), Record(x_cost_T_fwd, u_cost_T_fwd, xT_fwd, x_cost_0_bwd, u_cost_0_bwd, x0_bwd)
def bvp_fwd_bwd(f, y0, t0, t1, f_args, adj_y_t1, init_num_nodes=2): z_bc = (y0, adj_y_t1, 0.0, zeros_like_tree(f_args)) z_bc_flat, unravel = ravel_pytree(z_bc) def dynamics_one(t, aug, args): y, adj_y, _, _ = aug ydot, vjpfun = vjp(f, y, t, args) return (ydot, *tree_map(jnp.negative, vjpfun(adj_y))) def dynamics_one_flat(t, aug, args): flat, _ = ravel_pytree(dynamics_one(t, unravel(aug), args)) return flat @jit def dynamics_many_flat(ts, augs, args): return vmap(dynamics_one_flat, in_axes=(0, 1, None))(ts, augs, args).T def bc(aug_t0, aug_t1): y_t0_, _, _, _ = aug_t0 _, adj_y_t1_, adj_t_t1_, adj_args_t1_ = aug_t1 return tree_multimap(jnp.subtract, (y_t0_, adj_y_t1_, adj_t_t1_, adj_args_t1_), z_bc) @jit def bc_flat(aug_t0, aug_t1): error_flat, _ = ravel_pytree(bc(unravel(aug_t0), unravel(aug_t1))) return error_flat dynamics_one_jac = jacrev(dynamics_one_flat, argnums=1) @jit def dynamics_jac(ts, augs, args): return jnp.transpose(vmap(dynamics_one_jac, in_axes=(0, 1, None))(ts, augs, args), axes=(1, 2, 0)) # If fun_jac isn't provided then the number of nodes blows up, and we reach # memory errors, even on a machine with 90G. See the full error for more info: # https://gist.github.com/samuela/8c5f6463e08d15c9ffad1f352d1a5513. # Adding the bc_jac is super important for numerical stability. bvp_soln = solve_bvp( lambda ts, augs: dynamics_many_flat(ts, augs, f_args), bc_flat, jnp.linspace(t0, t1, num=init_num_nodes), jnp.array([z_bc_flat] * init_num_nodes).T, fun_jac=lambda ts, augs: dynamics_jac(ts, augs, f_args), bc_jac=jit(jacrev(bc_flat, argnums=(0, 1)))) z_t1, _, _, _ = unravel(bvp_soln.y[:, -1]) _, adj_y_t0, adj_t_t0, adj_args_t0 = unravel(bvp_soln.y[:, 0]) return z_t1, adj_y_t0, adj_t_t0, adj_args_t0, bvp_soln
def run(policy_params, x0, total_time): # Run the forward pass. y0 = (jnp.zeros(()), x0) # t0 = time.time() (cost, _), y_fn = solve_ivp_fwd(0.0, total_time, y0, policy_params) # print(f"... Forward pass took {time.time() - t0}s") # Run the backward pass. # t0 = time.time() g = (jnp.ones(()), zeros_like_tree(x0)) g = solve_ivp_bwd(policy_params, y_fn, g) # print(f"... Backward pass took {time.time() - t0}s") return cost, g
def bwd(args, y_fn, g): aug = (jnp.zeros(()), g, zeros_like_tree(args)) for i in range(y_fn.Q.shape[0])[::-1]: # Believe it or not it's faster to pull these out into variables. ta = y_fn.ts[i] tb = y_fn.ts[i + 1] # Believe it or not solve_ivp hands us shit like time steps that are only # 1e-16 apart. And those don't play nicely with Runge-Kutta. if tb - ta > 1e-8: aug = bwd_spline_segment(ta, tb, args, y_fn.Q[i], y_fn.y_old[i], aug) (_, _, adj_args) = aug return adj_args
def run(x0, args): (loss, _), _, _, adj_args_t0, bvp_soln = bvp_fwd_bwd( f, (0.0, x0), 0, T, args, (1.0, zeros_like_tree(x0))) assert bvp_soln.success return loss, adj_args_t0