def policy_loss_update(self, sample): share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \ value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \ adv_targ, available_actions_batch = sample old_action_log_probs_batch = check(old_action_log_probs_batch).to( **self.tpdv) adv_targ = check(adv_targ).to(**self.tpdv) active_masks_batch = check(active_masks_batch).to(**self.tpdv) # Reshape to do in a single forward pass for all steps action_log_probs, dist_entropy = self.policy.evaluate_actions( obs_batch, rnn_states_batch, actions_batch, masks_batch, available_actions_batch, active_masks_batch) ratio = torch.exp(action_log_probs - old_action_log_probs_batch) surr1 = ratio * adv_targ surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ action_loss = ( -torch.sum(torch.min(surr1, surr2), dim=-1, keepdim=True) * active_masks_batch).sum() / active_masks_batch.sum() # update common and action network self.policy.actor_optimizer.zero_grad() (action_loss - dist_entropy * self.entropy_coef).backward() if self._use_max_grad_norm: grad_norm = nn.utils.clip_grad_norm_( self.policy.actor.parameters(), self.max_grad_norm) else: grad_norm = get_gard_norm(self.policy.actor.parameters()) self.policy.actor_optimizer.step() return action_loss, dist_entropy, grad_norm, ratio
def forward(self, obs, rnn_states, masks, available_actions=None, deterministic=False): """ Compute actions from the given inputs. :param obs: (np.ndarray / torch.Tensor) observation inputs into network. :param rnn_states: (np.ndarray / torch.Tensor) if RNN network, hidden states for RNN. :param masks: (np.ndarray / torch.Tensor) mask tensor denoting if hidden states should be reinitialized to zeros. :param available_actions: (np.ndarray / torch.Tensor) denotes which actions are available to agent (if None, all actions available) :param deterministic: (bool) whether to sample from action distribution or return the mode. :return actions: (torch.Tensor) actions to take. :return action_log_probs: (torch.Tensor) log probabilities of taken actions. :return rnn_states: (torch.Tensor) updated RNN hidden states. """ obs = check(obs).to(**self.tpdv) rnn_states = check(rnn_states).to(**self.tpdv) masks = check(masks).to(**self.tpdv) if available_actions is not None: available_actions = check(available_actions).to(**self.tpdv) actor_features = self.base(obs) if self._use_naive_recurrent_policy or self._use_recurrent_policy: actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks) actions, action_log_probs = self.act(actor_features, available_actions, deterministic) return actions, action_log_probs, rnn_states
def get_actions(self, obs, rnn_states, masks, available_actions=None, deterministic=False): obs = check(obs).to(**self.tpdv) rnn_states = check(rnn_states).to(**self.tpdv) masks = check(masks).to(**self.tpdv) if available_actions is not None: available_actions = check(available_actions).to(**self.tpdv) x = obs x = self.obs_prep(x) # common actor_features = self.common(x) if self._use_naive_recurrent_policy or self._use_recurrent_policy: actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks) actions, action_log_probs = self.act(actor_features, available_actions, deterministic) return actions, action_log_probs, rnn_states
def evaluate_actions(self, obs, rnn_states, action, masks, available_actions=None, active_masks=None): obs = check(obs).to(**self.tpdv) rnn_states = check(rnn_states).to(**self.tpdv) action = check(action).to(**self.tpdv) masks = check(masks).to(**self.tpdv) if active_masks is not None: active_masks = check(active_masks).to(**self.tpdv) x = obs x = self.obs_prep(x) actor_features = self.common(x) if self._use_naive_recurrent_policy or self._use_recurrent_policy: actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks) action_log_probs, dist_entropy = self.act.evaluate_actions( actor_features, action, available_actions, active_masks) return action_log_probs, dist_entropy
def value_loss_update(self, sample): share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \ value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \ adv_targ, available_actions_batch = sample value_preds_batch = check(value_preds_batch).to(**self.tpdv) return_batch = check(return_batch).to(**self.tpdv) active_masks_batch = check(active_masks_batch).to(**self.tpdv) values = self.policy.get_values(share_obs_batch, rnn_states_critic_batch, masks_batch) if self._use_popart: value_pred_clipped = value_preds_batch + ( values - value_preds_batch).clamp(-self.clip_param, self.clip_param) error_clipped = self.value_normalizer( return_batch) - value_pred_clipped error_original = self.value_normalizer(return_batch) - values else: value_pred_clipped = value_preds_batch + ( values - value_preds_batch).clamp(-self.clip_param, self.clip_param) error_clipped = return_batch - value_pred_clipped error_original = return_batch - values if self._use_huber_loss: value_loss_clipped = huber_loss(error_clipped, self.huber_delta) value_loss_original = huber_loss(error_original, self.huber_delta) else: value_loss_clipped = mse_loss(error_clipped) value_loss_original = mse_loss(error_original) if self._use_clipped_value_loss: value_loss = torch.max(value_loss_original, value_loss_clipped) else: value_loss = value_loss_original if self._use_value_active_masks: value_loss = (value_loss * active_masks_batch).sum() / active_masks_batch.sum() else: value_loss = value_loss.mean() self.policy.critic_optimizer.zero_grad() (value_loss * self.value_loss_coef).backward() if self._use_max_grad_norm: grad_norm = nn.utils.clip_grad_norm_( self.policy.critic.parameters(), self.max_grad_norm) else: grad_norm = get_gard_norm(self.policy.critic.parameters()) self.policy.critic_optimizer.step() return value_loss, grad_norm
def forward(self, share_obs, rnn_states, masks): share_obs = check(share_obs).to(**self.tpdv) rnn_states = check(rnn_states).to(**self.tpdv) masks = check(masks).to(**self.tpdv) critic_features = self.base(share_obs) if self._use_naive_recurrent_policy or self._use_recurrent_policy: critic_features, rnn_states = self.rnn(critic_features, rnn_states, masks) values = self.v_out(critic_features) return values, rnn_states
def get_policy_values(self, obs, rnn_states, masks): obs = check(obs).to(**self.tpdv) rnn_states = check(rnn_states).to(**self.tpdv) masks = check(masks).to(**self.tpdv) actor_features = self.base(obs) if self._use_naive_recurrent_policy or self._use_recurrent_policy: actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks) values = self.v_out(actor_features) return values
def get_values(self, share_obs, rnn_states, masks): share_obs = check(share_obs).to(**self.tpdv) rnn_states = check(rnn_states).to(**self.tpdv) masks = check(masks).to(**self.tpdv) share_x = share_obs share_x = self.share_obs_prep(share_x) critic_features = self.common(share_x) if self._use_naive_recurrent_policy or self._use_recurrent_policy: critic_features, rnn_states = self.rnn(critic_features, rnn_states, masks) values = self.v_out(critic_features) return values, rnn_states
def evaluate_actions(self, obs, rnn_states, action, masks, available_actions=None, active_masks=None): if self._mixed_obs: for key in obs.keys(): obs[key] = check(obs[key]).to(**self.tpdv) else: obs = check(obs).to(**self.tpdv) rnn_states = check(rnn_states).to(**self.tpdv) action = check(action).to(**self.tpdv) masks = check(masks).to(**self.tpdv) if available_actions is not None: available_actions = check(available_actions).to(**self.tpdv) if active_masks is not None: active_masks = check(active_masks).to(**self.tpdv) actor_features = self.base(obs) if self._use_naive_recurrent_policy or self._use_recurrent_policy: actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks) if self._use_influence_policy: mlp_obs = self.mlp(obs) actor_features = torch.cat([actor_features, mlp_obs], dim=1) action_log_probs, dist_entropy = self.act.evaluate_actions(actor_features, action, available_actions, active_masks = active_masks if self._use_policy_active_masks else None) values = self.v_out(actor_features) if self._use_policy_vhead else None return action_log_probs, dist_entropy, values
def forward(self, cent_obs, rnn_states, masks): """ Compute actions from the given inputs. :param cent_obs: (np.ndarray / torch.Tensor) observation inputs into network. :param rnn_states: (np.ndarray / torch.Tensor) if RNN network, hidden states for RNN. :param masks: (np.ndarray / torch.Tensor) mask tensor denoting if RNN states should be reinitialized to zeros. :return values: (torch.Tensor) value function predictions. :return rnn_states: (torch.Tensor) updated RNN hidden states. """ cent_obs = check(cent_obs).to(**self.tpdv) rnn_states = check(rnn_states).to(**self.tpdv) masks = check(masks).to(**self.tpdv) critic_features = self.base(cent_obs) if self._use_naive_recurrent_policy or self._use_recurrent_policy: critic_features, rnn_states = self.rnn(critic_features, rnn_states, masks) values = self.v_out(critic_features) return values, rnn_states
def get_policy_values(self, obs, rnn_states, masks): if self._mixed_obs: for key in obs.keys(): obs[key] = check(obs[key]).to(**self.tpdv) else: obs = check(obs).to(**self.tpdv) rnn_states = check(rnn_states).to(**self.tpdv) masks = check(masks).to(**self.tpdv) actor_features = self.base(obs) if self._use_naive_recurrent_policy or self._use_recurrent_policy: actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks) if self._use_influence_policy: mlp_obs = self.mlp(obs) actor_features = torch.cat([actor_features, mlp_obs], dim=1) values = self.v_out(actor_features) return values
def forward(self, share_obs, rnn_states, masks): if self._mixed_obs: for key in share_obs.keys(): share_obs[key] = check(share_obs[key]).to(**self.tpdv) else: share_obs = check(share_obs).to(**self.tpdv) rnn_states = check(rnn_states).to(**self.tpdv) masks = check(masks).to(**self.tpdv) critic_features = self.base(share_obs) if self._use_naive_recurrent_policy or self._use_recurrent_policy: critic_features, rnn_states = self.rnn(critic_features, rnn_states, masks) if self._use_influence_policy: mlp_share_obs = self.mlp(share_obs) critic_features = torch.cat([critic_features, mlp_share_obs], dim=1) values = self.v_out(critic_features) return values, rnn_states
def forward(self, obs, rnn_states, masks, available_actions=None, deterministic=False): obs = check(obs).to(**self.tpdv) rnn_states = check(rnn_states).to(**self.tpdv) masks = check(masks).to(**self.tpdv) if available_actions is not None: available_actions = check(available_actions).to(**self.tpdv) actor_features = self.base(obs) if self._use_recurrent_policy: actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks) actions, _ = self.act(actor_features, available_actions, deterministic) return actions, rnn_states
def evaluate_actions(self, obs, rnn_states, action, masks, available_actions=None, active_masks=None): obs = check(obs).to(**self.tpdv) rnn_states = check(rnn_states).to(**self.tpdv) action = check(action).to(**self.tpdv) masks = check(masks).to(**self.tpdv) if available_actions is not None: available_actions = check(available_actions).to(**self.tpdv) if active_masks is not None: active_masks = check(active_masks).to(**self.tpdv) actor_features = self.base(obs) if self._use_naive_recurrent_policy or self._use_recurrent_policy: actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks) action_log_probs, dist_entropy = self.act.evaluate_actions( actor_features, action, available_actions, active_masks=active_masks if self._use_policy_active_masks else None) values = self.v_out(actor_features) if self._use_policy_vhead else None return action_log_probs, dist_entropy, values
def act(self, obs): """ Act according to the state-action value model and an exploration policy :param state: current state :param step_exploration_time: step the exploration schedule :return: an action """ obs = check(obs).to(**self.tpdv) if len(obs.shape) < 2: obs = obs.unsqueeze(0) values = self.value_net(obs) actions = torch.argmax(values, axis=-1, keepdim=True) return actions.detach().numpy()
def forward(self, obs, rnn_states, masks, available_actions=None, deterministic=False): if self._mixed_obs: for key in obs.keys(): obs[key] = check(obs[key]).to(**self.tpdv) else: obs = check(obs).to(**self.tpdv) rnn_states = check(rnn_states).to(**self.tpdv) masks = check(masks).to(**self.tpdv) if available_actions is not None: available_actions = check(available_actions).to(**self.tpdv) actor_features = self.base(obs) if self._use_naive_recurrent_policy or self._use_recurrent_policy: actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks) if self._use_influence_policy: mlp_obs = self.mlp(obs) actor_features = torch.cat([actor_features, mlp_obs], dim=1) actions, action_log_probs = self.act(actor_features, available_actions, deterministic) return actions, action_log_probs, rnn_states
def forward(self, obs, rnn_states=None, masks=None, available_actions=None, deterministic=False): """ Plan an optimal trajectory from an initial state. :param state: s, the initial state of the agent :return: [a0, a1, a2...], a sequence of actions to perform """ obs = check(obs).to(**self.tpdv) values = self.value_net(obs) if deterministic: actions = torch.argmax(values, axis=-1, keepdim=True) else: print("only support greedy action while evaluating!") raise NotImplementedError return actions, rnn_states
def evaluate_actions(self, obs, rnn_states, action, masks, available_actions=None, active_masks=None): """ Compute log probability and entropy of given actions. :param obs: (torch.Tensor) observation inputs into network. :param action: (torch.Tensor) actions whose entropy and log probability to evaluate. :param rnn_states: (torch.Tensor) if RNN network, hidden states for RNN. :param masks: (torch.Tensor) mask tensor denoting if hidden states should be reinitialized to zeros. :param available_actions: (torch.Tensor) denotes which actions are available to agent (if None, all actions available) :param active_masks: (torch.Tensor) denotes whether an agent is active or dead. :return action_log_probs: (torch.Tensor) log probabilities of the input actions. :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs. """ obs = check(obs).to(**self.tpdv) rnn_states = check(rnn_states).to(**self.tpdv) action = check(action).to(**self.tpdv) masks = check(masks).to(**self.tpdv) if available_actions is not None: available_actions = check(available_actions).to(**self.tpdv) if active_masks is not None: active_masks = check(active_masks).to(**self.tpdv) actor_features = self.base(obs) if self._use_naive_recurrent_policy or self._use_recurrent_policy: actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks) action_log_probs, dist_entropy = self.act.evaluate_actions( actor_features, action, available_actions, active_masks=active_masks if self._use_policy_active_masks else None) return action_log_probs, dist_entropy
def ppo_update(self, sample): share_obs_batch, obs_batch, recurrent_hidden_states_batch, recurrent_hidden_states_critic_batch, actions_batch, \ value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \ adv_targ, available_actions_batch = sample old_action_log_probs_batch = check(old_action_log_probs_batch).to( **self.tpdv) adv_targ = check(adv_targ).to(**self.tpdv) value_preds_batch = check(value_preds_batch).to(**self.tpdv) return_batch = check(return_batch).to(**self.tpdv) active_masks_batch = check(active_masks_batch).to(**self.tpdv) # policy loss # Reshape to do in a single forward pass for all steps action_log_probs, dist_entropy = self.policy.evaluate_actions( obs_batch, recurrent_hidden_states_batch, actions_batch, masks_batch, available_actions_batch, active_masks_batch) ratio = torch.exp(action_log_probs - old_action_log_probs_batch) surr1 = ratio * adv_targ surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ action_loss = ( -torch.sum(torch.min(surr1, surr2), dim=-1, keepdim=True) * active_masks_batch).sum() / active_masks_batch.sum() # value loss values = self.policy.get_values(share_obs_batch, recurrent_hidden_states_critic_batch, masks_batch) if self._use_popart: value_pred_clipped = value_preds_batch + ( values - value_preds_batch).clamp(-self.clip_param, self.clip_param) error_clipped = self.value_normalizer( return_batch) - value_pred_clipped error_original = self.value_normalizer(return_batch) - values else: value_pred_clipped = value_preds_batch + ( values - value_preds_batch).clamp(-self.clip_param, self.clip_param) error_clipped = return_batch - value_pred_clipped error_original = return_batch - values if self._use_huber_loss: value_loss_clipped = huber_loss(error_clipped, self.huber_delta) value_loss_original = huber_loss(error_original, self.huber_delta) else: value_loss_clipped = mse_loss(error_clipped) value_loss_original = mse_loss(error_original) if self._use_clipped_value_loss: value_loss = torch.max(value_loss_original, value_loss_clipped) else: value_loss = value_loss_original if self._use_value_active_masks: value_loss = (value_loss * active_masks_batch).sum() / active_masks_batch.sum() else: value_loss = value_loss.mean() # update common and action network self.policy.optimizer.zero_grad() (action_loss - dist_entropy * self.entropy_coef + value_loss * self.value_loss_coef).backward() if self._use_max_grad_norm: grad_norm = nn.utils.clip_grad_norm_( self.policy.model.parameters(), self.max_grad_norm) else: grad_norm = get_gard_norm(self.policy.model.parameters()) self.policy.optimizer.step() return value_loss, action_loss, dist_entropy, grad_norm, ratio
def ppo_update(self, sample, update_actor=True): """ Update actor and critic networks. :param sample: (Tuple) contains data batch with which to update networks. :update_actor: (bool) whether to update actor network. :return value_loss: (torch.Tensor) value function loss. :return critic_grad_norm: (torch.Tensor) gradient norm from critic up9date. ;return policy_loss: (torch.Tensor) actor(policy) loss value. :return dist_entropy: (torch.Tensor) action entropies. :return actor_grad_norm: (torch.Tensor) gradient norm from actor update. :return imp_weights: (torch.Tensor) importance sampling weights. """ share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \ value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \ adv_targ, available_actions_batch = sample old_action_log_probs_batch = check(old_action_log_probs_batch).to( **self.tpdv) adv_targ = check(adv_targ).to(**self.tpdv) value_preds_batch = check(value_preds_batch).to(**self.tpdv) return_batch = check(return_batch).to(**self.tpdv) active_masks_batch = check(active_masks_batch).to(**self.tpdv) # Reshape to do in a single forward pass for all steps values, action_log_probs, dist_entropy = self.policy.evaluate_actions( share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, masks_batch, available_actions_batch, active_masks_batch) # actor update imp_weights = torch.exp(action_log_probs - old_action_log_probs_batch) surr1 = imp_weights * adv_targ surr2 = torch.clamp(imp_weights, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ if self._use_policy_active_masks: policy_action_loss = ( -torch.sum(torch.min(surr1, surr2), dim=-1, keepdim=True) * active_masks_batch).sum() / active_masks_batch.sum() else: policy_action_loss = -torch.sum( torch.min(surr1, surr2), dim=-1, keepdim=True).mean() policy_loss = policy_action_loss self.policy.actor_optimizer.zero_grad() if update_actor: (policy_loss - dist_entropy * self.entropy_coef).backward() if self._use_max_grad_norm: actor_grad_norm = nn.utils.clip_grad_norm_( self.policy.actor.parameters(), self.max_grad_norm) else: actor_grad_norm = get_gard_norm(self.policy.actor.parameters()) self.policy.actor_optimizer.step() # critic update value_loss = self.cal_value_loss(values, value_preds_batch, return_batch, active_masks_batch) if self.args.use_q_head: self.policy.q_optimizer.zero_grad() (value_loss * self.value_loss_coef).backward() if self._use_max_grad_norm: critic_grad_norm = nn.utils.clip_grad_norm_( self.policy.q_head.parameters(), self.max_grad_norm) else: critic_grad_norm = get_gard_norm( self.policy.q_head.parameters()) self.policy.q_optimizer.step() else: self.policy.critic_optimizer.zero_grad() (value_loss * self.value_loss_coef).backward() if self._use_max_grad_norm: critic_grad_norm = nn.utils.clip_grad_norm_( self.policy.critic.parameters(), self.max_grad_norm) else: critic_grad_norm = get_gard_norm( self.policy.critic.parameters()) self.policy.critic_optimizer.step() return value_loss, critic_grad_norm, policy_loss, dist_entropy, actor_grad_norm, imp_weights
def auxiliary_loss_update(self, sample): share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \ value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \ old_action_probs_batch, available_actions_batch = sample old_action_probs_batch = check(old_action_probs_batch).to(**self.tpdv) active_masks_batch = check(active_masks_batch).to(**self.tpdv) value_preds_batch = check(value_preds_batch).to(**self.tpdv) return_batch = check(return_batch).to(**self.tpdv) # Reshape to do in a single forward pass for all steps values, new_action_probs = self.policy.get_policy_values_and_probs( obs_batch, rnn_states_batch, masks_batch, available_actions_batch) # kl = sum p * log(p / q) = sum p*(logp-logq) = sum plogp - plogq # cross-entropy = sum -plogq eps = (old_action_probs_batch == 0) * 1e-8 old_action_log_probs_batch = torch.log(old_action_probs_batch + eps.float().detach()) eps = (new_action_probs == 0) * 1e-8 new_action_log_probs = torch.log(new_action_probs + eps.float().detach()) kl_divergence = torch.sum( (old_action_probs_batch * (old_action_log_probs_batch - new_action_log_probs)), dim=-1, keepdim=True) kl_loss = (kl_divergence * active_masks_batch).sum() / active_masks_batch.sum() if self._use_popart: value_pred_clipped = value_preds_batch + ( values - value_preds_batch).clamp(-self.clip_param, self.clip_param) error_clipped = self.value_normalizer( return_batch) - value_pred_clipped error_original = self.value_normalizer(return_batch) - values else: value_pred_clipped = value_preds_batch + ( values - value_preds_batch).clamp(-self.clip_param, self.clip_param) error_clipped = return_batch - value_pred_clipped error_original = return_batch - values if self._use_huber_loss: value_loss_clipped = huber_loss(error_clipped, self.huber_delta) value_loss_original = huber_loss(error_original, self.huber_delta) else: value_loss_clipped = mse_loss(error_clipped) value_loss_original = mse_loss(error_original) if self._use_clipped_value_loss: value_loss = torch.max(value_loss_original, value_loss_clipped) else: value_loss = value_loss_original if self._use_value_active_masks: value_loss = (value_loss * active_masks_batch).sum() / active_masks_batch.sum() else: value_loss = value_loss.mean() joint_loss = value_loss + self.clone_coef * kl_loss self.policy.actor_optimizer.zero_grad() joint_loss.backward() if self._use_max_grad_norm: grad_norm = nn.utils.clip_grad_norm_( self.policy.actor.parameters(), self.max_grad_norm) else: grad_norm = get_gard_norm(self.policy.actor.parameters()) self.policy.actor_optimizer.step() return joint_loss, grad_norm
def ppo_update(self, sample, turn_on=True): share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \ value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \ adv_targ, available_actions_batch = sample old_action_log_probs_batch = check(old_action_log_probs_batch).to(**self.tpdv) adv_targ = check(adv_targ).to(**self.tpdv) value_preds_batch = check(value_preds_batch).to(**self.tpdv) return_batch = check(return_batch).to(**self.tpdv) active_masks_batch = check(active_masks_batch).to(**self.tpdv) # Reshape to do in a single forward pass for all steps values, action_log_probs, dist_entropy, policy_values = self.policy.evaluate_actions(share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, masks_batch, available_actions_batch, active_masks_batch) # actor update ratio = torch.exp(action_log_probs - old_action_log_probs_batch) surr1 = ratio * adv_targ surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ if self._use_policy_active_masks: policy_action_loss = (-torch.sum(torch.min(surr1, surr2), dim=-1, keepdim=True) * active_masks_batch).sum() / active_masks_batch.sum() else: policy_action_loss = -torch.sum(torch.min(surr1, surr2), dim=-1, keepdim=True).mean() if self._use_policy_vhead: policy_value_loss = self.cal_value_loss(policy_values, value_preds_batch, return_batch, active_masks_batch) policy_loss = policy_action_loss + policy_value_loss * self.policy_value_loss_coef else: policy_loss = policy_action_loss self.policy.actor_optimizer.zero_grad() if turn_on: (policy_loss - dist_entropy * self.entropy_coef).backward() if self._use_max_grad_norm: actor_grad_norm = nn.utils.clip_grad_norm_(self.policy.actor.parameters(), self.max_grad_norm) else: actor_grad_norm = get_gard_norm(self.policy.actor.parameters()) self.policy.actor_optimizer.step() # critic update value_loss = self.cal_value_loss(values, value_preds_batch, return_batch, active_masks_batch) self.policy.critic_optimizer.zero_grad() (value_loss * self.value_loss_coef).backward() if self._use_max_grad_norm: critic_grad_norm = nn.utils.clip_grad_norm_(self.policy.critic.parameters(), self.max_grad_norm) else: critic_grad_norm = get_gard_norm(self.policy.critic.parameters()) self.policy.critic_optimizer.step() return value_loss, critic_grad_norm, policy_loss, dist_entropy, actor_grad_norm, ratio