def _compute_next_value(self, experience_batch, policy_output): """Computes value of next state for target value computation.""" q_next = torch.min(*[critic_target(experience_batch.observation_next, self._prep_action(policy_output.action)).q for critic_target in self.critic_targets]) next_val = (q_next - self.alpha * policy_output.prior_divergence[:, None]) check_shape(next_val, [self._hp.batch_size, 1]) return next_val.squeeze(-1)
def _compute_policy_loss(self, experience_batch, policy_output): """Computes loss for policy update.""" q_est = torch.min(*[critic(experience_batch.observation, self._prep_action(policy_output.action)).q for critic in self.critics]) policy_loss = -1 * q_est + self.alpha * policy_output.prior_divergence[:, None] check_shape(policy_loss, [self._hp.batch_size, 1]) return policy_loss.mean()
def _compute_next_value(self, experience_batch, policy_output): q_next = torch.min(*[ critic_target(experience_batch.observation_next, self._prep_action(policy_output.action)).q for critic_target in self.critic_targets ]) next_val = (q_next - self.alpha * policy_output.log_prob[:, None]) check_shape(next_val, [self._hp.batch_size, 1]) return next_val.squeeze(-1)
def _compute_policy_loss(self, experience_batch, policy_output): q_est = torch.min(*[ critic(experience_batch.observation, self._prep_action(policy_output.action)).q for critic in self.critics ]) policy_loss = -1 * q_est + self.alpha * policy_output.log_prob[:, None] check_shape(policy_loss, [self._hp.batch_size, 1]) return policy_loss.mean()
def _compute_critic_loss(self, experience_batch, q_target): qs = self._compute_q_estimates(experience_batch) check_shape(qs[0], [self._hp.batch_size]) critic_losses = [0.5 * (q - q_target).pow(2).mean() for q in qs] return critic_losses, qs
def update(self, experience_batch): """Updates actor and critics.""" # push experience batch into replay buffer self.add_experience(experience_batch) for _ in range(self._hp.update_iterations): # sample batch and normalize experience_batch = self._sample_experience() experience_batch = self._normalize_batch(experience_batch) experience_batch = map2torch(experience_batch, self._hp.device) experience_batch = self._preprocess_experience(experience_batch) policy_output = self._run_policy(experience_batch.observation) # update alpha alpha_loss = self._update_alpha(experience_batch, policy_output) # compute policy loss policy_loss = self._compute_policy_loss(experience_batch, policy_output) # compute target Q value with torch.no_grad(): policy_output_next = self._run_policy( experience_batch.observation_next) value_next = self._compute_next_value(experience_batch, policy_output_next) q_target = experience_batch.reward * self._hp.reward_scale + \ (1 - experience_batch.done) * self._hp.discount_factor * value_next if self._hp.clip_q_target: q_target = self._clip_q_target(q_target) q_target = q_target.detach() check_shape(q_target, [self._hp.batch_size]) # compute critic loss critic_losses, qs = self._compute_critic_loss( experience_batch, q_target) # update critic networks [ self._perform_update(critic_loss, critic_opt, critic) for critic_loss, critic_opt, critic in zip( critic_losses, self.critic_opts, self.critics) ] # update target networks [ self._soft_update_target_network(critic_target, critic) for critic_target, critic in zip(self.critic_targets, self.critics) ] # update policy network on policy loss self._perform_update(policy_loss, self.policy_opt, self.policy) # logging info = AttrDict( # losses policy_loss=policy_loss, alpha_loss=alpha_loss, critic_loss_1=critic_losses[0], critic_loss_2=critic_losses[1], ) if self._update_steps % 100 == 0: info.update( AttrDict( # gradient norms policy_grad_norm=avg_grad_norm(self.policy), critic_1_grad_norm=avg_grad_norm(self.critics[0]), critic_2_grad_norm=avg_grad_norm(self.critics[1]), )) info.update( AttrDict( # misc alpha=self.alpha, pi_log_prob=policy_output.log_prob.mean(), policy_entropy=policy_output.dist.entropy().mean(), q_target=q_target.mean(), q_1=qs[0].mean(), q_2=qs[1].mean(), )) info.update(self._aux_info(experience_batch, policy_output)) info = map_dict(ten2ar, info) self._update_steps += 1 return info