Example #1
0
 def _compute_next_value(self, experience_batch, policy_output):
     """Computes value of next state for target value computation."""
     q_next = torch.min(*[critic_target(experience_batch.observation_next, self._prep_action(policy_output.action)).q
                          for critic_target in self.critic_targets])
     next_val = (q_next - self.alpha * policy_output.prior_divergence[:, None])
     check_shape(next_val, [self._hp.batch_size, 1])
     return next_val.squeeze(-1)
Example #2
0
 def _compute_policy_loss(self, experience_batch, policy_output):
     """Computes loss for policy update."""
     q_est = torch.min(*[critic(experience_batch.observation, self._prep_action(policy_output.action)).q
                                   for critic in self.critics])
     policy_loss = -1 * q_est + self.alpha * policy_output.prior_divergence[:, None]
     check_shape(policy_loss, [self._hp.batch_size, 1])
     return policy_loss.mean()
Example #3
0
 def _compute_next_value(self, experience_batch, policy_output):
     q_next = torch.min(*[
         critic_target(experience_batch.observation_next,
                       self._prep_action(policy_output.action)).q
         for critic_target in self.critic_targets
     ])
     next_val = (q_next - self.alpha * policy_output.log_prob[:, None])
     check_shape(next_val, [self._hp.batch_size, 1])
     return next_val.squeeze(-1)
Example #4
0
 def _compute_policy_loss(self, experience_batch, policy_output):
     q_est = torch.min(*[
         critic(experience_batch.observation,
                self._prep_action(policy_output.action)).q
         for critic in self.critics
     ])
     policy_loss = -1 * q_est + self.alpha * policy_output.log_prob[:, None]
     check_shape(policy_loss, [self._hp.batch_size, 1])
     return policy_loss.mean()
Example #5
0
 def _compute_critic_loss(self, experience_batch, q_target):
     qs = self._compute_q_estimates(experience_batch)
     check_shape(qs[0], [self._hp.batch_size])
     critic_losses = [0.5 * (q - q_target).pow(2).mean() for q in qs]
     return critic_losses, qs
Example #6
0
    def update(self, experience_batch):
        """Updates actor and critics."""
        # push experience batch into replay buffer
        self.add_experience(experience_batch)

        for _ in range(self._hp.update_iterations):
            # sample batch and normalize
            experience_batch = self._sample_experience()
            experience_batch = self._normalize_batch(experience_batch)
            experience_batch = map2torch(experience_batch, self._hp.device)
            experience_batch = self._preprocess_experience(experience_batch)

            policy_output = self._run_policy(experience_batch.observation)

            # update alpha
            alpha_loss = self._update_alpha(experience_batch, policy_output)

            # compute policy loss
            policy_loss = self._compute_policy_loss(experience_batch,
                                                    policy_output)

            # compute target Q value
            with torch.no_grad():
                policy_output_next = self._run_policy(
                    experience_batch.observation_next)
                value_next = self._compute_next_value(experience_batch,
                                                      policy_output_next)
                q_target = experience_batch.reward * self._hp.reward_scale + \
                                (1 - experience_batch.done) * self._hp.discount_factor * value_next
                if self._hp.clip_q_target:
                    q_target = self._clip_q_target(q_target)
                q_target = q_target.detach()
                check_shape(q_target, [self._hp.batch_size])

            # compute critic loss
            critic_losses, qs = self._compute_critic_loss(
                experience_batch, q_target)

            # update critic networks
            [
                self._perform_update(critic_loss, critic_opt, critic)
                for critic_loss, critic_opt, critic in zip(
                    critic_losses, self.critic_opts, self.critics)
            ]

            # update target networks
            [
                self._soft_update_target_network(critic_target, critic) for
                critic_target, critic in zip(self.critic_targets, self.critics)
            ]

            # update policy network on policy loss
            self._perform_update(policy_loss, self.policy_opt, self.policy)

            # logging
            info = AttrDict(  # losses
                policy_loss=policy_loss,
                alpha_loss=alpha_loss,
                critic_loss_1=critic_losses[0],
                critic_loss_2=critic_losses[1],
            )
            if self._update_steps % 100 == 0:
                info.update(
                    AttrDict(  # gradient norms
                        policy_grad_norm=avg_grad_norm(self.policy),
                        critic_1_grad_norm=avg_grad_norm(self.critics[0]),
                        critic_2_grad_norm=avg_grad_norm(self.critics[1]),
                    ))
            info.update(
                AttrDict(  # misc
                    alpha=self.alpha,
                    pi_log_prob=policy_output.log_prob.mean(),
                    policy_entropy=policy_output.dist.entropy().mean(),
                    q_target=q_target.mean(),
                    q_1=qs[0].mean(),
                    q_2=qs[1].mean(),
                ))
            info.update(self._aux_info(experience_batch, policy_output))
            info = map_dict(ten2ar, info)

            self._update_steps += 1

        return info