Пример #1
0
def test_v1_5_metric_auc_auroc():
    AUC.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        AUC()

    ROC.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        ROC()

    AUROC.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        AUROC()

    x = torch.tensor([0, 1, 2, 3])
    y = torch.tensor([0, 1, 2, 2])
    auc._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        assert auc(x, y) == torch.tensor(4.)

    preds = torch.tensor([0, 1, 2, 3])
    target = torch.tensor([0, 1, 1, 1])
    roc._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        fpr, tpr, thrs = roc(preds, target, pos_label=1)
    assert torch.equal(fpr, torch.tensor([0., 0., 0., 0., 1.]))
    assert torch.allclose(tpr, torch.tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]), atol=1e-4)
    assert torch.equal(thrs, torch.tensor([4, 3, 2, 1, 0]))

    preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34])
    target = torch.tensor([0, 0, 1, 1, 1])
    auroc._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        assert auroc(preds, target) == torch.tensor(0.5)
Пример #2
0
    def forward(self, logits: torch.FloatTensor, labels: torch.LongTensor):
        multi_prcs = self.multi_prc(pred=logits.softmax(dim=1),
                                    target=labels,
                                    sample_weight=None)
        avg_auprc = 0.
        for precision_, recall_, _ in multi_prcs:
            avg_auprc += auc(x=precision_, y=recall_, reorder=True)

        return torch.Tensor([avg_auprc / self.num_classes])
Пример #3
0
    def forward(self, logits: torch.FloatTensor, labels: torch.LongTensor):
        multi_rocs = self.multi_roc(pred=logits.softmax(dim=1),
                                    target=labels,
                                    sample_weight=None)
        avg_auroc = 0.
        for fpr, tpr, _ in multi_rocs:
            avg_auroc += auc(x=fpr, y=tpr, reorder=True)

        return torch.Tensor([avg_auroc / self.num_classes])
Пример #4
0
def auprc_pytorch(y_true, y_scores):
    """ Compute AUPRC for 1 class
        Args:
            y_true (np.array): one hot encoded labels
            y_scores (np.array): model prediction
        Return:
            auc (float): the Area Under the Recall Precision curve
    """
    device = y_true.device
    precision, recall, thresholds = plm.precision_recall_curve(
        y_scores, y_true)
    # To ensure the curve goes from (0,1) to (1,0) included
    recall = torch.cat(
        (torch.ones([1], device=device), recall, torch.zeros([1],
                                                             device=device)))
    precision = torch.cat((torch.zeros([1], device=device), precision,
                           torch.ones([1], device=device)))

    return plm.auc(recall, precision)
Пример #5
0
 def compute(self) -> torch.Tensor:
     preds, targets = self._get_preds_and_targets()
     if torch.unique(targets).numel() == 1:
         return torch.tensor(np.nan)
     prec, recall, _ = precision_recall_curve(preds, targets)
     return auc(recall, prec)  # type: ignore