def cal_value_loss(self, values, value_preds_batch, return_batch, active_masks_batch): if self._use_popart: value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(-self.clip_param, self.clip_param) error_clipped = self.value_normalizer(return_batch) - value_pred_clipped error_original = self.value_normalizer(return_batch) - values else: value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(-self.clip_param, self.clip_param) error_clipped = return_batch - value_pred_clipped error_original = return_batch - values if self._use_huber_loss: value_loss_clipped = huber_loss(error_clipped, self.huber_delta) value_loss_original = huber_loss(error_original, self.huber_delta) else: value_loss_clipped = mse_loss(error_clipped) value_loss_original = mse_loss(error_original) if self._use_clipped_value_loss: value_loss = torch.max(value_loss_original, value_loss_clipped) else: value_loss = value_loss_original if self._use_value_active_masks: value_loss = (value_loss * active_masks_batch).sum() / active_masks_batch.sum() else: value_loss = value_loss.mean() return value_loss
def value_loss_update(self, sample): share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \ value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \ adv_targ, available_actions_batch = sample value_preds_batch = check(value_preds_batch).to(**self.tpdv) return_batch = check(return_batch).to(**self.tpdv) active_masks_batch = check(active_masks_batch).to(**self.tpdv) values = self.policy.get_values(share_obs_batch, rnn_states_critic_batch, masks_batch) if self._use_popart: value_pred_clipped = value_preds_batch + ( values - value_preds_batch).clamp(-self.clip_param, self.clip_param) error_clipped = self.value_normalizer( return_batch) - value_pred_clipped error_original = self.value_normalizer(return_batch) - values else: value_pred_clipped = value_preds_batch + ( values - value_preds_batch).clamp(-self.clip_param, self.clip_param) error_clipped = return_batch - value_pred_clipped error_original = return_batch - values if self._use_huber_loss: value_loss_clipped = huber_loss(error_clipped, self.huber_delta) value_loss_original = huber_loss(error_original, self.huber_delta) else: value_loss_clipped = mse_loss(error_clipped) value_loss_original = mse_loss(error_original) if self._use_clipped_value_loss: value_loss = torch.max(value_loss_original, value_loss_clipped) else: value_loss = value_loss_original if self._use_value_active_masks: value_loss = (value_loss * active_masks_batch).sum() / active_masks_batch.sum() else: value_loss = value_loss.mean() self.policy.critic_optimizer.zero_grad() (value_loss * self.value_loss_coef).backward() if self._use_max_grad_norm: grad_norm = nn.utils.clip_grad_norm_( self.policy.critic.parameters(), self.max_grad_norm) else: grad_norm = get_gard_norm(self.policy.critic.parameters()) self.policy.critic_optimizer.step() return value_loss, grad_norm
def cal_value_loss(self, values, value_preds_batch, return_batch, active_masks_batch): """ Calculate value function loss. :param values: (torch.Tensor) value function predictions. :param value_preds_batch: (torch.Tensor) "old" value predictions from data batch (used for value clip loss) :param return_batch: (torch.Tensor) reward to go returns. :param active_masks_batch: (torch.Tensor) denotes if agent is active or dead at a given timesep. :return value_loss: (torch.Tensor) value function loss. """ if self._use_popart: value_pred_clipped = value_preds_batch + ( values - value_preds_batch).clamp(-self.clip_param, self.clip_param) error_clipped = self.value_normalizer( return_batch) - value_pred_clipped error_original = self.value_normalizer(return_batch) - values else: value_pred_clipped = value_preds_batch + ( values - value_preds_batch).clamp(-self.clip_param, self.clip_param) error_clipped = return_batch - value_pred_clipped error_original = return_batch - values if self._use_huber_loss: value_loss_clipped = huber_loss(error_clipped, self.huber_delta) value_loss_original = huber_loss(error_original, self.huber_delta) else: value_loss_clipped = mse_loss(error_clipped) value_loss_original = mse_loss(error_original) if self._use_clipped_value_loss: value_loss = torch.max(value_loss_original, value_loss_clipped) else: value_loss = value_loss_original if self._use_value_active_masks: value_loss = (value_loss * active_masks_batch).sum() / active_masks_batch.sum() else: value_loss = value_loss.mean() return value_loss
def auxiliary_loss_update(self, sample): share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \ value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \ old_action_probs_batch, available_actions_batch = sample old_action_probs_batch = check(old_action_probs_batch).to(**self.tpdv) active_masks_batch = check(active_masks_batch).to(**self.tpdv) value_preds_batch = check(value_preds_batch).to(**self.tpdv) return_batch = check(return_batch).to(**self.tpdv) # Reshape to do in a single forward pass for all steps values, new_action_probs = self.policy.get_policy_values_and_probs( obs_batch, rnn_states_batch, masks_batch, available_actions_batch) # kl = sum p * log(p / q) = sum p*(logp-logq) = sum plogp - plogq # cross-entropy = sum -plogq eps = (old_action_probs_batch == 0) * 1e-8 old_action_log_probs_batch = torch.log(old_action_probs_batch + eps.float().detach()) eps = (new_action_probs == 0) * 1e-8 new_action_log_probs = torch.log(new_action_probs + eps.float().detach()) kl_divergence = torch.sum( (old_action_probs_batch * (old_action_log_probs_batch - new_action_log_probs)), dim=-1, keepdim=True) kl_loss = (kl_divergence * active_masks_batch).sum() / active_masks_batch.sum() if self._use_popart: value_pred_clipped = value_preds_batch + ( values - value_preds_batch).clamp(-self.clip_param, self.clip_param) error_clipped = self.value_normalizer( return_batch) - value_pred_clipped error_original = self.value_normalizer(return_batch) - values else: value_pred_clipped = value_preds_batch + ( values - value_preds_batch).clamp(-self.clip_param, self.clip_param) error_clipped = return_batch - value_pred_clipped error_original = return_batch - values if self._use_huber_loss: value_loss_clipped = huber_loss(error_clipped, self.huber_delta) value_loss_original = huber_loss(error_original, self.huber_delta) else: value_loss_clipped = mse_loss(error_clipped) value_loss_original = mse_loss(error_original) if self._use_clipped_value_loss: value_loss = torch.max(value_loss_original, value_loss_clipped) else: value_loss = value_loss_original if self._use_value_active_masks: value_loss = (value_loss * active_masks_batch).sum() / active_masks_batch.sum() else: value_loss = value_loss.mean() joint_loss = value_loss + self.clone_coef * kl_loss self.policy.actor_optimizer.zero_grad() joint_loss.backward() if self._use_max_grad_norm: grad_norm = nn.utils.clip_grad_norm_( self.policy.actor.parameters(), self.max_grad_norm) else: grad_norm = get_gard_norm(self.policy.actor.parameters()) self.policy.actor_optimizer.step() return joint_loss, grad_norm
def ppo_update(self, sample): share_obs_batch, obs_batch, recurrent_hidden_states_batch, recurrent_hidden_states_critic_batch, actions_batch, \ value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \ adv_targ, available_actions_batch = sample old_action_log_probs_batch = check(old_action_log_probs_batch).to( **self.tpdv) adv_targ = check(adv_targ).to(**self.tpdv) value_preds_batch = check(value_preds_batch).to(**self.tpdv) return_batch = check(return_batch).to(**self.tpdv) active_masks_batch = check(active_masks_batch).to(**self.tpdv) # policy loss # Reshape to do in a single forward pass for all steps action_log_probs, dist_entropy = self.policy.evaluate_actions( obs_batch, recurrent_hidden_states_batch, actions_batch, masks_batch, available_actions_batch, active_masks_batch) ratio = torch.exp(action_log_probs - old_action_log_probs_batch) surr1 = ratio * adv_targ surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ action_loss = ( -torch.sum(torch.min(surr1, surr2), dim=-1, keepdim=True) * active_masks_batch).sum() / active_masks_batch.sum() # value loss values = self.policy.get_values(share_obs_batch, recurrent_hidden_states_critic_batch, masks_batch) if self._use_popart: value_pred_clipped = value_preds_batch + ( values - value_preds_batch).clamp(-self.clip_param, self.clip_param) error_clipped = self.value_normalizer( return_batch) - value_pred_clipped error_original = self.value_normalizer(return_batch) - values else: value_pred_clipped = value_preds_batch + ( values - value_preds_batch).clamp(-self.clip_param, self.clip_param) error_clipped = return_batch - value_pred_clipped error_original = return_batch - values if self._use_huber_loss: value_loss_clipped = huber_loss(error_clipped, self.huber_delta) value_loss_original = huber_loss(error_original, self.huber_delta) else: value_loss_clipped = mse_loss(error_clipped) value_loss_original = mse_loss(error_original) if self._use_clipped_value_loss: value_loss = torch.max(value_loss_original, value_loss_clipped) else: value_loss = value_loss_original if self._use_value_active_masks: value_loss = (value_loss * active_masks_batch).sum() / active_masks_batch.sum() else: value_loss = value_loss.mean() # update common and action network self.policy.optimizer.zero_grad() (action_loss - dist_entropy * self.entropy_coef + value_loss * self.value_loss_coef).backward() if self._use_max_grad_norm: grad_norm = nn.utils.clip_grad_norm_( self.policy.model.parameters(), self.max_grad_norm) else: grad_norm = get_gard_norm(self.policy.model.parameters()) self.policy.optimizer.step() return value_loss, action_loss, dist_entropy, grad_norm, ratio