Exemplo n.º 1
0
    def _categorical_loss(
        self, states_t, actions_t, rewards_t, states_tp1, done_t
    ):
        gammas, done_t, rewards_t = self._process_components(done_t, rewards_t)

        # actor loss
        # For now we have the same actor for all heads of the critic
        logits_tp0 = self.critic(states_t, self.actor(states_t))
        probs_tp0 = torch.softmax(logits_tp0, dim=-1)
        q_values_tp0 = torch.sum(probs_tp0 * self.z, dim=-1)
        policy_loss = -torch.mean(q_values_tp0)

        # critic loss (kl-divergence between categorical distributions)
        # B x num_heads x num_atoms
        logits_t = self.critic(states_t, actions_t).squeeze_(dim=2)
        # B x num_heads x num_atoms
        logits_tp1 = self.target_critic(
            states_tp1, self.target_actor(states_tp1)
        ).squeeze_(dim=2).detach()

        # B x num_heads x num_atoms
        atoms_target_t = rewards_t + (1 - done_t) * gammas * self.z
        value_loss = utils.categorical_loss(
            logits_t.view(-1, self.num_atoms),
            logits_tp1.view(-1, self.num_atoms),
            atoms_target_t.view(-1, self.num_atoms),
            self.z, self.delta_z,
            self.v_min, self.v_max
        )

        return policy_loss, value_loss
Exemplo n.º 2
0
    def _categorical_loss(self, states_t, actions_t, rewards_t, states_tp1,
                          done_t):

        # actor loss
        actions_tp0, log_pi_tp0 = self.actor(states_t, logprob=True)
        log_pi_tp0 = log_pi_tp0 / self.reward_scale
        logits_tp0 = [x(states_t, actions_tp0) for x in self.critics]
        probs_tp0 = [torch.softmax(x, dim=-1) for x in logits_tp0]
        q_values_tp0 = [
            torch.sum(x * self.z, dim=-1, keepdim=True) for x in probs_tp0
        ]
        q_values_tp0_min = torch.cat(q_values_tp0, dim=-1).min(dim=-1)[0]
        # B x num_heads
        # For now we use the same actor for each gamma
        policy_loss = torch.mean(log_pi_tp0[:, None] - q_values_tp0_min)

        # critic loss (kl-divergence between categorical distributions)
        actions_tp1, log_pi_tp1 = self.actor(states_tp1, logprob=True)
        log_pi_tp1 = log_pi_tp1 / self.reward_scale
        logits_t = [x(states_t, actions_t) for x in self.critics]
        logits_tp1 = [x(states_tp1, actions_tp1) for x in self.target_critics]
        probs_tp1 = [torch.softmax(x, dim=-1) for x in logits_tp1]
        q_values_tp1 = [
            torch.sum(x * self.z, dim=-1, keepdim=True) for x in probs_tp1
        ]
        probs_ids_tp1_min = torch.cat(q_values_tp1, dim=-1).argmin(dim=-1)
        # B x num_heads

        logits_tp1 = torch.cat([x.unsqueeze(-1) for x in logits_tp1], dim=-1)
        # B x num_heads x num_atoms x num_critics

        # @TODO: smarter way to do this (other than reshaping)?
        probs_ids_tp1_min = probs_ids_tp1_min.view(-1)
        logits_tp1 = logits_tp1.view(-1, self.num_atoms, self._num_critics)
        logits_tp1 = logits_tp1[range(len(logits_tp1)), :, probs_ids_tp1_min].\
            view(-1, self._num_heads, self.num_atoms)
        # B x num_heads x num_atoms

        gammas = self._gammas**self._n_step
        done_t = done_t[:, None, :]  # B x 1 x 1
        rewards_t = rewards_t[:, None, :]  # B x 1 x 1
        gammas = gammas[None, :, None]  # 1 x num_heads x 1

        z_target_tp1 = (self.z[None, :] - log_pi_tp1[:, None]).detach()
        # B x num_atoms
        # Unsqueeze so its the same for each head
        z_target_tp1 = z_target_tp1.unsqueeze(1)

        atoms_target_t = rewards_t + (1 - done_t) * gammas * z_target_tp1
        value_loss = [
            utils.categorical_loss(x.view(-1, self.num_atoms),
                                   logits_tp1.view(-1, self.num_atoms),
                                   atoms_target_t.view(-1,
                                                       self.num_atoms), self.z,
                                   self.delta_z, self.v_min, self.v_max)
            for x in logits_t
        ]

        return policy_loss, value_loss
Exemplo n.º 3
0
    def _categorical_loss(self, states_t, actions_t, rewards_t, states_tp1,
                          done_t):
        gammas, done_t, rewards_t = self._process_components(done_t, rewards_t)

        # actor loss
        actions_tp0 = self.actor(states_t)
        # Again, we use the same actor for each critic
        logits_tp0 = [
            x(states_t, actions_tp0).squeeze_(dim=2) for x in self.critics
        ]
        probs_tp0 = [torch.softmax(x, dim=-1) for x in logits_tp0]
        q_values_tp0 = [
            torch.sum(x * self.z, dim=-1, keepdim=True) for x in probs_tp0
        ]
        q_values_tp0_min = torch.cat(q_values_tp0, dim=-1).min(dim=-1)[0]
        policy_loss = -torch.mean(q_values_tp0_min)

        # critic loss (kl-divergence between categorical distributions)
        actions_tp1 = self.target_actor(states_tp1).detach()
        actions_tp1 = self._add_noise_to_actions(actions_tp1)
        logits_t = [
            x(states_t, actions_t).squeeze_(dim=2) for x in self.critics
        ]
        logits_tp1 = [
            x(states_tp1, actions_tp1).squeeze_(dim=2)
            for x in self.target_critics
        ]
        probs_tp1 = [torch.softmax(x, dim=-1) for x in logits_tp1]
        q_values_tp1 = [
            torch.sum(x * self.z, dim=-1, keepdim=True) for x in probs_tp1
        ]
        probs_ids_tp1_min = torch.cat(q_values_tp1, dim=-1).argmin(dim=-1)
        # B x num_heads

        logits_tp1 = torch.cat([x.unsqueeze(-1) for x in logits_tp1], dim=-1)
        # B x num_heads x num_atoms x num_critics
        # @TODO: smarter way to do this (other than reshaping)?
        probs_ids_tp1_min = probs_ids_tp1_min.view(-1)
        logits_tp1 = logits_tp1.view(-1, self.num_atoms, self._num_critics)

        logits_tp1 = \
            logits_tp1[range(len(logits_tp1)), :, probs_ids_tp1_min].\
            view(-1, self._num_heads, self.num_atoms).detach()

        atoms_target_t = rewards_t + (1 - done_t) * gammas * self.z
        value_loss = [
            utils.categorical_loss(x.view(-1, self.num_atoms),
                                   logits_tp1.view(-1, self.num_atoms),
                                   atoms_target_t.view(-1,
                                                       self.num_atoms), self.z,
                                   self.delta_z, self.v_min, self.v_max)
            for x in logits_t
        ]

        return policy_loss, value_loss
Exemplo n.º 4
0
    def _categorical_loss(
        self, states_t, actions_t, rewards_t, states_tp1, done_t
    ):
        gammas, done_t, rewards_t = self._process_components(done_t, rewards_t)

        # actor loss
        # For now we have the same actor for all heads of the critic
        # [bs; num_heads; num_atoms] -> many-heads view transform
        # [{bs * num_heads}; num_atoms]
        logits_tp0 = (
            self.critic(states_t, self.actor(states_t)).squeeze_(dim=2)
            .view(-1, self.num_atoms)
        )
        # [{bs * num_heads}; num_atoms]
        probs_tp0 = torch.softmax(logits_tp0, dim=-1)
        # [{bs * num_heads}; 1]
        q_values_tp0 = torch.sum(probs_tp0 * self.z, dim=-1)
        policy_loss = -torch.mean(q_values_tp0)

        # critic loss (kl-divergence between categorical distributions)
        # [bs; num_heads; num_atoms] -> many-heads view transform
        # [{bs * num_heads}; num_atoms]
        logits_t = (
            self.critic(states_t, actions_t).squeeze_(dim=2)
            .view(-1, self.num_atoms)
        )

        # [bs; action_size]
        actions_tp1 = self.target_actor(states_tp1)
        # [bs; num_heads; num_atoms] -> many-heads view transform
        # [{bs * num_heads}; num_atoms]
        logits_tp1 = (
            self.target_critic(states_tp1, actions_tp1).squeeze_(dim=2)
            .view(-1, self.num_atoms)
        ).detach()

        # [bs; num_heads; num_atoms] -> many-heads view transform
        # [{bs * num_heads}; num_atoms]
        atoms_target_t = (
            rewards_t + (1 - done_t) * gammas * self.z
        ).view(-1, self.num_atoms)

        value_loss = utils.categorical_loss(
            # [{bs * num_heads}; num_atoms]
            logits_t,
            # [{bs * num_heads}; num_atoms]
            logits_tp1,
            # [{bs * num_heads}; num_atoms]
            atoms_target_t,
            self.z, self.delta_z,
            self.v_min, self.v_max
        )

        return policy_loss, value_loss
Exemplo n.º 5
0
    def _categorical_loss(
        self, states_t, actions_t, rewards_t, states_tp1, done_t
    ):

        gammas = (self._gammas**self._n_step)[None, :, None]
        # 1 x num_heads x 1

        done_t = done_t[:, None, :]  # B x 1 x 1
        rewards_t = rewards_t[:, None, :]  # B x 1 x 1
        actions_t = actions_t[:, None, None, :]  # B x 1 x 1 x 1
        indices_t = actions_t.repeat(1, self._num_heads, 1, self.num_atoms)
        # B x num_heads x 1 x num_atoms

        logits_t = self.critic(states_t).gather(-2, indices_t).squeeze(-2)
        # B x num_heads x num_atoms

        all_logits_tp1 = self.target_critic(states_tp1).detach()
        # B x num_heads x num_actions x num_atoms

        q_values_tp1 = torch.sum(
            torch.softmax(all_logits_tp1, dim=-1) * self.z, dim=-1
        )
        actions_tp1 = torch.argmax(q_values_tp1, dim=-1, keepdim=True)
        # B x num_heads x 1

        indices_tp1 = \
            actions_tp1.unsqueeze(-1).repeat(1, 1, 1, self.num_atoms)
        # B x num_heads x 1 x num_atoms

        logits_tp1 = all_logits_tp1.gather(-2, indices_tp1).squeeze(-2)
        # B x num_heads x num_atoms
        atoms_target_t = rewards_t + (1 - done_t) * gammas * self.z

        value_loss = utils.categorical_loss(
            logits_t.view(-1, self.num_atoms),
            logits_tp1.view(-1, self.num_atoms),
            atoms_target_t.view(-1, self.num_atoms), self.z, self.delta_z,
            self.v_min, self.v_max
        )

        return value_loss
Exemplo n.º 6
0
    def _categorical_loss(
        self, states_t, actions_t, rewards_t, states_tp1, done_t
    ):
        gammas = self._gammas ** self._n_step

        # actor loss
        # For now we have the same actor for all heads of the critic
        logits_tp0 = self.critic(states_t, self.actor(states_t))
        probs_tp0 = torch.softmax(logits_tp0, dim=-1)
        q_values_tp0 = torch.sum(probs_tp0 * self.z, dim=-1)
        policy_loss = -torch.mean(q_values_tp0)

        # critic loss (kl-divergence between categorical distributions)

        done_t = done_t[:, None, :]
        # B x 1 x 1
        rewards_t = rewards_t[:, None, :]
        # B x 1 x 1
        gammas = gammas[None, :, None]
        # 1 x num_heads x 1

        logits_t = self.critic(states_t, actions_t)
        # B x num_heads x num_atoms
        logits_tp1 = self.target_critic(
            states_tp1, self.target_actor(states_tp1)
        ).detach()
        # B x num_heads x num_atoms
        atoms_target_t = rewards_t + (1 - done_t) * gammas * self.z
        # B x num_heads x num_atoms

        value_loss = utils.categorical_loss(
            logits_t.view(-1, self.num_atoms),
            logits_tp1.view(-1, self.num_atoms),
            atoms_target_t.view(-1, self.num_atoms),
            self.z,
            self.delta_z,
            self.v_min, self.v_max
        )

        return policy_loss, value_loss
Exemplo n.º 7
0
    def _categorical_loss(self, states_t, actions_t, rewards_t, states_tp1,
                          done_t):
        gammas, done_t, rewards_t = self._process_components(done_t, rewards_t)

        actions_t = actions_t[:, None, None, :]  # B x 1 x 1 x 1
        # B x num_heads x 1 x num_atoms
        indices_t = actions_t.repeat(1, self._num_heads, 1, self.num_atoms)
        # B x num_heads x num_actions x num_atoms
        q_logits_t = self.critic(states_t)
        # B x num_heads x num_atoms
        logits_t = q_logits_t.gather(-2, indices_t).squeeze(-2)

        # B x num_heads x num_actions x num_atoms
        q_logits_tp1 = self.target_critic(states_tp1).detach()
        q_values_tp1 = torch.sum(torch.softmax(q_logits_tp1, dim=-1) * self.z,
                                 dim=-1)
        # B x num_heads x 1
        actions_tp1 = torch.argmax(q_values_tp1, dim=-1, keepdim=True)
        # B x num_heads x 1 x num_atoms
        indices_tp1 = \
            actions_tp1.unsqueeze(-1).repeat(1, 1, 1, self.num_atoms)
        # B x num_heads x num_atoms
        logits_tp1 = q_logits_tp1.gather(-2, indices_tp1).squeeze(-2)

        atoms_target_t = rewards_t + (1 - done_t) * gammas * self.z
        value_loss = utils.categorical_loss(
            logits_t.view(-1, self.num_atoms),
            logits_tp1.view(-1, self.num_atoms),
            atoms_target_t.view(-1, self.num_atoms), self.z, self.delta_z,
            self.v_min, self.v_max)

        if self.entropy_regularization is not None:
            q_values_t = torch.sum(torch.softmax(q_logits_t, dim=-1) * self.z,
                                   dim=-1)
            value_loss -= \
                self.entropy_regularization * self._compute_entropy(q_values_t)

        return value_loss
Exemplo n.º 8
0
    def _categorical_value_loss(self, states_t, logits_t, returns_t,
                                states_tp1, done_t):
        # @TODO: WIP, no guaranties
        logits_tp0 = self.critic(states_t).squeeze_(dim=2)
        probs_tp0 = torch.softmax(logits_tp0, dim=-1)
        values_tp0 = torch.sum(probs_tp0 * self.z, dim=-1, keepdim=True)

        probs_t = torch.softmax(logits_t, dim=-1)
        values_t = torch.sum(probs_t * self.z, dim=-1, keepdim=True)

        value_loss = 0.5 * self._value_loss(values_tp0, values_t, returns_t)

        # B x num_heads x num_atoms
        logits_tp1 = self.critic(states_tp1).squeeze_(dim=2).detach()
        # B x num_heads x num_atoms
        atoms_target_t = returns_t + (1 - done_t) * self._gammas_torch * self.z

        value_loss += 0.5 * utils.categorical_loss(
            logits_tp0.view(-1, self.num_atoms),
            logits_tp1.view(-1, self.num_atoms),
            atoms_target_t.view(-1, self.num_atoms), self.z, self.delta_z,
            self.v_min, self.v_max)

        return value_loss
Exemplo n.º 9
0
    def _categorical_loss(
        self, states_t, actions_t, rewards_t, states_tp1, done_t
    ):
        gammas, done_t, rewards_t = self._process_components(done_t, rewards_t)

        # actor loss
        # [bs; action_size]
        actions_tp0, logprob_tp0 = self.actor(states_t, logprob=True)
        logprob_tp0 = logprob_tp0 / self.reward_scale
        # {num_critics} * [bs; num_heads; num_atoms]
        # -> many-heads view transform
        # {num_critics} * [{bs * num_heads}; num_atoms]
        logits_tp0 = [
            x(states_t, actions_tp0).squeeze_(dim=2).view(-1, self.num_atoms)
            for x in self.critics
        ]
        # -> categorical probs
        # {num_critics} * [{bs * num_heads}; num_atoms]
        probs_tp0 = [torch.softmax(x, dim=-1) for x in logits_tp0]
        # -> categorical value
        # {num_critics} * [{bs * num_heads}; 1]
        q_values_tp0 = [
            torch.sum(x * self.z, dim=-1, keepdim=True) for x in probs_tp0
        ]
        #  [{bs * num_heads}; num_critics] ->  min over all critics
        #  [{bs * num_heads}]
        q_values_tp0_min = torch.cat(q_values_tp0, dim=-1).min(dim=-1)[0]
        # For now we use the same actor for each gamma
        policy_loss = torch.mean(logprob_tp0[:, None] - q_values_tp0_min)

        # critic loss (kl-divergence between categorical distributions)
        # [bs; action_size]
        actions_tp1, logprob_tp1 = self.actor(states_tp1, logprob=True)
        logprob_tp1 = logprob_tp1 / self.reward_scale

        # {num_critics} * [bs; num_heads; num_atoms]
        # -> many-heads view transform
        # {num_critics} * [{bs * num_heads}; num_atoms]
        logits_t = [
            x(states_t, actions_t).squeeze_(dim=2).view(-1, self.num_atoms)
            for x in self.critics
        ]

        # {num_critics} * [bs; num_heads; num_atoms]
        logits_tp1 = [
            x(states_tp1, actions_tp1).squeeze_(dim=2)
            for x in self.target_critics
        ]
        # {num_critics} * [{bs * num_heads}; num_atoms]
        probs_tp1 = [torch.softmax(x, dim=-1) for x in logits_tp1]
        # {num_critics} * [bs; num_heads; 1]
        q_values_tp1 = [
            torch.sum(x * self.z, dim=-1, keepdim=True) for x in probs_tp1
        ]
        #  [{bs * num_heads}; num_critics] ->  argmin over all critics
        #  [{bs * num_heads}]
        probs_ids_tp1_min = torch.cat(q_values_tp1, dim=-1).argmin(dim=-1)

        # [bs; num_heads; num_atoms; num_critics]
        logits_tp1 = torch.cat([x.unsqueeze(-1) for x in logits_tp1], dim=-1)
        # @TODO: smarter way to do this (other than reshaping)?
        probs_ids_tp1_min = probs_ids_tp1_min.view(-1)
        # [bs; num_heads; num_atoms; num_critics] -> many-heads view transform
        # [{bs * num_heads}; num_atoms; num_critics] -> min over all critics
        # [{bs * num_heads}; num_atoms; 1] -> target view transform
        # [{bs; num_heads}; num_atoms]
        logits_tp1 = (
            logits_tp1
            .view(-1, self.num_atoms, self._num_critics)[
                range(len(probs_ids_tp1_min)), :, probs_ids_tp1_min]
            .view(-1, self.num_atoms)
        ).detach()

        # [bs; num_atoms] -> unsqueeze so its the same for each head
        # [bs; 1; num_atoms]
        z_target_tp1 = (
            self.z[None, :] - logprob_tp1[:, None]
        ).unsqueeze(1).detach()
        # [bs; num_heads; num_atoms] -> many-heads view transform
        # [{bs * num_heads}; num_atoms]
        atoms_target_t = (
            rewards_t + (1 - done_t) * gammas * z_target_tp1
        ).view(-1, self.num_atoms)

        value_loss = [
            utils.categorical_loss(
                # [{bs * num_heads}; num_atoms]
                x,
                # [{bs * num_heads}; num_atoms]
                logits_tp1,
                # [{bs * num_heads}; num_atoms]
                atoms_target_t,
                self.z,
                self.delta_z,
                self.v_min,
                self.v_max
            ) for x in logits_t
        ]

        return policy_loss, value_loss
Exemplo n.º 10
0
    def _categorical_loss(
        self, states_t, actions_t, rewards_t, states_tp1, done_t
    ):
        gammas, done_t, rewards_t = self._process_components(done_t, rewards_t)

        # [bs; 1] ->
        # [bs; 1; 1; 1;]
        actions_t = actions_t[:, None, None, :]

        # [bs; num_heads; 1; num_atoms]
        indices_t = actions_t.repeat(1, self._num_heads, 1, self.num_atoms)
        # [bs; num_heads; num_actions; num_atoms]
        q_logits_t = self.critic(states_t)
        # [bs; num_heads; 1; num_atoms] -> gathering selected actions
        # [bs; num_heads; num_atoms] -> many-heads view transform
        # [{bs * num_heads}; num_atoms]
        logits_t = (
            q_logits_t.gather(-2,
                              indices_t).squeeze(-2).view(-1, self.num_atoms)
        )

        # [bs; num_heads; num_actions; num_atoms]
        q_logits_tp1 = self.target_critic(states_tp1).detach()

        # [bs; num_heads; num_actions; num_atoms] -> categorical value
        # [bs; num_heads; num_actions] -> gathering best actions
        # [bs; num_heads; 1]
        actions_tp1 = (
            (torch.softmax(q_logits_tp1, dim=-1) *
             self.z).sum(dim=-1).argmax(dim=-1, keepdim=True)
        )
        # [bs; num_heads; 1] ->
        # [bs; num_heads; 1; 1] ->
        # [bs; num_heads; 1; num_atoms]
        indices_tp1 = actions_tp1.unsqueeze(-1).repeat(1, 1, 1, self.num_atoms)
        # [bs; num_heads; 1; num_atoms] -> gathering best actions
        # [bs; num_heads; num_atoms] -> many-heads view transform
        # [{bs * num_heads}; num_atoms]
        logits_tp1 = (
            q_logits_tp1.gather(-2, indices_tp1).squeeze(-2).view(
                -1, self.num_atoms
            )
        ).detach()

        # [bs; num_heads; num_atoms] -> many-heads view transform
        # [{bs * num_heads}; num_atoms]
        atoms_target_t = (rewards_t + (1 - done_t) * gammas *
                          self.z).view(-1, self.num_atoms).detach()

        value_loss = utils.categorical_loss(
            # [{bs * num_heads}; num_atoms]
            logits_t,
            # [{bs * num_heads}; num_atoms]
            logits_tp1,
            # [{bs * num_heads}; num_atoms]
            atoms_target_t,
            self.z,
            self.delta_z,
            self.v_min,
            self.v_max
        )

        if self.entropy_regularization is not None:
            q_values_t = torch.sum(
                torch.softmax(q_logits_t, dim=-1) * self.z, dim=-1
            )
            value_loss -= \
                self.entropy_regularization * self._compute_entropy(q_values_t)

        return value_loss