예제 #1
0
def make_default_ddpg_train_config(env_spec: GymEnvSpec):
    """Usually decent parameters."""
    # Actions must be bounded [-1, 1].
    actor_init, actor = stax.serial(
        Dense(64),
        Relu,
        Dense(64),
        Relu,
        Dense(env_spec.action_shape[0]),
        Tanh,
    )

    critic_init, critic = stax.serial(
        FanInConcat(),
        Dense(64),
        Relu,
        Dense(64),
        Relu,
        Dense(64),
        Relu,
        Dense(1),
        Scalarify,
    )
    return DDPGTrainConfig(
        gamma=0.99,
        tau=1e-4,
        buffer_size=2**20,
        batch_size=128,
        optimizer_init=make_optimizer(optimizers.adam(step_size=1e-3)),
        # For some reason using DiagMVN here is ~100x slower.
        noise=lambda _1, _2: Normal(
            jp.zeros(env_spec.action_shape),
            0.1 * jp.ones(env_spec.action_shape),
        ),
        actor_init=actor_init,
        actor=actor,
        critic_init=critic_init,
        critic=critic,
    )
예제 #2
0
def main():
    total_secs = 10.0
    gamma = 0.9
    rng = random.PRNGKey(0)

    ### Set up the problem/environment
    # xdot = Ax + Bu
    # u = - Kx
    # cost = xQx + uRu + 2xNu

    A = jp.eye(2)
    B = jp.eye(2)
    Q = jp.eye(2)
    R = jp.eye(2)
    N = jp.zeros((2, 2))

    # rngA, rngB, rngQ, rngR, rng = random.split(rng, 5)
    # # A = random.normal(rngA, (2, 2))
    # A = -1 * random_psd(rngA, 2)
    # B = random.normal(rngB, (2, 2))
    # Q = random_psd(rngQ, 2) + 0.1 * jp.eye(2)
    # R = random_psd(rngR, 2) + 0.1 * jp.eye(2)
    # N = jp.zeros((2, 2))

    # x_dim, u_dim = B.shape

    dynamics_fn = lambda x, u: A @ x + B @ u
    cost_fn = lambda x, u: x.T @ Q @ x + u.T @ R @ u + 2 * x.T @ N @ u

    ### Solve the Riccatti equation to get the infinite-horizon optimal solution.
    K, _, _ = control.lqr(A, B, Q, R, N)
    K = jp.array(K)

    t0 = time.time()
    rng_eval, rng = random.split(rng)
    x0_eval = random.normal(rng_eval, (1000, 2))
    opt_all_costs = vmap(lambda x0: policy_integrate_cost(
        dynamics_fn, cost_fn, lambda _, x: -K @ x, gamma)
                         (None, x0, total_secs))(x0_eval)
    opt_cost = jp.mean(opt_all_costs)
    print(f"opt_cost = {opt_cost} in {time.time() - t0}s")

    ### Set up the learned policy model.
    policy_init, policy = stax.serial(
        Dense(64),
        Relu,
        Dense(64),
        Relu,
        Dense(2),
    )
    # policy_init, policy = DenseNoBias(2)

    rng_init_params, rng = random.split(rng)
    _, init_policy_params = policy_init(rng_init_params, (2, ))

    cost_and_grad = jit(
        value_and_grad(
            policy_integrate_cost(dynamics_fn, cost_fn, policy, gamma)))
    opt = make_optimizer(optimizers.adam(1e-3))(init_policy_params)

    def multiple_steps(num_steps):
        """Return a jit-able function that runs `num_steps` iterations."""
        def body(_, stuff):
            rng, _, opt = stuff
            rng_x0, rng = random.split(rng)
            x0 = random.normal(rng_x0, (2, ))
            cost, g = cost_and_grad(opt.value, x0, total_secs)

            # Gradient clipping
            # g = tree_map(lambda x: jp.clip(x, -10, 10), g)
            # g = optimizers.clip_grads(g, 64)

            return rng, cost, opt.update(g)

        return lambda rng, opt: lax.fori_loop(0, num_steps, body,
                                              (rng, jp.zeros(()), opt))

    multi_steps = 1
    run = jit(multiple_steps(multi_steps))

    ### Main optimization loop.
    costs = []
    for i in range(25000):
        t0 = time.time()
        rng, cost, opt = run(rng, opt)
        print(f"Episode {(i + 1) * multi_steps}:")
        print(f"    excess cost = {cost - opt_cost}")
        print(f"    elapsed = {time.time() - t0}")
        costs.append(float(cost))

    print(f"Opt solution cost from starting point: {opt_cost}")
    # print(f"Gradient at opt solution: {opt_g}")

    # Print the identified and optimal policy. Note that layers multiply multipy
    # on the right instead of the left so we need a transpose.
    print(f"Est solution parameters: {opt.value}")
    print(f"Opt solution parameters: {K.T}")

    est_all_costs = vmap(
        lambda x0: policy_integrate_cost(dynamics_fn, cost_fn, policy, gamma)
        (opt.value, x0, total_secs))(x0_eval)

    ### Scatter plot of learned policy performance vs optimal policy performance.
    plt.figure()
    plt.scatter(est_all_costs, opt_all_costs)
    plt.plot([-100, 100], [-100, 100], color="gray")
    plt.xlim(0, jp.max(est_all_costs))
    plt.ylim(0, jp.max(opt_all_costs))
    plt.xlabel("Learned policy cost")
    plt.ylabel("Optimal cost")
    plt.title("Performance relative to the direct LQR solution")

    ### Plot performance per iteration, incl. average optimal policy performance.
    plt.figure()
    plt.plot(costs)
    plt.axhline(opt_cost, linestyle="--", color="gray")
    plt.yscale("log")
    plt.xlabel("Iteration")
    plt.ylabel(f"Cost (T = {total_secs}s)")
    plt.legend(["Learned policy", "Direct LQR solution"])
    plt.title("ODE control of LQR problem")

    ### Example rollout plots (learned policy vs optimal policy).
    x0 = jp.array([1.0, 2.0])
    framerate = 30
    timesteps = jp.linspace(0, total_secs, num=int(total_secs * framerate))
    est_policy_rollout_states = ode.odeint(
        lambda x, _: dynamics_fn(x, policy(opt.value, x)), y0=x0, t=timesteps)
    est_policy_rollout_controls = vmap(lambda x: policy(opt.value, x))(
        est_policy_rollout_states)

    opt_policy_rollout_states = ode.odeint(lambda x, _: dynamics_fn(x, -K @ x),
                                           y0=x0,
                                           t=timesteps)
    opt_policy_rollout_controls = vmap(lambda x: -K @ x)(
        opt_policy_rollout_states)

    plt.figure()
    plt.plot(est_policy_rollout_states[:, 0],
             est_policy_rollout_states[:, 1],
             marker='.')
    plt.plot(opt_policy_rollout_states[:, 0],
             opt_policy_rollout_states[:, 1],
             marker='.')
    plt.xlabel("x_1")
    plt.ylabel("x_2")
    plt.legend(["Learned policy", "Direct LQR solution"])
    plt.title("Phase space trajectory")

    plt.figure()
    plt.plot(timesteps, jp.sqrt(jp.sum(est_policy_rollout_controls**2,
                                       axis=-1)))
    plt.plot(timesteps, jp.sqrt(jp.sum(opt_policy_rollout_controls**2,
                                       axis=-1)))
    plt.xlabel("time")
    plt.ylabel("control input (L2 norm)")
    plt.legend(["Learned policy", "Direct LQR solution"])
    plt.title("Policy control over time")

    ### Plot quiver field showing dynamics under learned policy.
    plot_policy_dynamics(dynamics_fn, cost_fn, lambda x: policy(opt.value, x))

    plt.show()
예제 #3
0
def main():
    num_iter = 50000
    # Most people run 1000 steps and the OpenAI gym pendulum is 0.05s per step.
    # The max torque that can be applied is also 2 in their setup.
    T = 1000
    time_delta = 0.05
    max_torque = 2.0
    rng = random.PRNGKey(0)

    dynamics = pendulum_dynamics(
        mass=1.0,
        length=1.0,
        gravity=9.8,
        friction=0.0,
    )

    policy_init, policy_nn = stax.serial(
        Dense(64),
        Tanh,
        Dense(64),
        Tanh,
        Dense(1),
        Tanh,
        stax.elementwise(lambda x: max_torque * x),
    )

    # Should it matter whether theta is wrapped into [0, 2pi]?
    policy = lambda params, x: policy_nn(
        params,
        jnp.array([x[0] % (2 * jnp.pi), x[1],
                   jnp.cos(x[0]),
                   jnp.sin(x[0])]))

    def loss(policy_params, x0):
        x = x0
        acc_cost = 0.0
        for _ in range(T):
            u = policy(policy_params, x)
            x += time_delta * dynamics(x, u)
            acc_cost += time_delta * cost(x, u)
        return acc_cost

    rng_init_params, rng = random.split(rng)
    _, init_policy_params = policy_init(rng_init_params, (4, ))
    opt = make_optimizer(optimizers.adam(1e-3))(init_policy_params)
    loss_and_grad = jit(value_and_grad(loss))

    loss_per_iter = []
    elapsed_per_iter = []
    x0s = vmap(sample_x0)(random.split(rng, num_iter))
    for i in range(num_iter):
        t0 = time.time()
        loss, g = loss_and_grad(opt.value, x0s[i])
        opt = opt.update(g)
        elapsed = time.time() - t0

        loss_per_iter.append(loss)
        elapsed_per_iter.append(elapsed)

        print(f"Episode {i}")
        print(f"    loss = {loss}")
        print(f"    elapsed = {elapsed}")

    blt.remember({
        "loss_per_iter": loss_per_iter,
        "elapsed_per_iter": elapsed_per_iter,
        "final_params": opt.value
    })

    plt.figure()
    plt.plot(loss_per_iter)
    plt.yscale("log")
    plt.title("ODE control of an inverted pendulum")
    plt.xlabel("Iteration")
    plt.ylabel(f"Policy cost (T = {total_secs}s)")

    # Viz
    num_viz_rollouts = 50
    framerate = 30
    timesteps = jnp.linspace(0,
                             int(T * time_delta),
                             num=int(T * time_delta * framerate))
    rollout = lambda x0: ode.odeint(
        lambda x, _: dynamics(x, policy(opt.value, x)), y0=x0, t=timesteps)

    plt.figure()
    states = rollout(jnp.zeros(2))
    plt.plot(states[:, 0], states[:, 1], marker=".")
    plt.xlabel("theta")
    plt.ylabel("theta dot")
    plt.title("Swing up trajectory")

    plt.figure()
    states = vmap(rollout)(x0s[:num_viz_rollouts])
    for i in range(num_viz_rollouts):
        plt.plot(states[i, :, 0], states[i, :, 1], marker='.', alpha=0.5)
    plt.xlabel("theta")
    plt.ylabel("theta dot")
    plt.title("Phase space trajectory")

    plot_control_contour(lambda x: policy(opt.value, x))
    plot_policy_dynamics(dynamics, lambda x: policy(opt.value, x))

    blt.show()
예제 #4
0
from jax.experimental import stax
from jax.experimental.stax import FanInConcat, Dense, Relu, Tanh

from research.estop import ddpg
from research.estop.pendulum import config
from research.estop.pendulum.env import viz_pendulum_rollout
from research.estop.utils import Scalarify
from research.statistax import Deterministic, Normal
from research.utils import make_optimizer
from research.estop import mdp

tau = 1e-4
buffer_size = 2**15
batch_size = 64
num_eval_rollouts = 128
opt_init = make_optimizer(optimizers.adam(step_size=1e-3))
init_noise = Normal(jp.array(0.0), jp.array(0.0))
noise = lambda _1, _2: Normal(jp.array(0.0), jp.array(0.5))

actor_init, actor = stax.serial(
    Dense(64),
    Relu,
    Dense(1),
    Tanh,
    stax.elementwise(lambda x: config.max_torque * x),
)

critic_init, critic = stax.serial(
    FanInConcat(),
    Dense(64),
    Relu,
예제 #5
0
def main():
    num_keypoints = 64
    gamma = 0.9
    rng = random.PRNGKey(0)

    ### Set up the problem/environment
    # xdot = Ax + Bu
    # u = - Kx
    # cost = xQx + uRu + 2xNu

    A = -0.1 * jp.eye(2)
    B = jp.eye(2)
    Q = jp.eye(2)
    R = jp.eye(2)
    N = jp.zeros((2, 2))

    dynamics_fn = lambda x, u: A @ x + B @ u
    cost_fn = lambda x, u: x.T @ Q @ x + u.T @ R @ u + 2 * x.T @ N @ u

    ### Solve the Riccatti equation to get the infinite-horizon optimal solution.
    K, _, _ = control.lqr(A, B, Q, R, N)
    K = jp.array(K)

    t0 = time.time()
    rng_x0_eval, rng_eval_keypoints, rng = random.split(rng, 3)
    x0_eval = random.normal(rng_x0_eval, (1000, 2))
    opt_all_costs = vmap(lambda rng, x0: policy_cost(
        dynamics_fn, cost_fn, lambda _, x: -K @ x, num_keypoints)
                         (None, rng, x0, gamma),
                         in_axes=(0, 0))(random.split(rng_eval_keypoints,
                                                      x0_eval.shape[0]),
                                         x0_eval)
    opt_cost = jp.mean(opt_all_costs)
    print(f"opt_cost = {opt_cost} in {time.time() - t0}s")

    ### Set up the learned policy model.
    policy_init, policy = stax.serial(
        Dense(64),
        Relu,
        Dense(64),
        Relu,
        Dense(2),
    )
    # policy_init, policy = DenseNoBias(2)

    rng_init_params, rng = random.split(rng)
    _, init_policy_params = policy_init(rng_init_params, (2, ))

    cost_and_grad = value_and_grad(
        policy_cost(dynamics_fn, cost_fn, policy, num_keypoints))
    opt = make_optimizer(optimizers.adam(1e-3))(init_policy_params)

    ### Main optimization loop.
    costs = []
    for i in range(1000):
        t0 = time.time()
        # gamma_i = gamma * (1 - 1 / (i / 1000 + 1.01))
        gamma_i = gamma
        rng_x0, rng_keypoints, rng = random.split(rng, 3)
        x0 = random.normal(rng_x0, (2, ))
        cost, g = cost_and_grad(opt.value, rng_keypoints, x0, gamma_i)
        opt = opt.update(g)
        print(f"Episode {i}:")
        print(f"    excess cost = {cost - opt_cost}")
        print(f"    gamma = {gamma_i}")
        print(f"    elapsed = {time.time() - t0}")
        costs.append(float(cost))

        if not jp.isfinite(cost):
            break

    print(f"Opt solution cost from starting point: {opt_cost}")
    # print(f"Gradient at opt solution: {opt_g}")

    # Print the identified and optimal policy. Note that layers multiply multipy
    # on the right instead of the left so we need a transpose.
    print(f"Est solution parameters: {opt.value}")
    print(f"Opt solution parameters: {K.T}")

    rng_eval_keypoints, rng = random.split(rng)
    est_all_costs = vmap(lambda rng, x0: policy_cost(dynamics_fn, cost_fn,
                                                     policy, num_keypoints)
                         (opt.value, rng, x0, gamma),
                         in_axes=(0, 0))(random.split(rng_eval_keypoints,
                                                      x0_eval.shape[0]),
                                         x0_eval)

    # ### Scatter plot of learned policy performance vs optimal policy performance.
    plt.figure()
    plt.scatter(est_all_costs, opt_all_costs)
    plt.plot([-100, 100], [-100, 100], color="gray")
    plt.xlim(0, jp.max(est_all_costs))
    plt.ylim(0, jp.max(opt_all_costs))
    plt.xlabel("Learned policy cost")
    plt.ylabel("Optimal cost")
    plt.title("Performance relative to the direct LQR solution")

    ### Plot performance per iteration, incl. average optimal policy performance.
    plt.figure()
    plt.plot(costs)
    plt.axhline(opt_cost, linestyle="--", color="gray")
    plt.yscale("log")
    plt.xlabel("Iteration")
    plt.ylabel(f"Cost (gamma = {gamma}, num_keypoints = {num_keypoints})")
    plt.legend(["Learned policy", "Direct LQR solution"])
    plt.title("ODE control of LQR problem (keypoint sampling)")

    ### Example rollout plots (learned policy vs optimal policy).
    x0 = jp.array([1.0, 2.0])
    framerate = 30
    total_secs = 10.0
    timesteps = jp.linspace(0, total_secs, num=int(total_secs * framerate))
    est_policy_rollout_states = ode.odeint(
        lambda x, _: dynamics_fn(x, policy(opt.value, x)), y0=x0, t=timesteps)
    est_policy_rollout_controls = vmap(lambda x: policy(opt.value, x))(
        est_policy_rollout_states)

    opt_policy_rollout_states = ode.odeint(lambda x, _: dynamics_fn(x, -K @ x),
                                           y0=x0,
                                           t=timesteps)
    opt_policy_rollout_controls = vmap(lambda x: -K @ x)(
        opt_policy_rollout_states)

    plt.figure()
    plt.plot(est_policy_rollout_states[:, 0],
             est_policy_rollout_states[:, 1],
             marker='.')
    plt.plot(opt_policy_rollout_states[:, 0],
             opt_policy_rollout_states[:, 1],
             marker='.')
    plt.xlabel("x_1")
    plt.ylabel("x_2")
    plt.legend(["Learned policy", "Direct LQR solution"])
    plt.title("Phase space trajectory")

    plt.figure()
    plt.plot(timesteps, jp.sqrt(jp.sum(est_policy_rollout_controls**2,
                                       axis=-1)))
    plt.plot(timesteps, jp.sqrt(jp.sum(opt_policy_rollout_controls**2,
                                       axis=-1)))
    plt.xlabel("time")
    plt.ylabel("control input (L2 norm)")
    plt.legend(["Learned policy", "Direct LQR solution"])
    plt.title("Policy control over time")

    ### Plot quiver field showing dynamics under learned policy.
    plot_policy_dynamics(dynamics_fn, cost_fn, lambda x: policy(opt.value, x))

    blt.show()
예제 #6
0
def main():
    rng = random.PRNGKey(0)
    x_dim = 2
    T = 20.0

    policy_init, policy = stax.serial(
        Dense(64),
        Tanh,
        Dense(x_dim),
    )

    x0 = jnp.ones(x_dim)

    A, B, Q, R, N = fixed_env(x_dim)
    print("System dynamics:")
    print(f"  A = {A}")
    print(f"  B = {B}")
    print(f"  Q = {Q}")
    print(f"  R = {R}")
    print(f"  N = {N}")
    print()

    dynamics_fn = lambda x, u: A @ x + B @ u
    cost_fn = lambda x, u: x.T @ Q @ x + u.T @ R @ u + 2 * x.T @ N @ u

    ### Evaluate LQR solution to get a sense of optimal cost.
    K, _, _ = control.lqr(A, B, Q, R, N)
    K = jnp.array(K)
    opt_policy_cost_fn = policy_cost_and_grad(dynamics_fn,
                                              cost_fn,
                                              lambda KK, x: -KK @ x,
                                              example_x=x0)
    opt_loss, _opt_K_grad = opt_policy_cost_fn(K, x0, T)

    # This is true for longer time horizons, but not true for shorter time
    # horizons due to the LQR solution being an infinite-time solution.
    # assert jnp.allclose(opt_K_grad, 0)

    ### Training loop.
    rng_init_params, rng = random.split(rng)
    _, init_policy_params = policy_init(rng_init_params, (x_dim, ))
    opt = make_optimizer(optimizers.adam(1e-3))(init_policy_params)
    loss_and_grad = policy_cost_and_grad(dynamics_fn,
                                         cost_fn,
                                         policy,
                                         example_x=x0)

    loss_per_iter = []
    elapsed_per_iter = []
    for iteration in range(10000):
        t0 = time.time()
        loss, g = loss_and_grad(opt.value, x0, T)
        opt = opt.update(g)
        elapsed = time.time() - t0

        loss_per_iter.append(loss)
        elapsed_per_iter.append(elapsed)

        print(f"Iteration {iteration}")
        print(f"    excess loss = {loss - opt_loss}")
        print(f"    elapsed = {elapsed}")

    blt.remember({
        "loss_per_iter": loss_per_iter,
        "elapsed_per_iter": elapsed_per_iter,
        "opt_loss": opt_loss
    })

    _, ax1 = plt.subplots()
    ax1.set_xlabel("Iteration")
    ax1.set_ylabel("Cost", color="tab:blue")
    ax1.set_yscale("log")
    ax1.tick_params(axis="y", labelcolor="tab:blue")
    ax1.plot(loss_per_iter, color="tab:blue", label="Total rollout cost")
    plt.axhline(opt_loss, linestyle="--", color="gray")
    ax1.legend(loc="upper left")
    plt.title("Combined fwd-bwd BVP problem")
    blt.show()
def main():
  total_time = 20.0
  gamma = 1.0
  x_dim = 2
  outer_loop_count = 10
  inner_loop_count = 1000
  rng = random.PRNGKey(0)

  x0 = jnp.array([2.0, 1.0])

  ### Set up the problem/environment
  # xdot = Ax + Bu
  # u = - Kx
  # cost = xQx + uRu + 2xNu
  A, B, Q, R, N = fixed_env(x_dim)
  print("System dynamics:")
  print(f"  A = {A}")
  print(f"  B = {B}")
  print(f"  Q = {Q}")
  print(f"  R = {R}")
  print(f"  N = {N}")
  print()

  dynamics_fn = lambda x, u: A @ x + B @ u
  # cost_fn = lambda x, u: x.T @ Q @ x + u.T @ R @ u + 2 * x.T @ N @ u
  position_cost_fn = lambda x: x.T @ Q @ x
  control_cost_fn = lambda u: u.T @ R @ u

  ### Solve the Riccatti equation to get the infinite-horizon optimal solution.
  K, _, _ = control.lqr(A, B, Q, R, N)
  K = jnp.array(K)

  _, (opt_x_cost_fwd, opt_u_cost_fwd,
      opt_xT_fwd), (opt_x_cost_bwd, opt_u_cost_bwd,
                    opt_x0_bwd) = policy_integrate_cost(dynamics_fn, position_cost_fn,
                                                        control_cost_fn, gamma,
                                                        lambda _, x: -K @ x)(None, x0, total_time)
  opt_cost_fwd = opt_x_cost_fwd + opt_u_cost_fwd
  print("LQR solution:")
  print(f"  K                     = {K}")
  print(f"  opt_x_cost_fwd        = {opt_x_cost_fwd}")
  print(f"  opt_u_cost_fwd        = {opt_u_cost_fwd}")
  print(f"  opt_x_cost_bwd        = {opt_x_cost_bwd}")
  print(f"  opt_u_cost_bwd        = {opt_u_cost_bwd}")
  print(f"  opt_cost_fwd          = {opt_cost_fwd}")
  print(f"  opt_xT_fwd            = {opt_xT_fwd}")
  print(f"  opt_x0_bwd            = {opt_x0_bwd}")
  print(f"  ||x0 - opt_x0_bwd||^2 = {jnp.sum((x0 - opt_x0_bwd)**2)}")
  print()

  ### Set up the learned policy model.
  policy_init, policy = stax.serial(
      Dense(64),
      Tanh,
      Dense(x_dim),
  )

  rng_init_params, rng = random.split(rng)
  _, init_policy_params = policy_init(rng_init_params, (x_dim, ))
  init_opt = make_optimizer(optimizers.adam(1e-3))(init_policy_params)

  def inner_loop(opt, _):
    runny_run = policy_integrate_cost(dynamics_fn, position_cost_fn, control_cost_fn, gamma, policy)

    (y0_fwd, yT_fwd, y0_bwd), vjp = jax.vjp(runny_run, opt.value, x0, total_time)
    x_cost_T_fwd, u_cost_T_fwd, xT_fwd = yT_fwd
    x_cost_0_bwd, u_cost_0_bwd, x0_bwd = y0_bwd

    yT_fwd_bar = (jnp.ones(()), jnp.ones(()), jnp.zeros_like(x0))
    g, _, _ = vjp((zeros_like_tree(y0_fwd), yT_fwd_bar, zeros_like_tree(y0_bwd)))

    return opt.update(g), Record(x_cost_T_fwd, u_cost_T_fwd, xT_fwd, x_cost_0_bwd, u_cost_0_bwd,
                                 x0_bwd)

  def outer_loop(opt, last: Record, elapsed=None):
    x_cost_T_fwd, u_cost_T_fwd, xT_fwd, x_cost_0_bwd, u_cost_0_bwd, x0_bwd = last
    print(f"Episode {opt.iteration}:")
    print(f"  excess fwd cost = {(x_cost_T_fwd + u_cost_T_fwd) - opt_cost_fwd}")
    print(f"    excess fwd x cost = {x_cost_T_fwd - opt_x_cost_fwd}")
    print(f"    excess fwd u cost = {u_cost_T_fwd - opt_u_cost_fwd}")
    print(f"  bwd cost        = {x_cost_0_bwd + u_cost_0_bwd}")
    print(f"    bwd x cost        = {x_cost_0_bwd}")
    print(f"    bwd u cost        = {u_cost_0_bwd}")
    print(f"  bwd x0 - x0     = {x0_bwd - x0}")
    print(f"  fwd xT          = {xT_fwd}")
    print(f"  fwd xT norm sq. = {jnp.sum(xT_fwd**2)}")
    print(f"  elapsed/iter    = {elapsed/inner_loop_count}s")

  ### Main optimization loop.
  t1 = time.time()
  _, history = fruity_loops(outer_loop, inner_loop, outer_loop_count, inner_loop_count, init_opt)
  print(f"total elapsed = {time.time() - t1}s")

  blt.remember({"history": history})

  cost_T_fwd_per_iter = history.x_cost_T_fwd_per_iter + history.u_cost_T_fwd_per_iter
  cost_0_bwd_per_iter = history.x_cost_0_bwd_per_iter + history.u_cost_0_bwd_per_iter

  ### Plot performance per iteration, incl. average optimal policy performance.
  _, ax1 = plt.subplots()
  ax1.set_xlabel("Iteration")
  ax1.set_ylabel("Cost", color="tab:blue")
  ax1.set_yscale("log")
  ax1.tick_params(axis="y", labelcolor="tab:blue")
  ax1.plot(cost_T_fwd_per_iter, color="tab:blue", label="Total rollout cost")
  ax1.plot(history.x_cost_T_fwd_per_iter,
           linestyle="dotted",
           color="tab:blue",
           label="Position cost")
  ax1.plot(history.u_cost_T_fwd_per_iter,
           linestyle="dashed",
           color="tab:blue",
           label="Control cost")
  plt.axhline(opt_cost_fwd, linestyle="--", color="gray")
  ax1.legend(loc="upper left")

  ax2 = ax1.twinx()
  ax2.set_ylabel("Error", color="tab:red")
  ax2.set_yscale("log")
  ax2.tick_params(axis="y", labelcolor="tab:red")
  ax2.plot(cost_0_bwd_per_iter**2, alpha=0.5, color="tab:red", label="Cost rewind error")
  ax2.plot(jnp.sum((history.x0_bwd_per_iter - x0)**2, axis=-1),
           alpha=0.5,
           color="tab:purple",
           label="x(0) rewind error")
  ax2.plot(jnp.sum(history.xT_fwd_per_iter**2, axis=-1),
           alpha=0.5,
           color="tab:brown",
           label="x(T) squared norm")
  ax2.legend(loc="upper right")

  plt.title("ODE control of LQR problem")

  blt.show()
예제 #8
0
def main():
  num_keypoints = 64
  train_batch_size = 64
  eval_batch_size = 1024
  gamma = 0.9
  rng = random.PRNGKey(0)

  ### Set up the problem/environment
  # xdot = Ax + Bu
  # u = - Kx
  # cost = xQx + uRu + 2xNu

  A = -0.1 * jp.eye(1)
  B = jp.eye(1)
  Q = jp.eye(1)
  R = jp.eye(1)
  N = jp.zeros((1, 1))

  cost_fn = lambda x, u: x.T @ Q @ x + u.T @ R @ u + 2 * x.T @ N @ u

  ### Solve the Riccatti equation to get the infinite-horizon optimal solution.
  K, _, _ = control.lqr(A, B, Q, R, N)
  K = jp.array(K)

  def loss(batch_size):
    def lossy_loss(KK, rng):
      rng_t, rng_x0 = random.split(rng)
      x0 = random.normal(rng_x0, shape=(batch_size, ))
      t = random.exponential(rng_t, shape=(num_keypoints, )) / -jp.log(gamma)
      x_t = jp.outer(jp.exp(t * jp.squeeze(A - B @ KK)), x0)
      costs = vmap(lambda x: cost_fn(x, -KK @ x))(jp.reshape(x_t, (-1, 1)))
      return jp.mean(costs)

    return lossy_loss

  t0 = time.time()
  rng_eval_keypoints, rng = random.split(rng)
  opt_all_costs = loss(eval_batch_size)(K, rng_eval_keypoints)
  opt_cost = jp.mean(opt_all_costs)
  print(f"opt_cost = {opt_cost} in {time.time() - t0}s")

  ### Set up the learned policy model.
  rng_init_params, rng = random.split(rng)
  opt = make_optimizer(optimizers.adam(1e-3))(random.normal(rng_init_params, shape=(1, 1)))
  cost_and_grad = jit(value_and_grad(loss(train_batch_size)))

  ### Main optimization loop.
  costs = []
  for i in range(10000):
    t0 = time.time()
    rng_iter, rng = random.split(rng)
    cost, g = cost_and_grad(opt.value, rng_iter)
    opt = opt.update(g)
    print(f"Episode {i}: excess cost = {cost - opt_cost}, elapsed = {time.time() - t0}")
    costs.append(float(cost))

  print(f"Opt solution cost from starting point: {opt_cost}")
  print(f"Est solution parameters: {opt.value}")
  print(f"Opt solution parameters: {K}")

  ### Plot performance per iteration, incl. average optimal policy performance.
  plt.figure()
  plt.plot(costs)
  plt.axhline(opt_cost, linestyle="--", color="gray")
  plt.yscale("log")
  plt.xlabel("Iteration")
  plt.ylabel(f"Cost (gamma = {gamma}, num_keypoints = {num_keypoints})")
  plt.legend(["Learned policy", "Direct LQR solution"])
  plt.title("ODE control of LQR problem (keypoint sampling)")

  plt.show()
def main():
    total_time = 20.0
    gamma = 1.0
    x_dim = 2
    z_dim = 32
    rng = random.PRNGKey(0)

    x0 = jp.array([2.0, 1.0])

    ### Set up the problem/environment
    # xdot = Ax + Bu
    # u = - Kx
    # cost = xQx + uRu + 2xNu
    A, B, Q, R, N = fixed_env(x_dim)
    dynamics_fn = lambda x, u: A @ x + B @ u
    cost_fn = lambda x, u: x.T @ Q @ x + u.T @ R @ u + 2 * x.T @ N @ u
    policy_loss = policy_integrate_cost(x_dim, z_dim, dynamics_fn, cost_fn,
                                        gamma)

    ### Solve the Riccatti equation to get the infinite-horizon optimal solution.
    K, _, _ = control.lqr(A, B, Q, R, N)
    K = jp.array(K)

    t0 = time.time()
    opt_y_fwd, opt_y_bwd = policy_loss(lambda _, x: -K @ x)(None, x0,
                                                            total_time)
    opt_cost = opt_y_fwd[1, 0]
    print(f"opt_cost = {opt_cost} in {time.time() - t0}s")
    print(opt_y_fwd)
    print(opt_y_bwd)
    print(f"l2 error: {jp.sqrt(jp.sum((opt_y_fwd - opt_y_bwd)**2))}")

    ### Set up the learned policy model.
    policy_init, policy = stax.serial(
        Dense(64),
        Tanh,
        Dense(x_dim),
    )

    rng_init_params, rng = random.split(rng)
    _, init_policy_params = policy_init(rng_init_params, (x_dim, ))
    opt = make_optimizer(optimizers.adam(1e-3))(init_policy_params)
    runny_run = jit(policy_loss(policy))

    ### Main optimization loop.
    costs = []
    bwd_errors = []
    for i in range(5000):
        t0 = time.time()
        (y_fwd, y_bwd), vjp = jax.vjp(runny_run, opt.value, x0, total_time)
        cost = y_fwd[1, 0]

        y_fwd_bar = jax.ops.index_update(jp.zeros_like(y_fwd), (1, 0), 1)
        g, _, _ = vjp((y_fwd_bar, jp.zeros_like(y_bwd)))
        opt = opt.update(g)

        bwd_err = jp.sqrt(jp.sum((y_fwd - y_bwd)**2))
        bwd_errors.append(bwd_err)

        print(f"Episode {i}:")
        print(f"    excess cost = {cost - opt_cost}")
        print(f"    bwd error = {bwd_err}")
        print(f"    elapsed = {time.time() - t0}")
        costs.append(float(cost))

    print(f"Opt solution cost from starting point: {opt_cost}")

    ### Plot performance per iteration, incl. average optimal policy performance.
    _, ax1 = plt.subplots()
    ax1.set_xlabel("Iteration")
    ax1.set_ylabel("Cost", color="tab:blue")
    ax1.set_yscale("log")
    ax1.tick_params(axis="y", labelcolor="tab:blue")
    ax1.plot(costs, color="tab:blue")
    plt.axhline(opt_cost, linestyle="--", color="gray")

    ax2 = ax1.twinx()
    ax2.set_ylabel("Backward solve L2 error", color="tab:red")
    ax2.set_yscale("log")
    ax2.tick_params(axis="y", labelcolor="tab:red")
    ax2.plot(bwd_errors, color="tab:red")
    plt.title("ODE control of LQR problem")

    blt.show()