def compute_forward_loss(self, mini_batch: AgentBuffer) -> torch.Tensor: """ Computes the loss for the next state prediction """ return torch.mean( ModelUtils.dynamic_partition( self.compute_reward(mini_batch), ModelUtils.list_to_tensor(mini_batch["masks"], dtype=torch.float), 2, )[1])
def compute_inverse_loss(self, mini_batch: AgentBuffer) -> torch.Tensor: """ Computes the inverse loss for a mini_batch. Corresponds to the error on the action prediction (given the current and next state). """ predicted_action = self.predict_action(mini_batch) actions = AgentAction.from_dict(mini_batch) _inverse_loss = 0 if self._action_spec.continuous_size > 0: sq_difference = ( actions.continuous_tensor - predicted_action.continuous ) ** 2 sq_difference = torch.sum(sq_difference, dim=1) _inverse_loss += torch.mean( ModelUtils.dynamic_partition( sq_difference, ModelUtils.list_to_tensor(mini_batch["masks"], dtype=torch.float), 2, )[1] ) if self._action_spec.discrete_size > 0: true_action = torch.cat( ModelUtils.actions_to_onehot( actions.discrete_tensor, self._action_spec.discrete_branches ), dim=1, ) cross_entropy = torch.sum( -torch.log(predicted_action.discrete + self.EPSILON) * true_action, dim=1, ) _inverse_loss += torch.mean( ModelUtils.dynamic_partition( cross_entropy, ModelUtils.list_to_tensor( mini_batch["masks"], dtype=torch.float ), # use masks not action_masks 2, )[1] ) return _inverse_loss
def compute_inverse_loss(self, mini_batch: AgentBuffer) -> torch.Tensor: """ Computes the inverse loss for a mini_batch. Corresponds to the error on the action prediction (given the current and next state). """ predicted_action = self.predict_action(mini_batch) if self._policy_specs.is_action_continuous(): sq_difference = ( ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.float) - predicted_action ) ** 2 sq_difference = torch.sum(sq_difference, dim=1) return torch.mean( ModelUtils.dynamic_partition( sq_difference, ModelUtils.list_to_tensor(mini_batch["masks"], dtype=torch.float), 2, )[1] ) else: true_action = torch.cat( ModelUtils.actions_to_onehot( ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.long), self._policy_specs.discrete_action_branches, ), dim=1, ) cross_entropy = torch.sum( -torch.log(predicted_action + self.EPSILON) * true_action, dim=1 ) return torch.mean( ModelUtils.dynamic_partition( cross_entropy, ModelUtils.list_to_tensor( mini_batch["masks"], dtype=torch.float ), # use masks not action_masks 2, )[1] )