Example #1
0
def simulation(methods, log_dir, simu_dir):
    policy = Actor(S_DIM, A_DIM)
    value = Critic(S_DIM, A_DIM)
    config = DynamicsConfig()
    solver = Solver()
    load_dir = log_dir
    policy.load_parameters(load_dir)
    value.load_parameters(load_dir)
    statemodel_plt = Dynamics.VehicleDynamics()
    plot_length = config.SIMULATION_STEPS

    # initial_state = torch.tensor([[0.5, 0.0, config.psi_init, 0.0, 0.0]])
    # baseline = Baseline(initial_state, simu_dir)
    # baseline.mpcSolution()
    # baseline.openLoopSolution()

    # Open-loop reference
    x_init = [0.0, 0.0, config.psi_init, 0.0, 0.0]
    op_state, op_control = solver.openLoopMpcSolver(x_init, config.NP_TOTAL)
    np.savetxt(os.path.join(simu_dir, 'Open_loop_control.txt'), op_control)

    for method in methods:
        cal_time = 0
        state = torch.tensor([[0.0, 0.0, config.psi_init, 0.0, 0.0]])
        # state = torch.tensor([[0.0, 0.0, 0.0, 0.0, 0.0]])
        state.requires_grad_(False)
        # x_ref = statemodel_plt.reference_trajectory(state[:, -1])
        x_ref = statemodel_plt.reference_trajectory(state[:, -1])
        state_r = state.detach().clone()
        state_r[:, 0:4] = state_r[:, 0:4] - x_ref

        state_history = state.detach().numpy()
        control_history = []

        print('\nCALCULATION TIME:')
        for i in range(plot_length):
            if method == 'ADP':
                time_start = time.time()
                u = policy.forward(state_r[:, 0:4])
                cal_time += time.time() - time_start
            elif method == 'MPC':
                x = state_r.tolist()[0]
                time_start = time.time()
                _, control = solver.mpcSolver(x, config.NP)  # todo:retreve
                cal_time += time.time() - time_start
                u = np.array(control[0],
                             dtype='float32').reshape(-1, config.ACTION_DIM)
                u = torch.from_numpy(u)
            else:
                u = np.array(op_control[i],
                             dtype='float32').reshape(-1, config.ACTION_DIM)
                u = torch.from_numpy(u)

            state, state_r = step_relative(statemodel_plt, state, u)
            # state_next, deri_state, utility, F_y1, F_y2, alpha_1, alpha_2 = statemodel_plt.step(state, u)
            # state_r_old, _, _, _, _, _, _ = statemodel_plt.step(state_r, u)
            # state_r = state_r_old.detach().clone()
            # state_r[:, [0, 2]] = state_next[:, [0, 2]]
            # x_ref = statemodel_plt.reference_trajectory(state_next[:, -1])
            # state_r[:, 0:4] = state_r[:, 0:4] - x_ref
            # state = state_next.clone().detach()
            # s = state_next.detach().numpy()
            state_history = np.append(state_history,
                                      state.detach().numpy(),
                                      axis=0)
            control_history = np.append(control_history, u.detach().numpy())

        if method == 'ADP':
            print(" ADP: {:.3f}".format(cal_time) + "s")
            np.savetxt(os.path.join(simu_dir, 'ADP_state.txt'), state_history)
            np.savetxt(os.path.join(simu_dir, 'ADP_control.txt'),
                       control_history)

        elif method == 'MPC':
            print(" MPC: {:.3f}".format(cal_time) + "s")
            np.savetxt(os.path.join(simu_dir, 'structured_MPC_state.txt'),
                       state_history)
            np.savetxt(os.path.join(simu_dir, 'structured_MPC_control.txt'),
                       control_history)

        else:
            np.savetxt(os.path.join(simu_dir, 'Open_loop_state.txt'),
                       state_history)

    adp_simulation_plot(simu_dir)
    plot_comparison(simu_dir, methods)
policy = Actor(config.STATE_DIM, config.ACTION_DIM, lr=LR_P)
value = Critic(config.STATE_DIM, 1, lr=LR_V)
vehicleDynamics = Dynamics.VehicleDynamics()
state_batch = vehicleDynamics.initialize_state()
writer = SummaryWriter()

# Training
iteration_index = 0
if LOAD_PARA_FLAG == 1:
    print(
        "********************************* LOAD PARAMETERS *********************************"
    )
    # load pre-trained parameters
    load_dir = "./Results_dir/2020-10-09-14-42-10000"
    policy.load_parameters(load_dir)
    value.load_parameters(load_dir)

if TRAIN_FLAG == 1:
    print_iters = 10
    print(
        "********************************** START TRAINING **********************************"
    )
    print("************************** PRINT LOSS EVERY " + str(print_iters) +
          "iterations ***************************")
    # train the network by policy iteration
    train = Train()
    # train.agent_batch = vehicleDynamics.initialize_state()
    if LOAD_PARA_FLAG == 1:
        train.agent_batch = torch.load(
            os.path.join(load_dir, 'agent_buffer.pth'))
        train.setInitState()