def main():
    res_dir = make_results_folder("i2c_lqr_equivalence", 0, "", release=True)

    env = make_env(experiment)
    model = make_env_model(experiment.ENVIRONMENT, experiment.MODEL)

    experiment.N_INFERENCE = 1

    # redefine linear system
    model.xag = 10 * np.ones((env.dim_x, 1))
    model.zg_term = 10 * np.ones((env.dim_x, 1))
    model.a = model.xag - model.A @ model.xag
    env.a = model.a
    env.sig_eta = 0.0 * np.eye(env.dim_x)
    ug = np.zeros((env.dim_u, ))

    x_lqr, u_lqr, K_lqr, k_lqr, cost_lqr, P, p = finite_horizon_lqr(
        experiment.N_DURATION,
        model.A,
        model.a[:, 0],
        model.B,
        experiment.INFERENCE.Q,
        experiment.INFERENCE.R,
        model.x0[:, 0],
        model.xag[:, 0],
        ug,
        model.dim_x,
        model.dim_u,
    )
    from i2c.exp_types import CubatureQuadrature

    i2c = I2cGraph(
        sys=model,
        horizon=experiment.N_DURATION,
        Q=experiment.INFERENCE.Q,
        R=experiment.INFERENCE.R,
        Qf=experiment.INFERENCE.Qf,
        alpha=1e-5,  # 1e-6,
        alpha_update_tol=experiment.INFERENCE.alpha_update_tol,
        mu_u=np.zeros((experiment.N_DURATION, 1)),
        sig_u=1e2 * np.eye(1),
        mu_x_terminal=None,
        sig_x_terminal=None,
        inference=experiment.INFERENCE.inference,
        res_dir=None,
    )
    i2c.use_expert_controller = False
    for c in i2c.cells:
        c.state_action_independence = True

    # EM iteration
    i2c._forward_backward_msgs()
    i2c.plot_traj(0, dir_name=res_dir, filename="lqr")

    # compute riccati terms
    i2c._backward_ricatti_msgs()

    plot_trajectory(i2c, x_lqr, u_lqr, dir_name=res_dir)
    plot_controller(i2c, u_lqr, K_lqr, k_lqr, dir_name=res_dir)
    plot_value_function(i2c, P, p, dir_name=res_dir)
Exemple #2
0
def main():
    configure_plots()
    res_dir = make_results_folder(
        "i2c_nonlinear_covariance_control", 0, "", release=True
    )

    env = make_env(experiment)
    model = make_env_model(experiment.ENVIRONMENT, experiment.MODEL)

    i2c = I2cGraph(
        sys=model,
        horizon=experiment.N_DURATION,
        Q=experiment.INFERENCE.Q,
        R=experiment.INFERENCE.R,
        Qf=experiment.INFERENCE.Qf,
        alpha=experiment.INFERENCE.alpha,
        alpha_update_tol=experiment.INFERENCE.alpha_update_tol,
        mu_u=experiment.INFERENCE.mu_u,
        sig_u=experiment.INFERENCE.sig_u,
        mu_x_terminal=experiment.INFERENCE.mu_x_term,
        sig_x_terminal=experiment.INFERENCE.sig_x_term,
        inference=experiment.INFERENCE.inference,
        res_dir=res_dir,
    )
    for c in i2c.cells:
        c.use_expert_controller = False
    i2c._propagate = True

    policy = TimeIndexedLinearGaussianPolicy(
        experiment.POLICY_COVAR, experiment.N_DURATION, i2c.sys.dim_u, i2c.sys.dim_x
    )

    i2c.propagate()
    for i in tqdm(range(experiment.N_INFERENCE)):
        i2c.learn_msgs()

    i2c.plot_metrics(0, 0, dir_name=res_dir, filename="nonlinear_cc")

    policy.write(*i2c.get_local_linear_policy())

    xs, _, _, _ = env.batch_eval(policy=policy, n_eval=50, deterministic=True)
    env.plot_sim(xs, None, "final", res_dir)

    plot_covariance_control(i2c, xs, filename="nonlinear_cc", dir_name=res_dir)
def make_pendulum_cov_control_gif():
    import experiments.pendulum_known_act_reg_quad as experiment
    from i2c.policy.linear import TimeIndexedLinearGaussianPolicy
    from i2c.utils import covariance_2d

    model = make_env_model(experiment.ENVIRONMENT, experiment.MODEL)
    env = make_env(experiment)

    i2c = I2cGraph(
        sys=model,
        horizon=experiment.N_DURATION,
        Q=experiment.INFERENCE.Q,
        R=experiment.INFERENCE.R,
        Qf=experiment.INFERENCE.Qf,
        alpha=experiment.INFERENCE.alpha,
        alpha_update_tol=experiment.INFERENCE.alpha_update_tol,
        mu_u=experiment.INFERENCE.mu_u,
        # sig_u=experiment.INFERENCE.sig_u,
        sig_u=1.0 * np.eye(1),
        mu_x_terminal=experiment.INFERENCE.mu_x_term,
        sig_x_terminal=experiment.INFERENCE.sig_x_term,
        inference=experiment.INFERENCE.inference,
        res_dir=None,
    )
    for c in i2c.cells:
        c.use_expert_controller = False

    policy = TimeIndexedLinearGaussianPolicy(
        experiment.POLICY_COVAR, experiment.N_DURATION, i2c.sys.dim_u, i2c.sys.dim_x
    )

    i2c._propagate = False
    experiment.N_INFERENCE = 200
    iters = range(experiment.N_INFERENCE)
    gif_filename = os.path.join(DIR_NAME, "..", "assets", "p_cc_%ds.gif")
    stream = []
    for iter in tqdm(iters):
        i2c.learn_msgs()
        policy.write(*i2c.get_local_linear_policy())
        xs, _, _, _ = env.batch_eval(policy=policy, n_eval=500, deterministic=False)
        fig, ax = plt.subplots(1, 1)
        a = ax
        a.set_title(f"Pendulum Covariance Control\nIteration {iter:03d}")

        for i, x in enumerate(xs):
            a.plot(x[:, 0], x[:, 1], ".c", alpha=0.1, markersize=1)
            a.plot(
                x[-1, 0],
                x[-1, 1],
                ".c",
                alpha=1.0,
                label="rollouts" if i == 0 else None,
                markersize=1,
            )

        covariance_2d(i2c.sys.sig_x0, i2c.sys.x0, a, facecolor="k")
        a.plot(
            i2c.sys.x0[0], i2c.sys.x0[1], "xk", label="$\\mathbf{x}_0$", markersize=3
        )
        covariance_2d(i2c.sig_x_terminal, i2c.mu_x_terminal, a, facecolor="r")
        a.plot(
            i2c.mu_x_terminal[0],
            i2c.mu_x_terminal[1],
            "xr",
            label="$\\mathbf{x}_g$",
            markersize=3,
        )

        a.set_xlabel(i2c.sys.key[0])
        a.set_ylabel(i2c.sys.key[1])
        a.set_xlim(-np.pi / 4, 3 * np.pi / 2)
        a.set_ylim(-5, 5)
        a.legend(loc="lower left")
        fig.canvas.draw()
        image = np.frombuffer(fig.canvas.tostring_rgb(), dtype="uint8")
        stream.append(image.reshape(fig.canvas.get_width_height()[::-1] + (3,)))

        plt.close(fig)
    for T in [1, 2, 3, 4, 5, 10]:
        name = gif_filename % T
        fps = len(stream) / T
        imageio.mimsave(name, stream, fps=fps)
        optimize(name)
def run(experiment, res_dir, weight_path):
    env = make_env(experiment)
    model = make_env_model(experiment.ENVIRONMENT, experiment.MODEL)

    i2c = I2cGraph(
        model,
        experiment.N_DURATION,
        experiment.INFERENCE.Q,
        experiment.INFERENCE.R,
        experiment.INFERENCE.Qf,
        experiment.INFERENCE.alpha,
        experiment.INFERENCE.alpha_update_tol,
        experiment.INFERENCE.mu_u,
        experiment.INFERENCE.sig_u,
        experiment.INFERENCE.mu_x_term,
        experiment.INFERENCE.sig_x_term,
        experiment.INFERENCE.inference,
        res_dir=res_dir,
    )

    policy_class = ExpertTimeIndexedLinearGaussianPolicy
    policy_linear = TimeIndexedLinearGaussianPolicy(experiment.POLICY_COVAR,
                                                    experiment.N_DURATION,
                                                    i2c.sys.dim_u,
                                                    i2c.sys.dim_x)
    policy = policy_class(
        experiment.POLICY_COVAR,
        experiment.N_DURATION,
        i2c.sys.dim_u,
        i2c.sys.dim_x,
        soft=False,
    )

    if weight_path is not None:
        print("Loading i2c model with {}".format(weight_path))
        i2c.sys.model.load(weight_path)

    # initial marginal traj
    s_est = np.zeros((experiment.N_DURATION, model.dim_s))

    dim_terminal = i2c.Qf.shape[0]
    traj_eval = StochasticTrajectoryEvaluator(i2c.QR, i2c.Qf, i2c.z,
                                              i2c.z_term, dim_terminal)
    traj_eval_iter = StochasticTrajectoryEvaluator(i2c.QR, i2c.Qf, i2c.z,
                                                   i2c.z_term, dim_terminal)
    traj_eval_safe_iter = StochasticTrajectoryEvaluator(
        i2c.QR, i2c.Qf, i2c.z, i2c.z_term, dim_terminal)

    i2c.reset_metrics()

    if env.simulated:
        policy.zero()
        xs, ys, zs, z_term = env.batch_eval(policy, N_EVAL)
        env.plot_sim(xs, s_est, "initial", res_dir)
        traj_eval.eval(zs, z_term, zs[0], z_term[0])

    # inference
    try:
        for i in tqdm(range(experiment.N_INFERENCE)):
            plot = (i % experiment.N_ITERS_PER_PLOT
                    == 0) or (i == experiment.N_INFERENCE - 1)

            i2c.learn_msgs()

            if env.simulated:
                # eval policy
                policy_linear.write(*i2c.get_local_linear_policy())

                xs, ys, zs, zs_term = env.batch_eval(policy_linear, N_EVAL)
                z_est, z_term_est = i2c.get_marginal_observed_trajectory()
                traj_eval_iter.eval(zs, zs_term, z_est, z_term_est)

                policy.write(*i2c.get_local_expert_linear_policy())
                xs, ys, zs, zs_term = env.batch_eval(policy, N_EVAL)
                traj_eval_safe_iter.eval(zs, zs_term, z_est, z_term_est)

                logging.info(
                    f"{i:02d} Cost | Plan: {i2c.costs_m[-1]}, "
                    f"Predict: {i2c.costs_pf[-1]}, "
                    f"Sim: [{traj_eval_iter.actual_cost_10[-1]}, "
                    f"{traj_eval_iter.actual_cost_90[-1]}] "
                    f"alpha: {i2c.alphas[-1], i2c.alphas_desired[-1]}")

            if i == 0:  # see how well inference works at the start
                xs, ys, zs, zs_term = env.batch_eval(policy,
                                                     N_EVAL,
                                                     deterministic=False)
                env.plot_sim(xs, s_est, f"{i}_stochastic", res_dir)

            if plot:
                i2c.plot_metrics(0, i, res_dir, "msg")
                s_est = i2c.get_marginal_trajectory()
                env.plot_sim(xs, s_est, f"{i}_stochastic", res_dir)

        i2c.plot_metrics(0, i, res_dir, "msg")
    except Exception as ex:
        logging.exception("Inference failed")
        i2c.plot_metrics(0, i, res_dir, "esc")
        raise

    # update policy
    if env.simulated:
        # policy.write(*i2c.get_local_linear_policy())
        policy_linear.write(*i2c.get_local_linear_policy())
        z_est, z_term_est = i2c.get_marginal_observed_trajectory()
        xs, ys, zs, zs_term = env.batch_eval(policy_linear, N_EVAL)
        s_est = i2c.get_marginal_trajectory()
        env.plot_sim(xs, s_est, f"evaluation stochastic", res_dir)

        xs, ys, zs, zs_term = env.batch_eval(policy_linear, N_EVAL)
        env.plot_sim(xs, s_est, f"evaluation deterministic", res_dir)

        z_est, z_term_est = i2c.get_marginal_observed_trajectory()
        traj_eval_iter.eval(zs, zs_term, z_est, z_term_est)
        traj_eval.eval(zs, zs_term, z_est, z_term_est)
        traj_eval_iter.plot("over_iterations", res_dir)
        traj_eval.plot("over_episodes", res_dir)

    i2c.plot_alphas(res_dir, "final")
    i2c.plot_cost(res_dir, "cost_final")

    policy_linear.write(*i2c.get_local_linear_policy())
    x_final, y_final, _, _ = env.run(policy_linear)
    s_est = i2c.get_marginal_trajectory()
    env.plot_sim(x_final, s_est, "Final", res_dir)
    # generate gif for mujoco envs
    env.run_render(policy_linear, res_dir)

    policy_linear.zero()
    policy_linear.k = i2c.get_marginal_input().reshape(policy_linear.k.shape)
    x_ff, _, _, _ = env.run(policy_linear)
    env.plot_sim(x_ff, s_est, "Final Feedforward", res_dir)

    # save model and data
    save_trajectories(x_final, y_final, i2c, res_dir)
    traj_eval.save("episodic", res_dir)
    traj_eval_iter.save("iter", res_dir)
    i2c.save(res_dir, f"{i}")

    i2c.close()
    env.close()