def run(args):
    # logger
    make_logger(args.result_dir)

    # make envs
    env = make_env(args)

    # make config
    config = make_config(args)

    # make planner
    planner = make_planner(args, config)

    # make model
    model = make_model(args, config)

    # make controller
    controller = make_controller(args, config, model)

    # make simulator
    runner = make_runner(args)

    # run experiment
    history_x, history_u, history_g = runner.run(env, controller, planner)

    # plot results
    plot_results(args, history_x, history_u, history_g=history_g)
    save_plot_data(args, history_x, history_u, history_g=history_g)
Beispiel #2
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--env", type=str, default="TwoWheeledTrack")
    # parser.add_argument("--env", type=str, default="CartPole")

    parser.add_argument("--save_anim", type=bool_flag, default=1)
    parser.add_argument("--controller_type", type=str, default="NMPCCGMRES")
    # parser.add_argument("--controller_type", type=str, default="MPPI")

    parser.add_argument("--result_dir", type=str, default="./result")
    parser.add_argument("--use_learning", type=str, default=True)
    parser.add_argument("--num_train_steps_per_iter",
                        type=np.ndarray,
                        default=[40000, 80000, 120000, 160000])
    parser.add_argument("--relabel_with_expert", type=str, default=True)

    args = parser.parse_args()

    trainer = IL_trainer(args)
    initial_expertdata_path = os.path.join(args.result_dir + '/' +
                                           args.controller_type)
    print(
        'Running rl_trainer for {0} with controller type {1} \n Using Dagger: {2}'
        .format(args.env, args.controller_type, args.relabel_with_expert))

    # config = make_config(args)
    traj_steps = trainer.env.config[
        "max_step"]  # make sure that the sampled trajectory length is the same as the initial training trajectory length

    trainer.run_training_loop(n_iter=4,
                              initial_expertdata=initial_expertdata_path,
                              relabel_with_expert=args.relabel_with_expert,
                              start_relabel_with_expert=1,
                              traj_steps=traj_steps)

    print('Finish training ...')

    # still testing, need to refactor
    make_logger(args.result_dir)
    config = make_config(args)
    planner = make_planner(args, config)

    history_x, history_u, history_g, cost = trainer.run(planner)
    plot_results(history_x, history_u, history_g=history_g, args=args)
    save_plot_data(history_x,
                   history_u,
                   history_g=history_g,
                   cost=cost,
                   args=args)

    if args.save_anim:
        animator = Animator(env=trainer.env, args=args)
        print("first in history_x", history_x[0])
        animator.draw(history_x, history_g)
Beispiel #3
0
def run(args):
    make_logger(args.result_dir)

    env = make_env(args)

    config = make_config(args)

    planner = make_planner(args, config)

    model = make_model(args, config)

    controller = make_controller(args, config, model)

    runner = make_runner(args)

    history_x_all, history_u_all, history_g_all, history_cost_all = [], [], [], [
    ]  # this is the collection list of n_sample trajectories
    for iter_sample in range(args.n_sample):
        print("Sampling {} th trajectory generated by expert policy:".format(
            iter_sample))
        history_x, history_u, history_g, cost = runner.run(
            env, controller, planner)
        history_x_all.append(history_x)
        history_u_all.append(history_u)
        history_g_all.append(history_g)
        history_cost_all.append(cost)

    plot_results(
        history_x, history_u, history_g=history_g,
        args=args)  # no need to change now, just see the plot of the last traj
    save_plot_data(history_x_all,
                   history_u_all,
                   history_g=history_g_all,
                   cost=history_cost_all,
                   args=args)  # save lists

    if args.save_anim:
        animator = Animator(env, args=args)
        animator.draw(history_x, history_g)