Exemple #1
0
    def compute_loss(self, logits: torch.Tensor, labels: torch.Tensor, task: EdgeProbingTask):
        """ Compute loss & eval metrics.

        Expect logits and labels to be already "selected" for good targets,
        i.e. this function does not do any masking internally.

        Args:
            logits: [total_num_targets, n_classes] Tensor of float scores
            labels: [total_num_targets, n_classes] Tensor of sparse binary targets

        Returns:
            loss: scalar Tensor
        """
        binary_preds = logits.ge(0).long()  # {0,1}

        # Matthews coefficient and accuracy computed on {0,1} labels.
        task.mcc_scorer(binary_preds, labels.long())
        task.acc_scorer(binary_preds, labels.long())

        # F1Measure() expects [total_num_targets, n_classes, 2]
        # to compute binarized F1.
        binary_scores = torch.stack([-1 * logits, logits], dim=2)
        task.f1_scorer(binary_scores, labels)

        if self.loss_type == "sigmoid":
            return F.binary_cross_entropy(torch.sigmoid(logits), labels.float())
        else:
            raise ValueError("Unsupported loss type '%s' " "for edge probing." % self.loss_type)
Exemple #2
0
    def compute_loss(self, logits: torch.Tensor, labels: torch.Tensor,
                     task: EdgeProbingTask):
        """ Compute loss & eval metrics.

        Expect logits and labels to be already "selected" for good targets,
        i.e. this function does not do any masking internally.

        Args:
            logits: [total_num_targets, n_classes] Tensor of float scores
            labels: [total_num_targets, n_classes] Tensor of sparse binary targets

        Returns:
            loss: scalar Tensor
        """
        if self.loss_type == "sigmoid":
            binary_preds = logits.ge(0).long()  # {0,1}

            # Matthews coefficient and accuracy computed on {0,1} labels.
            task.mcc_scorer(binary_preds, labels.long())
            task.acc_scorer(binary_preds, labels.long())

            #print("\n\n\n", torch.sum(labels, dim=1), "\n\n\n")

            # F1Measure() expects [total_num_targets, n_classes, 2]
            # to compute binarized F1.
            binary_scores = torch.stack([-1 * logits, logits], dim=2)
            task.f1_scorer(binary_scores, labels)

            loss = F.binary_cross_entropy(torch.sigmoid(logits),
                                          labels.float())
            task.xent_scorer(loss.mean().item())
            return loss

        elif self.loss_type == "softmax":

            preds = one_hot(logits.argmax(dim=-1), depth=logits.shape[-1])

            # Matthews coefficient and accuracy computed on {0,1} labels.
            task.mcc_scorer(preds.long(), labels.long())
            task.acc_scorer(preds.long(), labels.long())

            # print("\n\n\n", torch.sum(labels, dim=1), "\n\n\n")

            # F1Measure() expects [total_num_targets, n_classes, 2]
            # to compute binarized F1.
            binary_scores = torch.stack([-1 * logits, logits], dim=2)
            task.f1_scorer(binary_scores, labels)

            loss = F.cross_entropy(logits, labels.argmax(dim=-1))
            task.xent_scorer(loss.mean().item())
            return loss
        else:
            raise ValueError("Unsupported loss type '%s' "
                             "for edge probing." % self.loss_type)