def test_v1_5_metric_auc_auroc(): AUC.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): AUC() ROC.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): ROC() AUROC.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): AUROC() x = torch.tensor([0, 1, 2, 3]) y = torch.tensor([0, 1, 2, 2]) auc._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert auc(x, y) == torch.tensor(4.) preds = torch.tensor([0, 1, 2, 3]) target = torch.tensor([0, 1, 1, 1]) roc._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): fpr, tpr, thrs = roc(preds, target, pos_label=1) assert torch.equal(fpr, torch.tensor([0., 0., 0., 0., 1.])) assert torch.allclose(tpr, torch.tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]), atol=1e-4) assert torch.equal(thrs, torch.tensor([4, 3, 2, 1, 0])) preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) target = torch.tensor([0, 0, 1, 1, 1]) auroc._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert auroc(preds, target) == torch.tensor(0.5)
def forward(self, logits: torch.FloatTensor, labels: torch.LongTensor): multi_prcs = self.multi_prc(pred=logits.softmax(dim=1), target=labels, sample_weight=None) avg_auprc = 0. for precision_, recall_, _ in multi_prcs: avg_auprc += auc(x=precision_, y=recall_, reorder=True) return torch.Tensor([avg_auprc / self.num_classes])
def forward(self, logits: torch.FloatTensor, labels: torch.LongTensor): multi_rocs = self.multi_roc(pred=logits.softmax(dim=1), target=labels, sample_weight=None) avg_auroc = 0. for fpr, tpr, _ in multi_rocs: avg_auroc += auc(x=fpr, y=tpr, reorder=True) return torch.Tensor([avg_auroc / self.num_classes])
def auprc_pytorch(y_true, y_scores): """ Compute AUPRC for 1 class Args: y_true (np.array): one hot encoded labels y_scores (np.array): model prediction Return: auc (float): the Area Under the Recall Precision curve """ device = y_true.device precision, recall, thresholds = plm.precision_recall_curve( y_scores, y_true) # To ensure the curve goes from (0,1) to (1,0) included recall = torch.cat( (torch.ones([1], device=device), recall, torch.zeros([1], device=device))) precision = torch.cat((torch.zeros([1], device=device), precision, torch.ones([1], device=device))) return plm.auc(recall, precision)
def compute(self) -> torch.Tensor: preds, targets = self._get_preds_and_targets() if torch.unique(targets).numel() == 1: return torch.tensor(np.nan) prec, recall, _ = precision_recall_curve(preds, targets) return auc(recall, prec) # type: ignore