示例#1
0
文件: DDPG.py 项目: zoetsekas/rl_lib
class DDPG:
    def __init__(self, state_space, action_space):
        self.actor = Actor(state_space, action_space).to(device)
        self.critic = Critic(state_space, action_space).to(device)

        self.actor_target = Actor(state_space, action_space).to(device)
        self.critic_target = Critic(state_space, action_space).to(device)

        self.actor_optimiser = optim.Adam(actor.parameters(), lr=1e-3)
        self.critic_optimiser = optim.Adam(critic.parameters(), lr=1e-3)

        self.mem = ReplayBuffer(buffer_size)

    def act(self, state, add_noise=False):
        return self.actor.act(state, add_noise)

    def save(self, fn):
        torch.save(self.actor.state_dict(), "{}_actor_model.pth".format(fn))
        torch.save(self.critic.state_dict(), "{}_critic_model.pth".format(fn))

    def learn(self):

        state_batch, action_batch, reward_batch, next_state_batch, masks = self.mem.sample(
            batch_size)

        state_batch = torch.FloatTensor(state_batch).to(device)
        action_batch = torch.FloatTensor(action_batch).to(device)
        reward_batch = torch.FloatTensor(reward_batch).to(device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(device)
        masks = torch.FloatTensor(masks).to(device)

        # Update Critic
        self.update_critic(states=state_batch,
                           next_states=next_state_batch,
                           actions=action_batch,
                           rewards=reward_batch,
                           dones=masks)

        # Update actor
        self.update_actor(states=state_batch)

        # Update target networks
        self.update_target_networks()

    def update_actor(self, states):
        actions_pred = self.actor(states)
        loss = -self.critic(states, actions_pred).mean()

        self.actor_optimiser.zero_grad()
        loss.backward()
        self.actor_optimiser.step()

    def update_critic(self, states, next_states, actions, rewards, dones):
        next_actions = self.actor_target.forward(next_states)

        y_i = rewards + (gamma *
                         self.critic_target(next_states, next_actions) *
                         (1 - dones))
        expected_Q = self.critic(states, actions)

        loss = F.mse_loss(y_i, expected_Q)

        self.critic_optimiser.zero_grad()
        loss.backward()
        self.critic_optimiser.step()

    def update_target_networks(self):
        for target, local in zip(self.actor_target.parameters(),
                                 self.actor.parameters()):
            target.data.copy_(tau * local.data + (1.0 - tau) * target.data)

        for target, local in zip(self.critic_target.parameters(),
                                 self.critic.parameters()):
            target.data.copy_(tau * local.data + (1.0 - tau) * target.data)
示例#2
0
def simulation(methods, log_dir, simu_dir):
    policy = Actor(S_DIM, A_DIM)
    value = Critic(S_DIM, A_DIM)
    config = DynamicsConfig()
    solver = Solver()
    load_dir = log_dir
    policy.load_parameters(load_dir)
    value.load_parameters(load_dir)
    statemodel_plt = Dynamics.VehicleDynamics()
    plot_length = config.SIMULATION_STEPS

    # initial_state = torch.tensor([[0.5, 0.0, config.psi_init, 0.0, 0.0]])
    # baseline = Baseline(initial_state, simu_dir)
    # baseline.mpcSolution()
    # baseline.openLoopSolution()

    # Open-loop reference
    x_init = [0.0, 0.0, config.psi_init, 0.0, 0.0]
    op_state, op_control = solver.openLoopMpcSolver(x_init, config.NP_TOTAL)
    np.savetxt(os.path.join(simu_dir, 'Open_loop_control.txt'), op_control)

    for method in methods:
        cal_time = 0
        state = torch.tensor([[0.0, 0.0, config.psi_init, 0.0, 0.0]])
        # state = torch.tensor([[0.0, 0.0, 0.0, 0.0, 0.0]])
        state.requires_grad_(False)
        # x_ref = statemodel_plt.reference_trajectory(state[:, -1])
        x_ref = statemodel_plt.reference_trajectory(state[:, -1])
        state_r = state.detach().clone()
        state_r[:, 0:4] = state_r[:, 0:4] - x_ref

        state_history = state.detach().numpy()
        control_history = []

        print('\nCALCULATION TIME:')
        for i in range(plot_length):
            if method == 'ADP':
                time_start = time.time()
                u = policy.forward(state_r[:, 0:4])
                cal_time += time.time() - time_start
            elif method == 'MPC':
                x = state_r.tolist()[0]
                time_start = time.time()
                _, control = solver.mpcSolver(x, config.NP)  # todo:retreve
                cal_time += time.time() - time_start
                u = np.array(control[0],
                             dtype='float32').reshape(-1, config.ACTION_DIM)
                u = torch.from_numpy(u)
            else:
                u = np.array(op_control[i],
                             dtype='float32').reshape(-1, config.ACTION_DIM)
                u = torch.from_numpy(u)

            state, state_r = step_relative(statemodel_plt, state, u)
            # state_next, deri_state, utility, F_y1, F_y2, alpha_1, alpha_2 = statemodel_plt.step(state, u)
            # state_r_old, _, _, _, _, _, _ = statemodel_plt.step(state_r, u)
            # state_r = state_r_old.detach().clone()
            # state_r[:, [0, 2]] = state_next[:, [0, 2]]
            # x_ref = statemodel_plt.reference_trajectory(state_next[:, -1])
            # state_r[:, 0:4] = state_r[:, 0:4] - x_ref
            # state = state_next.clone().detach()
            # s = state_next.detach().numpy()
            state_history = np.append(state_history,
                                      state.detach().numpy(),
                                      axis=0)
            control_history = np.append(control_history, u.detach().numpy())

        if method == 'ADP':
            print(" ADP: {:.3f}".format(cal_time) + "s")
            np.savetxt(os.path.join(simu_dir, 'ADP_state.txt'), state_history)
            np.savetxt(os.path.join(simu_dir, 'ADP_control.txt'),
                       control_history)

        elif method == 'MPC':
            print(" MPC: {:.3f}".format(cal_time) + "s")
            np.savetxt(os.path.join(simu_dir, 'structured_MPC_state.txt'),
                       state_history)
            np.savetxt(os.path.join(simu_dir, 'structured_MPC_control.txt'),
                       control_history)

        else:
            np.savetxt(os.path.join(simu_dir, 'Open_loop_state.txt'),
                       state_history)

    adp_simulation_plot(simu_dir)
    plot_comparison(simu_dir, methods)