Exemple #1
0
    def _quantile_loss(self, states_t, actions_t, rewards_t, states_tp1,
                       done_t):
        gammas = self._gammas**self._n_step

        # actor loss
        policy_loss = -torch.mean(self.critic(states_t, self.actor(states_t)))

        # critic loss (quantile regression)
        atoms_t = self.critic(states_t, actions_t)
        # B x num_heads x num_atoms
        atoms_tp1 = self.target_critic(states_tp1,
                                       self.target_actor(states_tp1)).detach()
        # B x num_heads x num_atoms

        done_t = done_t[:, None, :]
        # B x 1 x 1
        rewards_t = rewards_t[:, None, :]
        # B x 1 x 1
        gammas = gammas[None, :, None]
        # 1 x num_heads x 1

        atoms_target_t = rewards_t + (1 - done_t) * gammas * atoms_tp1

        value_loss = quantile_loss(atoms_t.view(-1, self.num_atoms),
                                   atoms_target_t.view(-1, self.num_atoms),
                                   self.tau, self.num_atoms,
                                   self.critic_criterion)

        return policy_loss, value_loss
Exemple #2
0
    def _quantile_loss(self, states_t, actions_t, rewards_t, states_tp1,
                       done_t):

        # actor loss
        actions_tp0 = self.actor(states_t)
        atoms_tp0 = [
            x(states_t, actions_tp0).unsqueeze_(-1) for x in self.critics
        ]
        q_values_tp0_min = torch.cat(atoms_tp0,
                                     dim=-1).mean(dim=1).min(dim=1)[0]
        policy_loss = -torch.mean(q_values_tp0_min)

        # critic loss (quantile regression)
        actions_tp1 = self.target_actor(states_tp1).detach()
        actions_tp1 = self._add_noise_to_actions(actions_tp1)
        atoms_t = [x(states_t, actions_t) for x in self.critics]
        atoms_tp1 = torch.cat([
            x(states_tp1, actions_tp1).unsqueeze_(-1)
            for x in self.target_critics
        ],
                              dim=-1)
        atoms_ids_tp1_min = atoms_tp1.mean(dim=1).argmin(dim=1)
        atoms_tp1 = \
            atoms_tp1[range(len(atoms_tp1)), :, atoms_ids_tp1_min].detach()
        gamma = self.gamma**self.n_step
        atoms_target_t = rewards_t + (1 - done_t) * gamma * atoms_tp1
        value_loss = [
            quantile_loss(x, atoms_target_t, self.tau, self.n_atoms,
                          self.critic_criterion) for x in atoms_t
        ]

        return policy_loss, value_loss
Exemple #3
0
    def _quantile_loss(self, states_t, actions_t, rewards_t, states_tp1,
                       done_t):
        gamma = self.gamma**self.n_step

        # actor loss
        policy_loss = -torch.mean(self.critic(states_t, self.actor(states_t)))

        # critic loss (quantile regression)
        atoms_t = self.critic(states_t, actions_t)
        atoms_tp1 = self.target_critic(states_tp1,
                                       self.target_actor(states_tp1)).detach()
        atoms_target_t = rewards_t + (1 - done_t) * gamma * atoms_tp1

        value_loss = quantile_loss(atoms_t, atoms_target_t, self.tau,
                                   self.n_atoms, self.critic_criterion)

        return policy_loss, value_loss
Exemple #4
0
    def _quantile_loss(self, states_t, actions_t, rewards_t, states_tp1,
                       done_t):
        gamma = self._gamma**self._n_step

        # critic loss (quantile regression)
        indices_t = actions_t.repeat(1, self.num_atoms).unsqueeze(1)
        atoms_t = self.critic(states_t).gather(1, indices_t).squeeze(1)

        all_atoms_tp1 = self.target_critic(states_tp1).detach()
        q_values_tp1 = all_atoms_tp1.mean(dim=-1)
        actions_tp1 = torch.argmax(q_values_tp1, dim=-1, keepdim=True)
        indices_tp1 = actions_tp1.repeat(1, self.num_atoms).unsqueeze(1)
        atoms_tp1 = all_atoms_tp1.gather(1, indices_tp1).squeeze(1)
        atoms_target_t = rewards_t + (1 - done_t) * gamma * atoms_tp1

        value_loss = quantile_loss(atoms_t, atoms_target_t, self.tau,
                                   self.num_atoms, self.critic_criterion)

        return value_loss
Exemple #5
0
    def _quantile_loss(
        self, states_t, actions_t, rewards_t, states_tp1, done_t
    ):

        # actor loss
        actions_tp0, log_pi_tp0 = self.actor(states_t, with_log_pi=True)
        log_pi_tp0 = log_pi_tp0[:, None] / self.reward_scale
        atoms_tp0 = [
            x(states_t, actions_tp0).unsqueeze_(-1) for x in self.critics
        ]
        q_values_tp0_min = torch.cat(
            atoms_tp0, dim=-1
        ).mean(dim=1).min(dim=1)[0]
        policy_loss = torch.mean(log_pi_tp0 - q_values_tp0_min)

        # critic loss (quantile regression)
        actions_tp1, log_pi_tp1 = self.actor(
            states_tp1, with_log_pi=True
        )
        log_pi_tp1 = log_pi_tp1[:, None] / self.reward_scale
        atoms_t = [x(states_t, actions_t) for x in self.critics]
        atoms_tp1 = torch.cat(
            [
                x(states_tp1, actions_tp1).unsqueeze_(-1)
                for x in self.target_critics
            ],
            dim=-1
        )
        atoms_ids_tp1_min = atoms_tp1.mean(dim=1).argmin(dim=1)
        atoms_tp1 = atoms_tp1[range(len(atoms_tp1)), :, atoms_ids_tp1_min]
        gamma = self.gamma**self.n_step
        atoms_tp1 = (atoms_tp1 - log_pi_tp1).detach()
        atoms_target_t = rewards_t + (1 - done_t) * gamma * atoms_tp1
        value_loss = [
            quantile_loss(
                x, atoms_target_t, self.tau, self.n_atoms,
                self.critic_criterion
            ) for x in atoms_t
        ]

        return policy_loss, value_loss
Exemple #6
0
    def _quantile_loss(self, states_t, actions_t, rewards_t, states_tp1,
                       done_t):

        critic = self.critic
        target_critic = self.target_critic
        gammas = (self._gammas**self._n_step)[None, :, None]
        # 1 x num_heads x 1

        done_t = done_t[:, None, :]  # B x 1 x 1
        rewards_t = rewards_t[:, None, :]  # B x 1 x 1
        actions_t = actions_t[:, None, None, :]  # B x 1 x 1 x 1
        indices_t = actions_t.repeat(1, self._num_heads, 1, self.num_atoms)
        # B x num_heads x 1 x num_atoms

        # critic loss (quantile regression)

        atoms_t = critic(states_t).gather(-2, indices_t).squeeze(-2)
        # B x num_heads x num_atoms

        all_atoms_tp1 = target_critic(states_tp1).detach()
        # B x num_heads x num_actions x num_atoms

        q_values_tp1 = all_atoms_tp1.mean(dim=-1)
        # B x num_heads x num_actions
        actions_tp1 = torch.argmax(q_values_tp1, dim=-1, keepdim=True)
        # B x num_heads x 1
        indices_tp1 = actions_tp1.unsqueeze(-1).repeat(1, 1, 1, self.num_atoms)
        # B x num_heads x 1 x num_atoms
        atoms_tp1 = all_atoms_tp1.gather(-2, indices_tp1).squeeze(-2)
        # B x num_heads x num_atoms
        atoms_target_t = rewards_t + (1 - done_t) * gammas * atoms_tp1

        value_loss = quantile_loss(atoms_t.view(-1, self.num_atoms),
                                   atoms_target_t.view(-1, self.num_atoms),
                                   self.tau, self.num_atoms,
                                   self.critic_criterion)

        return value_loss