Beispiel #1
0
 def sac_policy_loss(
     self,
     log_probs: torch.Tensor,
     q1p_outs: Dict[str, torch.Tensor],
     loss_masks: torch.Tensor,
     discrete: bool,
 ) -> torch.Tensor:
     _ent_coef = torch.exp(self._log_ent_coef)
     mean_q1 = torch.mean(torch.stack(list(q1p_outs.values())), axis=0)
     if not discrete:
         mean_q1 = mean_q1.unsqueeze(1)
         batch_policy_loss = torch.mean(_ent_coef * log_probs - mean_q1, dim=1)
         policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)
     else:
         action_probs = log_probs.exp()
         branched_per_action_ent = ModelUtils.break_into_branches(
             log_probs * action_probs, self.act_size
         )
         branched_q_term = ModelUtils.break_into_branches(
             mean_q1 * action_probs, self.act_size
         )
         branched_policy_loss = torch.stack(
             [
                 torch.sum(_ent_coef[i] * _lp - _qt, dim=1, keepdim=True)
                 for i, (_lp, _qt) in enumerate(
                     zip(branched_per_action_ent, branched_q_term)
                 )
             ],
             dim=1,
         )
         batch_policy_loss = torch.squeeze(branched_policy_loss)
         policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)
     return policy_loss
Beispiel #2
0
    def sac_entropy_loss(
        self, log_probs: torch.Tensor, loss_masks: torch.Tensor, discrete: bool
    ) -> torch.Tensor:
        if not discrete:
            with torch.no_grad():
                target_current_diff = torch.sum(log_probs + self.target_entropy, dim=1)
            entropy_loss = -1 * ModelUtils.masked_mean(
                self._log_ent_coef * target_current_diff, loss_masks
            )
        else:
            with torch.no_grad():
                branched_per_action_ent = ModelUtils.break_into_branches(
                    log_probs * log_probs.exp(), self.act_size
                )
                target_current_diff_branched = torch.stack(
                    [
                        torch.sum(_lp, axis=1, keepdim=True) + _te
                        for _lp, _te in zip(
                            branched_per_action_ent, self.target_entropy
                        )
                    ],
                    axis=1,
                )
                target_current_diff = torch.squeeze(
                    target_current_diff_branched, axis=2
                )
            entropy_loss = -1 * ModelUtils.masked_mean(
                torch.mean(self._log_ent_coef * target_current_diff, axis=1), loss_masks
            )

        return entropy_loss
Beispiel #3
0
    def sac_value_loss(
        self,
        log_probs: torch.Tensor,
        values: Dict[str, torch.Tensor],
        q1p_out: Dict[str, torch.Tensor],
        q2p_out: Dict[str, torch.Tensor],
        loss_masks: torch.Tensor,
        discrete: bool,
    ) -> torch.Tensor:
        min_policy_qs = {}
        with torch.no_grad():
            _ent_coef = torch.exp(self._log_ent_coef)
            for name in values.keys():
                if not discrete:
                    min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name])
                else:
                    action_probs = log_probs.exp()
                    _branched_q1p = ModelUtils.break_into_branches(
                        q1p_out[name] * action_probs, self.act_size
                    )
                    _branched_q2p = ModelUtils.break_into_branches(
                        q2p_out[name] * action_probs, self.act_size
                    )
                    _q1p_mean = torch.mean(
                        torch.stack(
                            [
                                torch.sum(_br, dim=1, keepdim=True)
                                for _br in _branched_q1p
                            ]
                        ),
                        dim=0,
                    )
                    _q2p_mean = torch.mean(
                        torch.stack(
                            [
                                torch.sum(_br, dim=1, keepdim=True)
                                for _br in _branched_q2p
                            ]
                        ),
                        dim=0,
                    )

                    min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean)

        value_losses = []
        if not discrete:
            for name in values.keys():
                with torch.no_grad():
                    v_backup = min_policy_qs[name] - torch.sum(
                        _ent_coef * log_probs, dim=1
                    )
                value_loss = 0.5 * ModelUtils.masked_mean(
                    torch.nn.functional.mse_loss(values[name], v_backup), loss_masks
                )
                value_losses.append(value_loss)
        else:
            branched_per_action_ent = ModelUtils.break_into_branches(
                log_probs * log_probs.exp(), self.act_size
            )
            # We have to do entropy bonus per action branch
            branched_ent_bonus = torch.stack(
                [
                    torch.sum(_ent_coef[i] * _lp, dim=1, keepdim=True)
                    for i, _lp in enumerate(branched_per_action_ent)
                ]
            )
            for name in values.keys():
                with torch.no_grad():
                    v_backup = min_policy_qs[name] - torch.mean(
                        branched_ent_bonus, axis=0
                    )
                value_loss = 0.5 * ModelUtils.masked_mean(
                    torch.nn.functional.mse_loss(values[name], v_backup.squeeze()),
                    loss_masks,
                )
                value_losses.append(value_loss)
        value_loss = torch.mean(torch.stack(value_losses))
        if torch.isinf(value_loss).any() or torch.isnan(value_loss).any():
            raise UnityTrainerException("Inf found")
        return value_loss