def compute_loss(self, inputs, labels, outputs, grad_enabled, **kwargs):
        torch.set_grad_enabled(grad_enabled)

        pred_before = outputs['pred_before']
        grad_pred = outputs['grad_pred']
        y = labels[0].to(self.device)
        y_one_hot = F.one_hot(y, num_classes=self.num_classes).float()

        # classification loss
        classifier_loss = F.cross_entropy(input=outputs['pred'], target=y)

        # compute grad actual
        if self.detach:
            # NOTE: we detach here too, so that the classifier is trained using the predicted gradient only
            pred_softmax = torch.softmax(pred_before.detach(), dim=1)
        else:
            pred_softmax = torch.softmax(pred_before, dim=1)
        if self.loss_function in ['ce', 'none']:
            grad_actual = pred_softmax - y_one_hot
        elif self.loss_function == 'mae':
            grad_actual = torch.sum(pred_softmax * y_one_hot, dim=1).unsqueeze(dim=-1) *\
                          (pred_softmax - y_one_hot)
        else:
            raise NotImplementedError()

        # I(g : y | x) penalty
        if self.q_dist == 'Gaussian':
            info_penalty = losses.mse(grad_pred, grad_actual)
        elif self.q_dist == 'Laplace':
            info_penalty = losses.mae(grad_pred, grad_actual)
        elif self.q_dist == 'dot':
            # this corresponds to Taylor approximation of L(w + g_t)
            info_penalty = -torch.mean(
                (grad_pred * grad_actual).sum(dim=1), dim=0)
        elif self.q_dist == 'ce':
            # TODO: clarify which distribution will give this
            info_penalty = losses.get_classification_loss(
                target=y_one_hot,
                pred=outputs['q_label_pred'],
                loss_function='ce')
        else:
            raise NotImplementedError()

        batch_losses = {
            'classifier': classifier_loss,
            'info_penalty': info_penalty
        }

        # add predicted gradient norm penalty
        if self.grad_weight_decay > 0:
            grad_l2_loss = self.grad_weight_decay *\
                           torch.mean(torch.sum(grad_pred**2, dim=1), dim=0)
            batch_losses['pred_grad_l2'] = grad_l2_loss

        if self.grad_l1_penalty > 0:
            grad_l1_loss = self.grad_l1_penalty *\
                           torch.mean(torch.sum(torch.abs(grad_pred), dim=1), dim=0)
            batch_losses['pred_grad_l1'] = grad_l1_loss

        return batch_losses, outputs
    def compute_loss(self, inputs, labels, outputs, grad_enabled, **kwargs):
        torch.set_grad_enabled(grad_enabled)

        pred = outputs['pred']
        y = labels[0].to(self.device)

        # classification loss
        y_one_hot = F.one_hot(y, num_classes=self.num_classes).float()
        classifier_loss = losses.get_classification_loss(
            target=y_one_hot,
            pred=pred,
            loss_function=self.loss_function,
            loss_function_param=self.loss_function_param)

        batch_losses = {
            'classifier': classifier_loss,
        }

        return batch_losses, outputs
    def compute_loss(self, inputs, labels, grad_enabled, **kwargs):
        torch.set_grad_enabled(grad_enabled)

        info = self.forward(inputs=inputs, grad_enabled=grad_enabled)
        pred = info["pred"]
        y = labels[0].to(self.device)

        # classification loss
        y_one_hot = F.one_hot(y, num_classes=self.num_classes).float()
        classifier_loss = losses.get_classification_loss(
            target=y_one_hot,
            pred=pred,
            loss_function=self.loss_function,
            loss_function_param=self.loss_function_param,
        )

        batch_losses = {
            "classifier": classifier_loss,
        }

        return batch_losses, info