示例#1
0
    def _categorical_loss(self, states_t, actions_t, rewards_t, states_tp1,
                          done_t):
        gammas = self._gammas**self._n_step

        # actor loss
        # For now we have the same actor for all heads of the critic
        logits_tp0 = self.critic(states_t, self.actor(states_t))
        probs_tp0 = torch.softmax(logits_tp0, dim=-1)
        q_values_tp0 = torch.sum(probs_tp0 * self.z, dim=-1)
        policy_loss = -torch.mean(q_values_tp0)

        # critic loss (kl-divergence between categorical distributions)

        done_t = done_t[:, None, :]
        # B x 1 x 1
        rewards_t = rewards_t[:, None, :]
        # B x 1 x 1
        gammas = gammas[None, :, None]
        # 1 x num_heads x 1

        logits_t = self.critic(states_t, actions_t)
        # B x num_heads x num_atoms
        logits_tp1 = self.target_critic(
            states_tp1, self.target_actor(states_tp1)).detach()
        # B x num_heads x num_atoms
        atoms_target_t = rewards_t + (1 - done_t) * gammas * self.z
        # B x num_heads x num_atoms

        value_loss = categorical_loss(logits_t.view(-1, self.num_atoms),
                                      logits_tp1.view(-1, self.num_atoms),
                                      atoms_target_t.view(-1, self.num_atoms),
                                      self.z, self.delta_z, self.v_min,
                                      self.v_max)

        return policy_loss, value_loss
示例#2
0
    def _categorical_loss(
        self, states_t, actions_t, rewards_t, states_tp1, done_t
    ):

        # actor loss
        actions_tp0, log_pi_tp0 = self.actor(states_t, with_log_pi=True)
        log_pi_tp0 = log_pi_tp0 / self.reward_scale
        logits_tp0 = [x(states_t, actions_tp0) for x in self.critics]
        probs_tp0 = [F.softmax(x, dim=-1) for x in logits_tp0]
        q_values_tp0 = [
            torch.sum(x * self.z, dim=-1).unsqueeze_(-1) for x in probs_tp0
        ]
        q_values_tp0_min = torch.cat(q_values_tp0, dim=-1).min(dim=-1)[0]
        policy_loss = torch.mean(log_pi_tp0 - q_values_tp0_min)

        # critic loss (kl-divergence between categorical distributions)
        actions_tp1, log_pi_tp1 = self.actor(
            states_tp1, with_log_pi=True
        )
        log_pi_tp1 = log_pi_tp1 / self.reward_scale
        logits_t = [x(states_t, actions_t) for x in self.critics]
        logits_tp1 = [x(states_tp1, actions_tp1) for x in self.target_critics]
        probs_tp1 = [F.softmax(x, dim=-1) for x in logits_tp1]
        q_values_tp1 = [
            torch.sum(x * self.z, dim=-1).unsqueeze_(-1) for x in probs_tp1
        ]
        probs_ids_tp1_min = torch.cat(q_values_tp1, dim=-1).argmin(dim=1)

        logits_tp1 = torch.cat([x.unsqueeze(-1) for x in logits_tp1], dim=-1)
        logits_tp1 = logits_tp1[range(len(logits_tp1)), :, probs_ids_tp1_min]
        gamma = self.gamma**self.n_step
        z_target_tp1 = (self.z[None, :] - log_pi_tp1[:, None]).detach()
        atoms_target_t = rewards_t + (1 - done_t) * gamma * z_target_tp1
        value_loss = [
            categorical_loss(
                x, logits_tp1, atoms_target_t, self.z, self.delta_z,
                self.v_min, self.v_max
            ) for x in logits_t
        ]

        return policy_loss, value_loss
示例#3
0
    def _categorical_loss(self, states_t, actions_t, rewards_t, states_tp1,
                          done_t):

        critic = self.critic
        target_critic = self.target_critic
        gammas = (self._gammas**self._n_step)[None, :, None]
        # 1 x num_heads x 1

        done_t = done_t[:, None, :]  # B x 1 x 1
        rewards_t = rewards_t[:, None, :]  # B x 1 x 1
        actions_t = actions_t[:, None, None, :]  # B x 1 x 1 x 1
        indices_t = actions_t.repeat(1, self._num_heads, 1, self.num_atoms)
        # B x num_heads x 1 x num_atoms

        logits_t = critic(states_t).gather(-2, indices_t).squeeze(-2)
        # B x num_heads x num_atoms

        all_logits_tp1 = target_critic(states_tp1).detach()
        # B x num_heads x num_actions x num_atoms

        q_values_tp1 = torch.sum(torch.softmax(all_logits_tp1, dim=-1) *
                                 self.z,
                                 dim=-1)
        actions_tp1 = torch.argmax(q_values_tp1, dim=-1, keepdim=True)
        # B x num_heads x 1

        indices_tp1 = \
            actions_tp1.unsqueeze(-1).repeat(1, 1, 1, self.num_atoms)
        # B x num_heads x 1 x num_atoms

        logits_tp1 = all_logits_tp1.gather(-2, indices_tp1).squeeze(-2)
        # B x num_heads x num_atoms
        atoms_target_t = rewards_t + (1 - done_t) * gammas * self.z

        value_loss = categorical_loss(logits_t.view(-1, self.num_atoms),
                                      logits_tp1.view(-1, self.num_atoms),
                                      atoms_target_t.view(-1, self.num_atoms),
                                      self.z, self.delta_z, self.v_min,
                                      self.v_max)

        return value_loss
示例#4
0
    def _categorical_loss(self, states_t, actions_t, rewards_t, states_tp1,
                          done_t):
        gamma = self.gamma**self.n_step

        # actor loss
        logits_tp0 = self.critic(states_t, self.actor(states_t))
        probs_tp0 = F.softmax(logits_tp0, dim=-1)
        q_values_tp0 = torch.sum(probs_tp0 * self.z, dim=-1)
        policy_loss = -torch.mean(q_values_tp0)

        # critic loss (kl-divergence between categorical distributions)
        logits_t = self.critic(states_t, actions_t)
        logits_tp1 = self.target_critic(
            states_tp1, self.target_actor(states_tp1)).detach()
        atoms_target_t = rewards_t + (1 - done_t) * gamma * self.z

        value_loss = categorical_loss(logits_t, logits_tp1, atoms_target_t,
                                      self.z, self.delta_z, self.v_min,
                                      self.v_max)

        return policy_loss, value_loss
示例#5
0
文件: dqn.py 项目: qrltrader/catalyst
    def _categorical_loss(self, states_t, actions_t, rewards_t, states_tp1,
                          done_t):
        gamma = self._gamma**self._n_step

        # critic loss (kl-divergence between categorical distributions)
        indices_t = actions_t.repeat(1, self.num_atoms).unsqueeze(1)
        logits_t = self.critic(states_t).gather(1, indices_t).squeeze(1)

        all_logits_tp1 = self.target_critic(states_tp1).detach()
        q_values_tp1 = torch.sum(F.softmax(all_logits_tp1, dim=-1) * self.z,
                                 dim=-1)
        actions_tp1 = torch.argmax(q_values_tp1, dim=-1, keepdim=True)
        indices_tp1 = actions_tp1.repeat(1, self.num_atoms).unsqueeze(1)
        logits_tp1 = all_logits_tp1.gather(1, indices_tp1).squeeze(1)
        atoms_target_t = rewards_t + (1 - done_t) * gamma * self.z

        value_loss = categorical_loss(logits_t, logits_tp1, atoms_target_t,
                                      self.z, self.delta_z, self.v_min,
                                      self.v_max)

        return value_loss
示例#6
0
    def _categorical_loss(self, states_t, actions_t, rewards_t, states_tp1,
                          done_t):

        # actor loss
        actions_tp0 = self.actor(states_t)
        logits_tp0 = [x(states_t, actions_tp0) for x in self.critics]
        probs_tp0 = [F.softmax(x, dim=-1) for x in logits_tp0]
        q_values_tp0 = [
            torch.sum(x * self.z, dim=-1).unsqueeze_(-1) for x in probs_tp0
        ]
        q_values_tp0_min = torch.cat(q_values_tp0, dim=-1).min(dim=-1)[0]
        policy_loss = -torch.mean(q_values_tp0_min)

        # critic loss (kl-divergence between categorical distributions)
        actions_tp1 = self.target_actor(states_tp1).detach()
        actions_tp1 = self._add_noise_to_actions(actions_tp1)
        logits_t = [x(states_t, actions_t) for x in self.critics]
        logits_tp1 = [x(states_tp1, actions_tp1) for x in self.target_critics]
        probs_tp1 = [F.softmax(x, dim=-1) for x in logits_tp1]
        q_values_tp1 = [
            torch.sum(x * self.z, dim=-1).unsqueeze_(-1) for x in probs_tp1
        ]
        probs_ids_tp1_min = torch.cat(q_values_tp1, dim=-1).argmin(dim=1)

        logits_tp1 = torch.cat([x.unsqueeze(-1) for x in logits_tp1], dim=-1)
        logits_tp1 = \
            logits_tp1[range(len(logits_tp1)), :, probs_ids_tp1_min].detach()
        gamma = self.gamma**self.n_step
        atoms_target_t = rewards_t + (1 - done_t) * gamma * self.z
        value_loss = [
            categorical_loss(x, logits_tp1, atoms_target_t, self.z,
                             self.delta_z, self.v_min, self.v_max)
            for x in logits_t
        ]

        return policy_loss, value_loss