class TD3(object): """Agent class that handles the training of the networks and provides outputs as actions. Args: state_dim (array): state size action_dim (array): action size policy_noise (float): how much noise to add to actions device (device): cuda or cpu to process the tensors discount (float): discount factor tau (float): soft update for main networks to target networks policy_noise (float): noise factor noise_clip (float): clip factor policy_freq (int): frequency of policy updates """ def __init__(self, state_dim, action_dim, max_action, discount, tau, policy_noise, noise_clip, policy_freq, device): self.state_dim = len(state_dim[0]) self.action_dim = len(action_dim) self.max_action = max_action[2] self.actor = Actor(self.state_dim, self.action_dim, self.max_action).to(device) self.actor_target = copy.deepcopy(self.actor).float() # self.actor_target = Actor(state_dim, action_dim, self.max_action).to(device) # self.actor_target.load_state_dict(self.actor.state_dict()) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4) # or 1e-3 self.critic = Critic(self.state_dim, self.action_dim).to(device) self.critic_target = copy.deepcopy(self.critic).float() # self.critic_target = Critic(state_dim, action_dim).to(device) # self.critic_target.load_state_dict(self.critic.state_dict()) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4) # or 1e-2 self.device = device self.max_action = max_action self.discount = discount self.tau = tau self.policy_noise = policy_noise self.noise_clip = noise_clip self.policy_freq = policy_freq self.total_it = 0 def select_action(self, state): """Select an appropriate action from the agent policy Args: state (array): current state of environment Returns: action (float): action clipped within action range """ state = torch.FloatTensor(state.reshape(1, -1)).to(self.device) # if noise != 0: # action_dim = len(self.env.action_space()) # action = (action + np.random.normal(0, noise, size=action_dim)) # action_space_low, _, action_space_high = self.env.action_domain() # return action.clip(action_space_low, action_space_high) return self.actor(state).cpu().data.numpy().flatten() def train(self, replay_buffer, batch_size=100): """Train and update actor and critic networks Args: replay_buffer (ReplayBuffer): buffer for experience replay batch_size(int): batch size to sample from replay buffer Return: actor_loss (float): loss from actor network critic_loss (float): loss from critic network """ self.total_it += 1 # Sample replay buffer state, next_state, action, reward, done = replay_buffer.sample( batch_size) state = torch.from_numpy( np.asarray([np.array(i.item().values()) for i in state])) next_state = np.asarray( [np.array(i.item().values()) for i in next_state]) reward = torch.as_tensor(reward, dtype=torch.float32) done = torch.as_tensor(done, dtype=torch.float32) with torch.no_grad(): # select an action according to the policy an add clipped noise # need to select set of actions noise = (torch.rand_like(torch.from_numpy(action)) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip) next_action = (self.actor_target( torch.tensor(next_state, dtype=torch.float32)) + torch.tensor(noise, dtype=torch.float32)).clamp( self.max_action[0], self.max_action[2]) # next_action_d =torch.as_tensor(next_action, dtype=torch.double) # Compute the target Q value target_Q1, target_Q2 = self.critic(state, next_action) target_Q = torch.min(target_Q1, target_Q2) target_Q = reward + done * self.discount * target_Q # update action datatype, can't do earlier, use np.array earlier action = torch.as_tensor(action, dtype=torch.float32) # get current Q estimates current_Q1, current_Q2 = self.critic(state, action) # compute critic loss critic_loss = F.mse_loss(current_Q1, target_Q[:1, :].transpose( 0, 1)) + F.mse_loss(current_Q2, target_Q[:1, :].transpose(0, 1)) # optimize the critic self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() # delayed policy updates if self.total_it % self.policy_freq == 0: # compute the actor loss actor_loss = -self.critic.get_q(state, self.actor(state)).mean() # optimize the actor self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() # Update the frozen target models for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) def save(self, filename, directory): torch.save(self.actor.state_dict(), '%s/%s_actor.pth' % (directory, filename)) torch.save(self.critic.state_dict(), '%s/%s_critic.pth' % (directory, filename)) def load(self, filename="best_avg", directory="./saves"): self.actor.load_state_dict( torch.load('%s/%s_actor.pth' % (directory, filename))) self.critic.load_state_dict( torch.load('%s/%s_critic.pth' % (directory, filename)))
class TD3(object): """Agent class that handles the training of the networks and provides outputs as actions. """ def __init__(self): state_dim = cons.STATE_DIM.flatten().shape[0] action_dim = cons.ACTION_DIM self.actor = Actor(state_dim, action_dim, cons.MAX_ACTION).to(cons.DEVICE) self.actor_target = Actor(state_dim, action_dim, cons.MAX_ACTION).to(cons.DEVICE) self.actor_target.load_state_dict(self.actor.state_dict()) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4) # or 1e-3 self.critic = Critic(state_dim, action_dim).to(cons.DEVICE) self.critic_target = Critic(state_dim, action_dim).to(cons.DEVICE) self.critic_target.load_state_dict(self.critic.state_dict()) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4) # or 1e-3 self.total_it = 0 self.critic_loss_plot = [] self.actor_loss_plot = [] def select_action(self, state, noise=cons.POLICY_NOISE): """Select an appropriate action from the agent policy Args: state (array): current state of environment noise (float): how much noise to add to actions Returns: action (list): nn action results """ state = torch.FloatTensor(state).to(cons.DEVICE) action = self.actor(state) # action space noise introduces noise to change the likelihoods of each action the agent might take if noise != 0: # creates tensor of gaussian noise noise = torch.clamp(torch.randn(14, dtype=torch.float32, device='cuda') * noise, min=-cons.NOISE_CLIP, max=cons.NOISE_CLIP) action = action + noise torch.clamp(action, min=cons.MIN_ACTION, max=cons.MAX_ACTION) return action def train(self, replay_buffer, iterations): """Train and update actor and critic networks Args: replay_buffer (ReplayBuffer): buffer for experience replay iterations (int): how many times to run training Return: actor_loss (float): loss from actor network critic_loss (float): loss from critic network """ for it in range(iterations): self.total_it += 1 # keep track of the total training iterations # Sample replay buffer (priority replay) # choose type of replay if cons.PRIORITY: state, action, reward, next_state, done, weights, indexes = replay_buffer.sample(cons.BATCH_SIZE, beta=cons.BETA_SCHED.value(it)) else: state, action, reward, next_state, done = replay_buffer.sample(cons.BATCH_SIZE) state = torch.from_numpy(state).float().to(cons.DEVICE) # torch.Size([100, 14]) next_state = torch.from_numpy(next_state).float().to(cons.DEVICE) # torch.Size([100, 14]) action = torch.from_numpy(action).float().to(cons.DEVICE) # torch.Size([100, 14]) reward = torch.as_tensor(reward, dtype=torch.float32).to(cons.DEVICE) # torch.Size([100]) done = torch.as_tensor(done, dtype=torch.float32).to(cons.DEVICE) # torch.Size([100]) with torch.no_grad(): # select an action according to the policy and add clipped noise next_action = self.actor_target(next_state) noise = torch.clamp(torch.randn((100, 14), dtype=torch.float32, device='cuda') * cons.POLICY_NOISE, min=-cons.NOISE_CLIP, max=cons.NOISE_CLIP) next_action = torch.clamp((next_action + noise), min=cons.MIN_ACTION, max=cons.MAX_ACTION) # Compute the target Q value target_q1, target_q2 = self.critic(state.float(), next_action.float()) target_q = torch.min(target_q1, target_q2) gamma = torch.ones((100, 1), dtype=torch.float32, device='cuda') gamma = gamma.new_full((100, 1), cons.GAMMA) target_q = reward.unsqueeze(1) + (done.unsqueeze(1) * gamma * target_q).detach() # get current Q estimates current_q1, current_q2 = self.critic(state.float(), action.float()) # compute critic loss critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q) cons.TD3_REPORT.write_critic_loss(self.total_it, it, critic_loss) # optimize the critic self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() # using the minimum of the q values as the weight, use min to prevent overestimation if cons.PRIORITY: new_priorities = torch.flatten(torch.min(current_q1, current_q2)) # convert any negative priorities to a minimum value, can't have a negative priority new_priorities = torch.clamp(new_priorities, min=0.0000001).tolist() # convert to a list for storage replay_buffer.update_priorities(indexes, new_priorities) # delayed policy updates if it % cons.POLICY_FREQ == 0: # update the actor policy less frequently # compute the actor loss q_action = self.actor(state).float().detach() actor_loss = -self.critic.get_q(state, q_action).mean() cons.TD3_REPORT.write_actor_loss(self.total_it, it, actor_loss, 1) # optimize the actor self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() self.actor_loss_plot.append(actor_loss.item()) # Update the frozen right_target models for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): target_param.data.copy_(cons.TAU * param.data + (1 - cons.TAU) * target_param.data) for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): target_param.data.copy_(cons.TAU * param.data + (1 - cons.TAU) * target_param.data) def save(self, filename, directory): torch.save(self.actor.state_dict(), '%s/%s_actor.pth' % (directory, filename)) torch.save(self.critic.state_dict(), '%s/%s_critic.pth' % (directory, filename)) def load(self, filename="best_avg", directory="td3/saves/shared_agent"): self.actor.load_state_dict(torch.load('%s/%s_actor.pth' % (directory, filename))) self.critic.load_state_dict(torch.load('%s/%s_critic.pth' % (directory, filename)))