def make_default_ddpg_train_config(env_spec: GymEnvSpec): """Usually decent parameters.""" # Actions must be bounded [-1, 1]. actor_init, actor = stax.serial( Dense(64), Relu, Dense(64), Relu, Dense(env_spec.action_shape[0]), Tanh, ) critic_init, critic = stax.serial( FanInConcat(), Dense(64), Relu, Dense(64), Relu, Dense(64), Relu, Dense(1), Scalarify, ) return DDPGTrainConfig( gamma=0.99, tau=1e-4, buffer_size=2**20, batch_size=128, optimizer_init=make_optimizer(optimizers.adam(step_size=1e-3)), # For some reason using DiagMVN here is ~100x slower. noise=lambda _1, _2: Normal( jp.zeros(env_spec.action_shape), 0.1 * jp.ones(env_spec.action_shape), ), actor_init=actor_init, actor=actor, critic_init=critic_init, critic=critic, )
def main(): total_secs = 10.0 gamma = 0.9 rng = random.PRNGKey(0) ### Set up the problem/environment # xdot = Ax + Bu # u = - Kx # cost = xQx + uRu + 2xNu A = jp.eye(2) B = jp.eye(2) Q = jp.eye(2) R = jp.eye(2) N = jp.zeros((2, 2)) # rngA, rngB, rngQ, rngR, rng = random.split(rng, 5) # # A = random.normal(rngA, (2, 2)) # A = -1 * random_psd(rngA, 2) # B = random.normal(rngB, (2, 2)) # Q = random_psd(rngQ, 2) + 0.1 * jp.eye(2) # R = random_psd(rngR, 2) + 0.1 * jp.eye(2) # N = jp.zeros((2, 2)) # x_dim, u_dim = B.shape dynamics_fn = lambda x, u: A @ x + B @ u cost_fn = lambda x, u: x.T @ Q @ x + u.T @ R @ u + 2 * x.T @ N @ u ### Solve the Riccatti equation to get the infinite-horizon optimal solution. K, _, _ = control.lqr(A, B, Q, R, N) K = jp.array(K) t0 = time.time() rng_eval, rng = random.split(rng) x0_eval = random.normal(rng_eval, (1000, 2)) opt_all_costs = vmap(lambda x0: policy_integrate_cost( dynamics_fn, cost_fn, lambda _, x: -K @ x, gamma) (None, x0, total_secs))(x0_eval) opt_cost = jp.mean(opt_all_costs) print(f"opt_cost = {opt_cost} in {time.time() - t0}s") ### Set up the learned policy model. policy_init, policy = stax.serial( Dense(64), Relu, Dense(64), Relu, Dense(2), ) # policy_init, policy = DenseNoBias(2) rng_init_params, rng = random.split(rng) _, init_policy_params = policy_init(rng_init_params, (2, )) cost_and_grad = jit( value_and_grad( policy_integrate_cost(dynamics_fn, cost_fn, policy, gamma))) opt = make_optimizer(optimizers.adam(1e-3))(init_policy_params) def multiple_steps(num_steps): """Return a jit-able function that runs `num_steps` iterations.""" def body(_, stuff): rng, _, opt = stuff rng_x0, rng = random.split(rng) x0 = random.normal(rng_x0, (2, )) cost, g = cost_and_grad(opt.value, x0, total_secs) # Gradient clipping # g = tree_map(lambda x: jp.clip(x, -10, 10), g) # g = optimizers.clip_grads(g, 64) return rng, cost, opt.update(g) return lambda rng, opt: lax.fori_loop(0, num_steps, body, (rng, jp.zeros(()), opt)) multi_steps = 1 run = jit(multiple_steps(multi_steps)) ### Main optimization loop. costs = [] for i in range(25000): t0 = time.time() rng, cost, opt = run(rng, opt) print(f"Episode {(i + 1) * multi_steps}:") print(f" excess cost = {cost - opt_cost}") print(f" elapsed = {time.time() - t0}") costs.append(float(cost)) print(f"Opt solution cost from starting point: {opt_cost}") # print(f"Gradient at opt solution: {opt_g}") # Print the identified and optimal policy. Note that layers multiply multipy # on the right instead of the left so we need a transpose. print(f"Est solution parameters: {opt.value}") print(f"Opt solution parameters: {K.T}") est_all_costs = vmap( lambda x0: policy_integrate_cost(dynamics_fn, cost_fn, policy, gamma) (opt.value, x0, total_secs))(x0_eval) ### Scatter plot of learned policy performance vs optimal policy performance. plt.figure() plt.scatter(est_all_costs, opt_all_costs) plt.plot([-100, 100], [-100, 100], color="gray") plt.xlim(0, jp.max(est_all_costs)) plt.ylim(0, jp.max(opt_all_costs)) plt.xlabel("Learned policy cost") plt.ylabel("Optimal cost") plt.title("Performance relative to the direct LQR solution") ### Plot performance per iteration, incl. average optimal policy performance. plt.figure() plt.plot(costs) plt.axhline(opt_cost, linestyle="--", color="gray") plt.yscale("log") plt.xlabel("Iteration") plt.ylabel(f"Cost (T = {total_secs}s)") plt.legend(["Learned policy", "Direct LQR solution"]) plt.title("ODE control of LQR problem") ### Example rollout plots (learned policy vs optimal policy). x0 = jp.array([1.0, 2.0]) framerate = 30 timesteps = jp.linspace(0, total_secs, num=int(total_secs * framerate)) est_policy_rollout_states = ode.odeint( lambda x, _: dynamics_fn(x, policy(opt.value, x)), y0=x0, t=timesteps) est_policy_rollout_controls = vmap(lambda x: policy(opt.value, x))( est_policy_rollout_states) opt_policy_rollout_states = ode.odeint(lambda x, _: dynamics_fn(x, -K @ x), y0=x0, t=timesteps) opt_policy_rollout_controls = vmap(lambda x: -K @ x)( opt_policy_rollout_states) plt.figure() plt.plot(est_policy_rollout_states[:, 0], est_policy_rollout_states[:, 1], marker='.') plt.plot(opt_policy_rollout_states[:, 0], opt_policy_rollout_states[:, 1], marker='.') plt.xlabel("x_1") plt.ylabel("x_2") plt.legend(["Learned policy", "Direct LQR solution"]) plt.title("Phase space trajectory") plt.figure() plt.plot(timesteps, jp.sqrt(jp.sum(est_policy_rollout_controls**2, axis=-1))) plt.plot(timesteps, jp.sqrt(jp.sum(opt_policy_rollout_controls**2, axis=-1))) plt.xlabel("time") plt.ylabel("control input (L2 norm)") plt.legend(["Learned policy", "Direct LQR solution"]) plt.title("Policy control over time") ### Plot quiver field showing dynamics under learned policy. plot_policy_dynamics(dynamics_fn, cost_fn, lambda x: policy(opt.value, x)) plt.show()
def main(): num_iter = 50000 # Most people run 1000 steps and the OpenAI gym pendulum is 0.05s per step. # The max torque that can be applied is also 2 in their setup. T = 1000 time_delta = 0.05 max_torque = 2.0 rng = random.PRNGKey(0) dynamics = pendulum_dynamics( mass=1.0, length=1.0, gravity=9.8, friction=0.0, ) policy_init, policy_nn = stax.serial( Dense(64), Tanh, Dense(64), Tanh, Dense(1), Tanh, stax.elementwise(lambda x: max_torque * x), ) # Should it matter whether theta is wrapped into [0, 2pi]? policy = lambda params, x: policy_nn( params, jnp.array([x[0] % (2 * jnp.pi), x[1], jnp.cos(x[0]), jnp.sin(x[0])])) def loss(policy_params, x0): x = x0 acc_cost = 0.0 for _ in range(T): u = policy(policy_params, x) x += time_delta * dynamics(x, u) acc_cost += time_delta * cost(x, u) return acc_cost rng_init_params, rng = random.split(rng) _, init_policy_params = policy_init(rng_init_params, (4, )) opt = make_optimizer(optimizers.adam(1e-3))(init_policy_params) loss_and_grad = jit(value_and_grad(loss)) loss_per_iter = [] elapsed_per_iter = [] x0s = vmap(sample_x0)(random.split(rng, num_iter)) for i in range(num_iter): t0 = time.time() loss, g = loss_and_grad(opt.value, x0s[i]) opt = opt.update(g) elapsed = time.time() - t0 loss_per_iter.append(loss) elapsed_per_iter.append(elapsed) print(f"Episode {i}") print(f" loss = {loss}") print(f" elapsed = {elapsed}") blt.remember({ "loss_per_iter": loss_per_iter, "elapsed_per_iter": elapsed_per_iter, "final_params": opt.value }) plt.figure() plt.plot(loss_per_iter) plt.yscale("log") plt.title("ODE control of an inverted pendulum") plt.xlabel("Iteration") plt.ylabel(f"Policy cost (T = {total_secs}s)") # Viz num_viz_rollouts = 50 framerate = 30 timesteps = jnp.linspace(0, int(T * time_delta), num=int(T * time_delta * framerate)) rollout = lambda x0: ode.odeint( lambda x, _: dynamics(x, policy(opt.value, x)), y0=x0, t=timesteps) plt.figure() states = rollout(jnp.zeros(2)) plt.plot(states[:, 0], states[:, 1], marker=".") plt.xlabel("theta") plt.ylabel("theta dot") plt.title("Swing up trajectory") plt.figure() states = vmap(rollout)(x0s[:num_viz_rollouts]) for i in range(num_viz_rollouts): plt.plot(states[i, :, 0], states[i, :, 1], marker='.', alpha=0.5) plt.xlabel("theta") plt.ylabel("theta dot") plt.title("Phase space trajectory") plot_control_contour(lambda x: policy(opt.value, x)) plot_policy_dynamics(dynamics, lambda x: policy(opt.value, x)) blt.show()
from jax.experimental import stax from jax.experimental.stax import FanInConcat, Dense, Relu, Tanh from research.estop import ddpg from research.estop.pendulum import config from research.estop.pendulum.env import viz_pendulum_rollout from research.estop.utils import Scalarify from research.statistax import Deterministic, Normal from research.utils import make_optimizer from research.estop import mdp tau = 1e-4 buffer_size = 2**15 batch_size = 64 num_eval_rollouts = 128 opt_init = make_optimizer(optimizers.adam(step_size=1e-3)) init_noise = Normal(jp.array(0.0), jp.array(0.0)) noise = lambda _1, _2: Normal(jp.array(0.0), jp.array(0.5)) actor_init, actor = stax.serial( Dense(64), Relu, Dense(1), Tanh, stax.elementwise(lambda x: config.max_torque * x), ) critic_init, critic = stax.serial( FanInConcat(), Dense(64), Relu,
def main(): num_keypoints = 64 gamma = 0.9 rng = random.PRNGKey(0) ### Set up the problem/environment # xdot = Ax + Bu # u = - Kx # cost = xQx + uRu + 2xNu A = -0.1 * jp.eye(2) B = jp.eye(2) Q = jp.eye(2) R = jp.eye(2) N = jp.zeros((2, 2)) dynamics_fn = lambda x, u: A @ x + B @ u cost_fn = lambda x, u: x.T @ Q @ x + u.T @ R @ u + 2 * x.T @ N @ u ### Solve the Riccatti equation to get the infinite-horizon optimal solution. K, _, _ = control.lqr(A, B, Q, R, N) K = jp.array(K) t0 = time.time() rng_x0_eval, rng_eval_keypoints, rng = random.split(rng, 3) x0_eval = random.normal(rng_x0_eval, (1000, 2)) opt_all_costs = vmap(lambda rng, x0: policy_cost( dynamics_fn, cost_fn, lambda _, x: -K @ x, num_keypoints) (None, rng, x0, gamma), in_axes=(0, 0))(random.split(rng_eval_keypoints, x0_eval.shape[0]), x0_eval) opt_cost = jp.mean(opt_all_costs) print(f"opt_cost = {opt_cost} in {time.time() - t0}s") ### Set up the learned policy model. policy_init, policy = stax.serial( Dense(64), Relu, Dense(64), Relu, Dense(2), ) # policy_init, policy = DenseNoBias(2) rng_init_params, rng = random.split(rng) _, init_policy_params = policy_init(rng_init_params, (2, )) cost_and_grad = value_and_grad( policy_cost(dynamics_fn, cost_fn, policy, num_keypoints)) opt = make_optimizer(optimizers.adam(1e-3))(init_policy_params) ### Main optimization loop. costs = [] for i in range(1000): t0 = time.time() # gamma_i = gamma * (1 - 1 / (i / 1000 + 1.01)) gamma_i = gamma rng_x0, rng_keypoints, rng = random.split(rng, 3) x0 = random.normal(rng_x0, (2, )) cost, g = cost_and_grad(opt.value, rng_keypoints, x0, gamma_i) opt = opt.update(g) print(f"Episode {i}:") print(f" excess cost = {cost - opt_cost}") print(f" gamma = {gamma_i}") print(f" elapsed = {time.time() - t0}") costs.append(float(cost)) if not jp.isfinite(cost): break print(f"Opt solution cost from starting point: {opt_cost}") # print(f"Gradient at opt solution: {opt_g}") # Print the identified and optimal policy. Note that layers multiply multipy # on the right instead of the left so we need a transpose. print(f"Est solution parameters: {opt.value}") print(f"Opt solution parameters: {K.T}") rng_eval_keypoints, rng = random.split(rng) est_all_costs = vmap(lambda rng, x0: policy_cost(dynamics_fn, cost_fn, policy, num_keypoints) (opt.value, rng, x0, gamma), in_axes=(0, 0))(random.split(rng_eval_keypoints, x0_eval.shape[0]), x0_eval) # ### Scatter plot of learned policy performance vs optimal policy performance. plt.figure() plt.scatter(est_all_costs, opt_all_costs) plt.plot([-100, 100], [-100, 100], color="gray") plt.xlim(0, jp.max(est_all_costs)) plt.ylim(0, jp.max(opt_all_costs)) plt.xlabel("Learned policy cost") plt.ylabel("Optimal cost") plt.title("Performance relative to the direct LQR solution") ### Plot performance per iteration, incl. average optimal policy performance. plt.figure() plt.plot(costs) plt.axhline(opt_cost, linestyle="--", color="gray") plt.yscale("log") plt.xlabel("Iteration") plt.ylabel(f"Cost (gamma = {gamma}, num_keypoints = {num_keypoints})") plt.legend(["Learned policy", "Direct LQR solution"]) plt.title("ODE control of LQR problem (keypoint sampling)") ### Example rollout plots (learned policy vs optimal policy). x0 = jp.array([1.0, 2.0]) framerate = 30 total_secs = 10.0 timesteps = jp.linspace(0, total_secs, num=int(total_secs * framerate)) est_policy_rollout_states = ode.odeint( lambda x, _: dynamics_fn(x, policy(opt.value, x)), y0=x0, t=timesteps) est_policy_rollout_controls = vmap(lambda x: policy(opt.value, x))( est_policy_rollout_states) opt_policy_rollout_states = ode.odeint(lambda x, _: dynamics_fn(x, -K @ x), y0=x0, t=timesteps) opt_policy_rollout_controls = vmap(lambda x: -K @ x)( opt_policy_rollout_states) plt.figure() plt.plot(est_policy_rollout_states[:, 0], est_policy_rollout_states[:, 1], marker='.') plt.plot(opt_policy_rollout_states[:, 0], opt_policy_rollout_states[:, 1], marker='.') plt.xlabel("x_1") plt.ylabel("x_2") plt.legend(["Learned policy", "Direct LQR solution"]) plt.title("Phase space trajectory") plt.figure() plt.plot(timesteps, jp.sqrt(jp.sum(est_policy_rollout_controls**2, axis=-1))) plt.plot(timesteps, jp.sqrt(jp.sum(opt_policy_rollout_controls**2, axis=-1))) plt.xlabel("time") plt.ylabel("control input (L2 norm)") plt.legend(["Learned policy", "Direct LQR solution"]) plt.title("Policy control over time") ### Plot quiver field showing dynamics under learned policy. plot_policy_dynamics(dynamics_fn, cost_fn, lambda x: policy(opt.value, x)) blt.show()
def main(): rng = random.PRNGKey(0) x_dim = 2 T = 20.0 policy_init, policy = stax.serial( Dense(64), Tanh, Dense(x_dim), ) x0 = jnp.ones(x_dim) A, B, Q, R, N = fixed_env(x_dim) print("System dynamics:") print(f" A = {A}") print(f" B = {B}") print(f" Q = {Q}") print(f" R = {R}") print(f" N = {N}") print() dynamics_fn = lambda x, u: A @ x + B @ u cost_fn = lambda x, u: x.T @ Q @ x + u.T @ R @ u + 2 * x.T @ N @ u ### Evaluate LQR solution to get a sense of optimal cost. K, _, _ = control.lqr(A, B, Q, R, N) K = jnp.array(K) opt_policy_cost_fn = policy_cost_and_grad(dynamics_fn, cost_fn, lambda KK, x: -KK @ x, example_x=x0) opt_loss, _opt_K_grad = opt_policy_cost_fn(K, x0, T) # This is true for longer time horizons, but not true for shorter time # horizons due to the LQR solution being an infinite-time solution. # assert jnp.allclose(opt_K_grad, 0) ### Training loop. rng_init_params, rng = random.split(rng) _, init_policy_params = policy_init(rng_init_params, (x_dim, )) opt = make_optimizer(optimizers.adam(1e-3))(init_policy_params) loss_and_grad = policy_cost_and_grad(dynamics_fn, cost_fn, policy, example_x=x0) loss_per_iter = [] elapsed_per_iter = [] for iteration in range(10000): t0 = time.time() loss, g = loss_and_grad(opt.value, x0, T) opt = opt.update(g) elapsed = time.time() - t0 loss_per_iter.append(loss) elapsed_per_iter.append(elapsed) print(f"Iteration {iteration}") print(f" excess loss = {loss - opt_loss}") print(f" elapsed = {elapsed}") blt.remember({ "loss_per_iter": loss_per_iter, "elapsed_per_iter": elapsed_per_iter, "opt_loss": opt_loss }) _, ax1 = plt.subplots() ax1.set_xlabel("Iteration") ax1.set_ylabel("Cost", color="tab:blue") ax1.set_yscale("log") ax1.tick_params(axis="y", labelcolor="tab:blue") ax1.plot(loss_per_iter, color="tab:blue", label="Total rollout cost") plt.axhline(opt_loss, linestyle="--", color="gray") ax1.legend(loc="upper left") plt.title("Combined fwd-bwd BVP problem") blt.show()
def main(): total_time = 20.0 gamma = 1.0 x_dim = 2 outer_loop_count = 10 inner_loop_count = 1000 rng = random.PRNGKey(0) x0 = jnp.array([2.0, 1.0]) ### Set up the problem/environment # xdot = Ax + Bu # u = - Kx # cost = xQx + uRu + 2xNu A, B, Q, R, N = fixed_env(x_dim) print("System dynamics:") print(f" A = {A}") print(f" B = {B}") print(f" Q = {Q}") print(f" R = {R}") print(f" N = {N}") print() dynamics_fn = lambda x, u: A @ x + B @ u # cost_fn = lambda x, u: x.T @ Q @ x + u.T @ R @ u + 2 * x.T @ N @ u position_cost_fn = lambda x: x.T @ Q @ x control_cost_fn = lambda u: u.T @ R @ u ### Solve the Riccatti equation to get the infinite-horizon optimal solution. K, _, _ = control.lqr(A, B, Q, R, N) K = jnp.array(K) _, (opt_x_cost_fwd, opt_u_cost_fwd, opt_xT_fwd), (opt_x_cost_bwd, opt_u_cost_bwd, opt_x0_bwd) = policy_integrate_cost(dynamics_fn, position_cost_fn, control_cost_fn, gamma, lambda _, x: -K @ x)(None, x0, total_time) opt_cost_fwd = opt_x_cost_fwd + opt_u_cost_fwd print("LQR solution:") print(f" K = {K}") print(f" opt_x_cost_fwd = {opt_x_cost_fwd}") print(f" opt_u_cost_fwd = {opt_u_cost_fwd}") print(f" opt_x_cost_bwd = {opt_x_cost_bwd}") print(f" opt_u_cost_bwd = {opt_u_cost_bwd}") print(f" opt_cost_fwd = {opt_cost_fwd}") print(f" opt_xT_fwd = {opt_xT_fwd}") print(f" opt_x0_bwd = {opt_x0_bwd}") print(f" ||x0 - opt_x0_bwd||^2 = {jnp.sum((x0 - opt_x0_bwd)**2)}") print() ### Set up the learned policy model. policy_init, policy = stax.serial( Dense(64), Tanh, Dense(x_dim), ) rng_init_params, rng = random.split(rng) _, init_policy_params = policy_init(rng_init_params, (x_dim, )) init_opt = make_optimizer(optimizers.adam(1e-3))(init_policy_params) def inner_loop(opt, _): runny_run = policy_integrate_cost(dynamics_fn, position_cost_fn, control_cost_fn, gamma, policy) (y0_fwd, yT_fwd, y0_bwd), vjp = jax.vjp(runny_run, opt.value, x0, total_time) x_cost_T_fwd, u_cost_T_fwd, xT_fwd = yT_fwd x_cost_0_bwd, u_cost_0_bwd, x0_bwd = y0_bwd yT_fwd_bar = (jnp.ones(()), jnp.ones(()), jnp.zeros_like(x0)) g, _, _ = vjp((zeros_like_tree(y0_fwd), yT_fwd_bar, zeros_like_tree(y0_bwd))) return opt.update(g), Record(x_cost_T_fwd, u_cost_T_fwd, xT_fwd, x_cost_0_bwd, u_cost_0_bwd, x0_bwd) def outer_loop(opt, last: Record, elapsed=None): x_cost_T_fwd, u_cost_T_fwd, xT_fwd, x_cost_0_bwd, u_cost_0_bwd, x0_bwd = last print(f"Episode {opt.iteration}:") print(f" excess fwd cost = {(x_cost_T_fwd + u_cost_T_fwd) - opt_cost_fwd}") print(f" excess fwd x cost = {x_cost_T_fwd - opt_x_cost_fwd}") print(f" excess fwd u cost = {u_cost_T_fwd - opt_u_cost_fwd}") print(f" bwd cost = {x_cost_0_bwd + u_cost_0_bwd}") print(f" bwd x cost = {x_cost_0_bwd}") print(f" bwd u cost = {u_cost_0_bwd}") print(f" bwd x0 - x0 = {x0_bwd - x0}") print(f" fwd xT = {xT_fwd}") print(f" fwd xT norm sq. = {jnp.sum(xT_fwd**2)}") print(f" elapsed/iter = {elapsed/inner_loop_count}s") ### Main optimization loop. t1 = time.time() _, history = fruity_loops(outer_loop, inner_loop, outer_loop_count, inner_loop_count, init_opt) print(f"total elapsed = {time.time() - t1}s") blt.remember({"history": history}) cost_T_fwd_per_iter = history.x_cost_T_fwd_per_iter + history.u_cost_T_fwd_per_iter cost_0_bwd_per_iter = history.x_cost_0_bwd_per_iter + history.u_cost_0_bwd_per_iter ### Plot performance per iteration, incl. average optimal policy performance. _, ax1 = plt.subplots() ax1.set_xlabel("Iteration") ax1.set_ylabel("Cost", color="tab:blue") ax1.set_yscale("log") ax1.tick_params(axis="y", labelcolor="tab:blue") ax1.plot(cost_T_fwd_per_iter, color="tab:blue", label="Total rollout cost") ax1.plot(history.x_cost_T_fwd_per_iter, linestyle="dotted", color="tab:blue", label="Position cost") ax1.plot(history.u_cost_T_fwd_per_iter, linestyle="dashed", color="tab:blue", label="Control cost") plt.axhline(opt_cost_fwd, linestyle="--", color="gray") ax1.legend(loc="upper left") ax2 = ax1.twinx() ax2.set_ylabel("Error", color="tab:red") ax2.set_yscale("log") ax2.tick_params(axis="y", labelcolor="tab:red") ax2.plot(cost_0_bwd_per_iter**2, alpha=0.5, color="tab:red", label="Cost rewind error") ax2.plot(jnp.sum((history.x0_bwd_per_iter - x0)**2, axis=-1), alpha=0.5, color="tab:purple", label="x(0) rewind error") ax2.plot(jnp.sum(history.xT_fwd_per_iter**2, axis=-1), alpha=0.5, color="tab:brown", label="x(T) squared norm") ax2.legend(loc="upper right") plt.title("ODE control of LQR problem") blt.show()
def main(): num_keypoints = 64 train_batch_size = 64 eval_batch_size = 1024 gamma = 0.9 rng = random.PRNGKey(0) ### Set up the problem/environment # xdot = Ax + Bu # u = - Kx # cost = xQx + uRu + 2xNu A = -0.1 * jp.eye(1) B = jp.eye(1) Q = jp.eye(1) R = jp.eye(1) N = jp.zeros((1, 1)) cost_fn = lambda x, u: x.T @ Q @ x + u.T @ R @ u + 2 * x.T @ N @ u ### Solve the Riccatti equation to get the infinite-horizon optimal solution. K, _, _ = control.lqr(A, B, Q, R, N) K = jp.array(K) def loss(batch_size): def lossy_loss(KK, rng): rng_t, rng_x0 = random.split(rng) x0 = random.normal(rng_x0, shape=(batch_size, )) t = random.exponential(rng_t, shape=(num_keypoints, )) / -jp.log(gamma) x_t = jp.outer(jp.exp(t * jp.squeeze(A - B @ KK)), x0) costs = vmap(lambda x: cost_fn(x, -KK @ x))(jp.reshape(x_t, (-1, 1))) return jp.mean(costs) return lossy_loss t0 = time.time() rng_eval_keypoints, rng = random.split(rng) opt_all_costs = loss(eval_batch_size)(K, rng_eval_keypoints) opt_cost = jp.mean(opt_all_costs) print(f"opt_cost = {opt_cost} in {time.time() - t0}s") ### Set up the learned policy model. rng_init_params, rng = random.split(rng) opt = make_optimizer(optimizers.adam(1e-3))(random.normal(rng_init_params, shape=(1, 1))) cost_and_grad = jit(value_and_grad(loss(train_batch_size))) ### Main optimization loop. costs = [] for i in range(10000): t0 = time.time() rng_iter, rng = random.split(rng) cost, g = cost_and_grad(opt.value, rng_iter) opt = opt.update(g) print(f"Episode {i}: excess cost = {cost - opt_cost}, elapsed = {time.time() - t0}") costs.append(float(cost)) print(f"Opt solution cost from starting point: {opt_cost}") print(f"Est solution parameters: {opt.value}") print(f"Opt solution parameters: {K}") ### Plot performance per iteration, incl. average optimal policy performance. plt.figure() plt.plot(costs) plt.axhline(opt_cost, linestyle="--", color="gray") plt.yscale("log") plt.xlabel("Iteration") plt.ylabel(f"Cost (gamma = {gamma}, num_keypoints = {num_keypoints})") plt.legend(["Learned policy", "Direct LQR solution"]) plt.title("ODE control of LQR problem (keypoint sampling)") plt.show()
def main(): total_time = 20.0 gamma = 1.0 x_dim = 2 z_dim = 32 rng = random.PRNGKey(0) x0 = jp.array([2.0, 1.0]) ### Set up the problem/environment # xdot = Ax + Bu # u = - Kx # cost = xQx + uRu + 2xNu A, B, Q, R, N = fixed_env(x_dim) dynamics_fn = lambda x, u: A @ x + B @ u cost_fn = lambda x, u: x.T @ Q @ x + u.T @ R @ u + 2 * x.T @ N @ u policy_loss = policy_integrate_cost(x_dim, z_dim, dynamics_fn, cost_fn, gamma) ### Solve the Riccatti equation to get the infinite-horizon optimal solution. K, _, _ = control.lqr(A, B, Q, R, N) K = jp.array(K) t0 = time.time() opt_y_fwd, opt_y_bwd = policy_loss(lambda _, x: -K @ x)(None, x0, total_time) opt_cost = opt_y_fwd[1, 0] print(f"opt_cost = {opt_cost} in {time.time() - t0}s") print(opt_y_fwd) print(opt_y_bwd) print(f"l2 error: {jp.sqrt(jp.sum((opt_y_fwd - opt_y_bwd)**2))}") ### Set up the learned policy model. policy_init, policy = stax.serial( Dense(64), Tanh, Dense(x_dim), ) rng_init_params, rng = random.split(rng) _, init_policy_params = policy_init(rng_init_params, (x_dim, )) opt = make_optimizer(optimizers.adam(1e-3))(init_policy_params) runny_run = jit(policy_loss(policy)) ### Main optimization loop. costs = [] bwd_errors = [] for i in range(5000): t0 = time.time() (y_fwd, y_bwd), vjp = jax.vjp(runny_run, opt.value, x0, total_time) cost = y_fwd[1, 0] y_fwd_bar = jax.ops.index_update(jp.zeros_like(y_fwd), (1, 0), 1) g, _, _ = vjp((y_fwd_bar, jp.zeros_like(y_bwd))) opt = opt.update(g) bwd_err = jp.sqrt(jp.sum((y_fwd - y_bwd)**2)) bwd_errors.append(bwd_err) print(f"Episode {i}:") print(f" excess cost = {cost - opt_cost}") print(f" bwd error = {bwd_err}") print(f" elapsed = {time.time() - t0}") costs.append(float(cost)) print(f"Opt solution cost from starting point: {opt_cost}") ### Plot performance per iteration, incl. average optimal policy performance. _, ax1 = plt.subplots() ax1.set_xlabel("Iteration") ax1.set_ylabel("Cost", color="tab:blue") ax1.set_yscale("log") ax1.tick_params(axis="y", labelcolor="tab:blue") ax1.plot(costs, color="tab:blue") plt.axhline(opt_cost, linestyle="--", color="gray") ax2 = ax1.twinx() ax2.set_ylabel("Backward solve L2 error", color="tab:red") ax2.set_yscale("log") ax2.tick_params(axis="y", labelcolor="tab:red") ax2.plot(bwd_errors, color="tab:red") plt.title("ODE control of LQR problem") blt.show()