Beispiel #1
0
    def _quantile_loss(self, states_t, actions_t, rewards_t, states_tp1,
                       done_t):
        gammas, done_t, rewards_t = self._process_components(done_t, rewards_t)

        # actor loss
        policy_loss = -torch.mean(self.critic(states_t, self.actor(states_t)))

        # critic loss (quantile regression)
        # [bs; num_heads; num_atoms]
        atoms_t = self.critic(states_t, actions_t).squeeze_(dim=2)
        # [bs; num_heads; num_atoms]
        atoms_tp1 = self.target_critic(
            states_tp1,
            self.target_actor(states_tp1)).squeeze_(dim=2).detach()

        # [bs; num_heads; num_atoms]
        atoms_target_t = rewards_t + (1 - done_t) * gammas * atoms_tp1
        value_loss = utils.quantile_loss(
            # [{bs * num_heads}; num_atoms]
            atoms_t.view(-1, self.num_atoms),
            # [{bs * num_heads}; num_atoms]
            atoms_target_t.view(-1, self.num_atoms),
            self.tau,
            self.num_atoms,
            self.critic_criterion)

        return policy_loss, value_loss
Beispiel #2
0
    def _quantile_loss(
        self, states_t, actions_t, rewards_t, states_tp1, done_t
    ):
        gammas, done_t, rewards_t = self._process_components(done_t, rewards_t)

        # [bs; 1] ->
        # [bs; 1; 1; 1;]
        actions_t = actions_t[:, None, None, :]

        # [bs; num_heads; 1; num_atoms]
        indices_t = actions_t.repeat(1, self._num_heads, 1, self.num_atoms)
        # [bs; num_heads; num_actions; num_atoms]
        q_atoms_t = self.critic(states_t)
        # [bs; num_heads; 1; num_atoms] -> gathering selected actions
        # [bs; num_heads; num_atoms] -> many-heads view transform
        # [{bs * num_heads}; num_atoms]
        atoms_t = (
            q_atoms_t.gather(-2,
                             indices_t).squeeze(-2).view(-1, self.num_atoms)
        )

        # [bs; num_heads; num_actions; num_atoms]
        q_atoms_tp1 = self.target_critic(states_tp1)

        # [bs; num_heads; num_actions; num_atoms] -> quantile value
        # [bs; num_heads; num_actions] -> gathering best actions
        # [bs; num_heads; 1]
        actions_tp1 = (q_atoms_tp1.mean(dim=-1).argmax(dim=-1, keepdim=True))
        # [bs; num_heads; 1] ->
        # [bs; num_heads; 1; 1] ->
        # [bs; num_heads; 1; num_atoms]
        indices_tp1 = actions_tp1.unsqueeze(-1).repeat(1, 1, 1, self.num_atoms)
        # [bs; num_heads; 1; num_atoms] -> gathering best actions
        # [bs; num_heads; num_atoms]
        atoms_tp1 = q_atoms_tp1.gather(-2, indices_tp1).squeeze(-2)

        # [bs; num_heads; num_atoms] -> many-heads view transform
        # [{bs * num_heads}; num_atoms]
        atoms_target_t = (rewards_t + (1 - done_t) * gammas *
                          atoms_tp1).view(-1, self.num_atoms).detach()

        value_loss = utils.quantile_loss(
            # [{bs * num_heads}; num_atoms]
            atoms_t,
            # [{bs * num_heads}; num_atoms]
            atoms_target_t,
            self.tau,
            self.num_atoms,
            self.critic_criterion
        )

        if self.entropy_regularization is not None:
            q_values_t = torch.mean(q_atoms_t, dim=-1)
            value_loss -= \
                self.entropy_regularization * self._compute_entropy(q_values_t)

        return value_loss
Beispiel #3
0
    def _quantile_value_loss(self, states_t, atoms_t, returns_t, states_tp1,
                             done_t):
        # @TODO: WIP, no guaranties
        atoms_tp0 = self.critic(states_t).squeeze_(dim=2)
        values_tp0 = torch.mean(atoms_tp0, dim=-1, keepdim=True)

        values_t = torch.mean(atoms_t, dim=-1, keepdim=True)

        value_loss = 0.5 * self._value_loss(values_tp0, values_t, returns_t)

        # B x num_heads x num_atoms
        atoms_tp1 = self.critic(states_tp1).squeeze_(dim=2).detach()
        # B x num_heads x num_atoms
        atoms_target_t = returns_t \
            + (1 - done_t) * self._gammas_torch * atoms_tp1

        value_loss += 0.5 * utils.quantile_loss(
            atoms_tp0.view(-1, self.num_atoms),
            atoms_target_t.view(-1, self.num_atoms), self.tau, self.num_atoms,
            self.critic_criterion)

        return value_loss
Beispiel #4
0
    def _quantile_loss(self, states_t, actions_t, rewards_t, states_tp1,
                       done_t):
        gammas, done_t, rewards_t = self._process_components(done_t, rewards_t)

        # actor loss
        # [bs; action_size]
        actions_tp0, logprob_tp0 = self.actor(states_t, logprob=True)
        logprob_tp0 = logprob_tp0[:, None] / self.reward_scale
        # {num_critics} * [bs; num_heads; num_atoms; 1]
        atoms_tp0 = [
            x(states_t, actions_tp0).squeeze_(dim=2).unsqueeze_(-1)
            for x in self.critics
        ]
        # [bs; num_heads, num_atoms; num_critics] -> many-heads view transform
        # [{bs * num_heads}; num_atoms; num_critics] ->  quantile value
        # [{bs * num_heads}; num_critics] ->  min over all critics
        # [{bs * num_heads};]
        q_values_tp0_min = (torch.cat(atoms_tp0, dim=-1).view(
            -1, self.num_atoms, self._num_critics).mean(dim=1).min(dim=1)[0])
        # Again, we use the same actor for each head
        policy_loss = torch.mean(logprob_tp0[:, None] - q_values_tp0_min)

        # critic loss (quantile regression)
        # [bs; action_size]
        actions_tp1, logprob_tp1 = self.actor(states_tp1, logprob=True)
        logprob_tp1 = logprob_tp1[:, None] / self.reward_scale

        # {num_critics} * [bs; num_heads; num_atoms]
        # -> many-heads view transform
        # {num_critics} * [{bs * num_heads}; num_atoms]
        atoms_t = [
            x(states_t, actions_t).squeeze_(dim=2).view(-1, self.num_atoms)
            for x in self.critics
        ]

        # [bs; num_heads; num_atoms; num_critics]
        atoms_tp1 = torch.cat([
            x(states_tp1, actions_tp1).squeeze_(dim=2).unsqueeze_(-1)
            for x in self.target_critics
        ],
                              dim=-1)
        # [{bs * num_heads}, ]
        atoms_ids_tp1_min = atoms_tp1.mean(dim=-2).argmin(dim=-1).view(-1)
        # @TODO smarter way to do this (other than reshaping)?
        # [bs; num_heads; num_atoms; num_critics] -> many-heads view transform
        # [{bs * num_heads}; num_atoms; num_critics] -> min over all critics
        # [{bs * num_heads}; num_atoms; 1] -> target view transform
        # [bs; num_heads; num_atoms]
        atoms_tp1 = (atoms_tp1.view(
            -1, self.num_atoms,
            self._num_critics)[range(len(atoms_ids_tp1_min)), :,
                               atoms_ids_tp1_min].view(-1, self._num_heads,
                                                       self.num_atoms))

        # Same log_pi for each head.
        # [bs; num_heads; num_atoms]
        atoms_tp1 = (atoms_tp1 - logprob_tp1.unsqueeze(1)).detach()
        # [bs; num_heads; num_atoms] -> many-heads view transform
        # [{bs * num_heads}; num_atoms]
        atoms_target_t = (rewards_t + (1 - done_t) * gammas * atoms_tp1).view(
            -1, self.num_atoms).detach()

        value_loss = [
            utils.quantile_loss(
                # [{bs * num_heads}; num_atoms]
                x,
                # [{bs * num_heads}; num_atoms]
                atoms_target_t,
                self.tau,
                self.num_atoms,
                self.critic_criterion) for x in atoms_t
        ]

        return policy_loss, value_loss
Beispiel #5
0
    def _quantile_loss(self, states_t, actions_t, rewards_t, states_tp1,
                       done_t):
        gammas, done_t, rewards_t = self._process_components(done_t, rewards_t)

        # actor loss
        # [bs; action_size]
        actions_tp0 = self.actor(states_t)
        # {num_critics} * [bs; num_heads; num_atoms; 1]
        atoms_tp0 = [
            x(states_t, actions_tp0).squeeze_(dim=2).unsqueeze_(-1)
            for x in self.critics
        ]
        # [bs; num_heads, num_atoms; num_critics] -> many-heads view transform
        # [{bs * num_heads}; num_atoms; num_critics] ->  quantile value
        # [{bs * num_heads}; num_critics] ->  min over all critics
        # [{bs * num_heads};]
        q_values_tp0_min = (torch.cat(atoms_tp0, dim=-1).view(
            -1, self.num_atoms, self._num_critics).mean(dim=1).min(dim=1)[0])
        policy_loss = -torch.mean(q_values_tp0_min)

        # critic loss (quantile regression)
        # [bs; action_size]
        actions_tp1 = self.target_actor(states_tp1)
        actions_tp1 = self._add_noise_to_actions(actions_tp1).detach()

        # {num_critics} * [bs; num_heads; num_atoms]
        # -> many-heads view transform
        # {num_critics} * [{bs * num_heads}; num_atoms]
        atoms_t = [
            x(states_t, actions_t).squeeze_(dim=2).view(-1, self.num_atoms)
            for x in self.critics
        ]

        # [bs; num_heads; num_atoms; num_critics]
        atoms_tp1 = torch.cat([
            x(states_tp1, actions_tp1).squeeze_(dim=2).unsqueeze_(-1)
            for x in self.target_critics
        ],
                              dim=-1)
        # @TODO: smarter way to do this (other than reshaping)?
        # [{bs * num_heads}; ]
        atoms_ids_tp1_min = atoms_tp1.mean(dim=-2).argmin(dim=-1).view(-1)
        # [bs; num_heads; num_atoms; num_critics] -> many-heads view transform
        # [{bs * num_heads}; num_atoms; num_critics] -> min over all critics
        # [{bs * num_heads}; num_atoms; 1] -> target view transform
        # [bs; num_heads; num_atoms]
        atoms_tp1 = (atoms_tp1.view(
            -1, self.num_atoms,
            self._num_critics)[range(len(atoms_ids_tp1_min)), :,
                               atoms_ids_tp1_min].view(-1, self._num_heads,
                                                       self.num_atoms))

        # [bs; num_heads; num_atoms] -> many-heads view transform
        # [{bs * num_heads}; num_atoms]
        atoms_target_t = (rewards_t + (1 - done_t) * gammas * atoms_tp1).view(
            -1, self.num_atoms).detach()

        value_loss = [
            utils.quantile_loss(
                # [{bs * num_heads}; num_atoms]
                x,
                # [{bs * num_heads}; num_atoms]
                atoms_target_t,
                self.tau,
                self.num_atoms,
                self.critic_criterion) for x in atoms_t
        ]

        return policy_loss, value_loss