def compute_forward_loss(self, mini_batch: AgentBuffer) -> torch.Tensor:
     """
     Computes the loss for the next state prediction
     """
     return torch.mean(
         ModelUtils.dynamic_partition(
             self.compute_reward(mini_batch),
             ModelUtils.list_to_tensor(mini_batch["masks"],
                                       dtype=torch.float),
             2,
         )[1])
 def compute_inverse_loss(self, mini_batch: AgentBuffer) -> torch.Tensor:
     """
     Computes the inverse loss for a mini_batch. Corresponds to the error on the
     action prediction (given the current and next state).
     """
     predicted_action = self.predict_action(mini_batch)
     actions = AgentAction.from_dict(mini_batch)
     _inverse_loss = 0
     if self._action_spec.continuous_size > 0:
         sq_difference = (
             actions.continuous_tensor - predicted_action.continuous
         ) ** 2
         sq_difference = torch.sum(sq_difference, dim=1)
         _inverse_loss += torch.mean(
             ModelUtils.dynamic_partition(
                 sq_difference,
                 ModelUtils.list_to_tensor(mini_batch["masks"], dtype=torch.float),
                 2,
             )[1]
         )
     if self._action_spec.discrete_size > 0:
         true_action = torch.cat(
             ModelUtils.actions_to_onehot(
                 actions.discrete_tensor, self._action_spec.discrete_branches
             ),
             dim=1,
         )
         cross_entropy = torch.sum(
             -torch.log(predicted_action.discrete + self.EPSILON) * true_action,
             dim=1,
         )
         _inverse_loss += torch.mean(
             ModelUtils.dynamic_partition(
                 cross_entropy,
                 ModelUtils.list_to_tensor(
                     mini_batch["masks"], dtype=torch.float
                 ),  # use masks not action_masks
                 2,
             )[1]
         )
     return _inverse_loss
Пример #3
0
 def compute_inverse_loss(self, mini_batch: AgentBuffer) -> torch.Tensor:
     """
     Computes the inverse loss for a mini_batch. Corresponds to the error on the
     action prediction (given the current and next state).
     """
     predicted_action = self.predict_action(mini_batch)
     if self._policy_specs.is_action_continuous():
         sq_difference = (
             ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.float)
             - predicted_action
         ) ** 2
         sq_difference = torch.sum(sq_difference, dim=1)
         return torch.mean(
             ModelUtils.dynamic_partition(
                 sq_difference,
                 ModelUtils.list_to_tensor(mini_batch["masks"], dtype=torch.float),
                 2,
             )[1]
         )
     else:
         true_action = torch.cat(
             ModelUtils.actions_to_onehot(
                 ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.long),
                 self._policy_specs.discrete_action_branches,
             ),
             dim=1,
         )
         cross_entropy = torch.sum(
             -torch.log(predicted_action + self.EPSILON) * true_action, dim=1
         )
         return torch.mean(
             ModelUtils.dynamic_partition(
                 cross_entropy,
                 ModelUtils.list_to_tensor(
                     mini_batch["masks"], dtype=torch.float
                 ),  # use masks not action_masks
                 2,
             )[1]
         )