def test_v1_4_0_deprecated_metrics(): from pytorch_lightning.metrics.functional.classification import stat_scores_multiple_classes with pytest.deprecated_call(match='will be removed in v1.4'): stat_scores_multiple_classes(pred=torch.tensor([0, 1]), target=torch.tensor([0, 1])) from pytorch_lightning.metrics.functional.classification import iou with pytest.deprecated_call(match='will be removed in v1.4'): iou(torch.randint(0, 2, (10, 3, 3)), torch.randint(0, 2, (10, 3, 3))) from pytorch_lightning.metrics.functional.classification import recall with pytest.deprecated_call(match='will be removed in v1.4'): recall(torch.randint(0, 2, (10, 3, 3)), torch.randint(0, 2, (10, 3, 3))) from pytorch_lightning.metrics.functional.classification import precision with pytest.deprecated_call(match='will be removed in v1.4'): precision(torch.randint(0, 2, (10, 3, 3)), torch.randint(0, 2, (10, 3, 3))) from pytorch_lightning.metrics.functional.classification import precision_recall with pytest.deprecated_call(match='will be removed in v1.4'): precision_recall(torch.randint(0, 2, (10, 3, 3)), torch.randint(0, 2, (10, 3, 3))) # Testing deprecation of class_reduction arg in the *new* precision from pytorch_lightning.metrics.functional import precision with pytest.deprecated_call(match='will be removed in v1.4'): precision(torch.randint(0, 2, (10, )), torch.randint(0, 2, (10, )), class_reduction='micro') # Testing deprecation of class_reduction arg in the *new* recall from pytorch_lightning.metrics.functional import recall with pytest.deprecated_call(match='will be removed in v1.4'): recall(torch.randint(0, 2, (10, )), torch.randint(0, 2, (10, )), class_reduction='micro') from pytorch_lightning.metrics.functional.classification import auc with pytest.deprecated_call(match='will be removed in v1.4'): auc(torch.rand(10, ).sort().values, torch.rand(10, )) from pytorch_lightning.metrics.functional.classification import auroc with pytest.deprecated_call(match='will be removed in v1.4'): auroc(torch.rand(10, ), torch.randint(0, 2, (10, ))) from pytorch_lightning.metrics.functional.classification import multiclass_auroc with pytest.deprecated_call(match='will be removed in v1.4'): multiclass_auroc(torch.rand(20, 5).softmax(dim=-1), torch.randint(0, 5, (20, )), num_classes=5) from pytorch_lightning.metrics.functional.classification import auc_decorator with pytest.deprecated_call(match='will be removed in v1.4'): auc_decorator() from pytorch_lightning.metrics.functional.classification import multiclass_auc_decorator with pytest.deprecated_call(match='will be removed in v1.4'): multiclass_auc_decorator()
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Actual metric computation Args: pred: predicted labels target: groundtruth labels Return: torch.Tensor: classification score """ return precision_recall(pred=pred, target=target, num_classes=self.num_classes, class_reduction='none', return_state=True)
def compute_evaluation_metrics(outputs: List[List[torch.Tensor]], plot: bool = False, prefix: Optional[str] = None) -> Dict[str, torch.Tensor]: scores = torch.cat(list((scores for step in outputs for scores in step[0]))) # NOTE: Need sigmoid here because we skip the sigmoid in forward() due to using BCE with logits for loss. #scores = torch.sigmoid(scores) print('Score range: [{}, {}]' .format(torch.min(scores).item(), torch.max(scores).item())) labels = torch.cat(list((labels for step in outputs for labels in step[1]))) auc = auroc(scores, labels, pos_label=1) fpr, tpr, thresholds = roc(scores, labels, pos_label=1) prec, recall = precision_recall(scores, labels) # mypy massaging, single tensors when num_classes is not specified (= binary case). fpr = cast(torch.Tensor, fpr) tpr = cast(torch.Tensor, tpr) thresholds = cast(torch.Tensor, thresholds) fnr = 1 - tpr eer, eer_threshold, idx = equal_error_rate(fpr, fnr, thresholds) min_dcf, min_dcf_threshold = minDCF(fpr, fnr, thresholds) # Accuracy based on EER and minDCF thresholds. eer_preds = (scores >= eer_threshold).long() min_dcf_preds = (scores >= min_dcf_threshold).long() eer_acc = torch.sum(eer_preds == labels).float() / labels.numel() min_dcf_acc = torch.sum(min_dcf_preds == labels).float() / labels.numel() if plot: assert idx.dim() == 0 or (idx.dim() == 1 and idx.size(0) == 1) i = int(idx.item()) fpr = fpr.cpu().numpy() tpr = tpr.cpu().numpy() plt.xlabel('False positive rate') plt.ylabel('True positive rate') plt.plot([0, 1], [0, 1], 'r--', label='Reference', alpha=0.6) plt.plot([1, 0], [0, 1], 'k--', label='EER line', alpha=0.6) plt.plot(fpr, tpr, label='ROC curve') plt.fill_between(fpr, tpr, 0, label='AUC', color='0.8') plt.plot(fpr[i], tpr[i], 'ko', label='EER = {:.2f}%'.format(eer * 100)) # EER point plt.legend() plt.show() if prefix: prefix = '{}_'.format(prefix) else: prefix = '' return { '{}eer'.format(prefix): eer, '{}eer_acc'.format(prefix): eer_acc, '{}eer_threshold'.format(prefix): eer_threshold, '{}auc'.format(prefix): auc, '{}min_dcf'.format(prefix): min_dcf, '{}min_dcf_acc'.format(prefix): min_dcf_acc, '{}min_dcf_threshold'.format(prefix): min_dcf_threshold, '{}prec'.format(prefix): prec, '{}recall'.format(prefix): recall }