def update_actor_temp(self, states, actions, rewards, next_states, dones): for p in self.sac_net.target.parameters(): p.requires_grad = False for p in self.sac_net.critic.parameters(): p.requires_grad = False # update actor: actions, log_probs, aux_losses = self.sac_net.sample(states, training=True) q1, q2 = self.sac_net.critic(states, actions) q_old = torch.min(q1, q2) actor_loss = (self.sac_net.alpha.detach() * log_probs - q_old).mean() aux_losses = compute_sum_aux_losses(aux_losses) overall_loss = actor_loss + aux_losses self.actor_optimizer.zero_grad() overall_loss.backward() self.actor_optimizer.step() # update temp: temp_loss = (self.sac_net.log_alpha.exp() * (-log_probs.detach().mean() + self.action_size).detach()) self.log_alpha_optimizer.zero_grad() temp_loss.backward() self.log_alpha_optimizer.step() self.sac_net.alpha.data = self.sac_net.log_alpha.exp().detach() for p in self.sac_net.target.parameters(): p.requires_grad = True for p in self.sac_net.critic.parameters(): p.requires_grad = True return actor_loss, temp_loss
def update_critic(self, states, actions, rewards, next_states, dones): q1_current, q2_current, aux_losses = self.sac_net.critic(states, actions, training=True) with torch.no_grad(): next_actions, log_probs, _ = self.sac_net.sample(next_states) q1_next, q2_next = self.sac_net.target(next_states, next_actions) v_next = (torch.min(q1_next, q2_next) - self.sac_net.alpha.detach() * log_probs) q_target = (rewards + ((1 - dones) * self.gamma * v_next)).detach() critic_loss = F.mse_loss(q1_current, q_target) + F.mse_loss( q2_current, q_target) aux_losses = compute_sum_aux_losses(aux_losses) overall_loss = critic_loss + aux_losses self.critic_optimizer.zero_grad() overall_loss.backward() self.critic_optimizer.step() return critic_loss
def learn(self): output = {} states, actions, rewards, next_states, dones, others = self.memory.sample( device=self.device) actions = actions.squeeze(dim=1) next_actions = self.actor_target(next_states) noise = torch.randn_like(next_actions).mul(self.policy_noise) noise = noise.clamp(-self.noise_clip, self.noise_clip) next_actions += noise next_actions = torch.max( torch.min(next_actions, self.action_high.to(self.device)), self.action_low.to(self.device), ) target_Q1 = self.critic_1_target(next_states, next_actions) target_Q2 = self.critic_2_target(next_states, next_actions) target_Q = torch.min(target_Q1, target_Q2) target_Q = (rewards + ((1 - dones) * self.gamma * target_Q)).detach() # Optimize Critic 1: current_Q1, aux_losses_Q1 = self.critic_1(states, actions, training=True) loss_Q1 = F.mse_loss(current_Q1, target_Q) + compute_sum_aux_losses(aux_losses_Q1) self.critic_1_optimizer.zero_grad() loss_Q1.backward() self.critic_1_optimizer.step() # Optimize Critic 2: current_Q2, aux_losses_Q2 = self.critic_2(states, actions, training=True) loss_Q2 = F.mse_loss(current_Q2, target_Q) + compute_sum_aux_losses(aux_losses_Q2) self.critic_2_optimizer.zero_grad() loss_Q2.backward() self.critic_2_optimizer.step() # delayed actor updates if (self.step_count + 1) % self.policy_delay == 0: critic_out = self.critic_1(states, self.actor(states), training=True) actor_loss, actor_aux_losses = -critic_out[0], critic_out[1] actor_loss = actor_loss.mean() + compute_sum_aux_losses( actor_aux_losses) self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() self.soft_update(self.actor_target, self.actor, self.actor_tau) self.num_actor_updates += 1 output = { "loss/critic_1": { "type": "scalar", "data": loss_Q1.data.cpu().numpy(), "freq": 10, }, "loss/actor": { "type": "scalar", "data": actor_loss.data.cpu().numpy(), "freq": 10, }, } self.soft_update(self.critic_1_target, self.critic_1, self.critic_tau) self.soft_update(self.critic_2_target, self.critic_2, self.critic_tau) self.current_iteration += 1 return output
def update(self, n_epochs, mini_batch_size, states, actions, log_probs, returns, advantages): total_actor_loss = 0 total_critic_loss = 0 total_entropy_loss = 0 # multiple epochs for _ in range(n_epochs): # minibatch updates for ( state, action, old_pi_log_probs, return_batch, advantage, ) in self.get_minibatch(mini_batch_size, states, actions, log_probs, returns, advantages): (dist, value), aux_losses = self.ppo_net(state, training=True) entropy = dist.entropy().mean() # L_S new_pi_log_probs = dist.log_prob(action) ratio = self.get_ratio(new_pi_log_probs, old_pi_log_probs) L_CPI = ratio * advantage clipped_version = ( torch.clamp(ratio, 1.0 - self.eps, 1.0 + self.eps) * advantage) # loss and clipping actor_loss = -torch.min(L_CPI, clipped_version).mean() # L_CLIP critic_loss = ((return_batch - value).pow(2).mean() ) # L_VF (squared error loss) aux_losses = compute_sum_aux_losses(aux_losses) # overall loss loss = (self.critic_tau * critic_loss + self.actor_tau * actor_loss - self.entropy_tau * entropy + aux_losses) # calculate gradients and update the weights self.optimizer.zero_grad() loss.backward() self.optimizer.step() total_actor_loss += actor_loss.item() total_critic_loss += critic_loss.item() total_entropy_loss += entropy.item() average_actor_loss = total_actor_loss / ( n_epochs * (self.batch_size / self.mini_batch_size)) average_critic_loss = total_critic_loss / ( n_epochs * (self.batch_size / self.mini_batch_size)) average_entropy_loss = total_entropy_loss / ( n_epochs * (self.batch_size / self.mini_batch_size)) output = { "loss/critic": { "type": "scalar", "data": average_critic_loss, "freq": self.logging_freq, }, "loss/actor": { "type": "scalar", "data": average_actor_loss, "freq": self.logging_freq, }, "loss/entropy": { "type": "scalar", "data": average_entropy_loss, "freq": self.logging_freq, }, } return output