Exemplo n.º 1
0
def test_v1_4_0_deprecated_metrics():
    from pytorch_lightning.metrics.functional.classification import stat_scores_multiple_classes
    with pytest.deprecated_call(match='will be removed in v1.4'):
        stat_scores_multiple_classes(pred=torch.tensor([0, 1]),
                                     target=torch.tensor([0, 1]))

    from pytorch_lightning.metrics.functional.classification import iou
    with pytest.deprecated_call(match='will be removed in v1.4'):
        iou(torch.randint(0, 2, (10, 3, 3)), torch.randint(0, 2, (10, 3, 3)))

    from pytorch_lightning.metrics.functional.classification import recall
    with pytest.deprecated_call(match='will be removed in v1.4'):
        recall(torch.randint(0, 2, (10, 3, 3)),
               torch.randint(0, 2, (10, 3, 3)))

    from pytorch_lightning.metrics.functional.classification import precision
    with pytest.deprecated_call(match='will be removed in v1.4'):
        precision(torch.randint(0, 2, (10, 3, 3)),
                  torch.randint(0, 2, (10, 3, 3)))

    from pytorch_lightning.metrics.functional.classification import precision_recall
    with pytest.deprecated_call(match='will be removed in v1.4'):
        precision_recall(torch.randint(0, 2, (10, 3, 3)),
                         torch.randint(0, 2, (10, 3, 3)))

    # Testing deprecation of class_reduction arg in the *new* precision
    from pytorch_lightning.metrics.functional import precision
    with pytest.deprecated_call(match='will be removed in v1.4'):
        precision(torch.randint(0, 2, (10, )),
                  torch.randint(0, 2, (10, )),
                  class_reduction='micro')

    # Testing deprecation of class_reduction arg in the *new* recall
    from pytorch_lightning.metrics.functional import recall
    with pytest.deprecated_call(match='will be removed in v1.4'):
        recall(torch.randint(0, 2, (10, )),
               torch.randint(0, 2, (10, )),
               class_reduction='micro')

    from pytorch_lightning.metrics.functional.classification import auc
    with pytest.deprecated_call(match='will be removed in v1.4'):
        auc(torch.rand(10, ).sort().values, torch.rand(10, ))

    from pytorch_lightning.metrics.functional.classification import auroc
    with pytest.deprecated_call(match='will be removed in v1.4'):
        auroc(torch.rand(10, ), torch.randint(0, 2, (10, )))

    from pytorch_lightning.metrics.functional.classification import multiclass_auroc
    with pytest.deprecated_call(match='will be removed in v1.4'):
        multiclass_auroc(torch.rand(20, 5).softmax(dim=-1),
                         torch.randint(0, 5, (20, )),
                         num_classes=5)

    from pytorch_lightning.metrics.functional.classification import auc_decorator
    with pytest.deprecated_call(match='will be removed in v1.4'):
        auc_decorator()

    from pytorch_lightning.metrics.functional.classification import multiclass_auc_decorator
    with pytest.deprecated_call(match='will be removed in v1.4'):
        multiclass_auc_decorator()
    def forward(self, pred: torch.Tensor,
                target: torch.Tensor) -> torch.Tensor:
        """
        Actual metric computation

        Args:
            pred: predicted labels
            target: groundtruth labels

        Return:
            torch.Tensor: classification score
        """
        return precision_recall(pred=pred,
                                target=target,
                                num_classes=self.num_classes,
                                class_reduction='none',
                                return_state=True)
Exemplo n.º 3
0
def compute_evaluation_metrics(outputs: List[List[torch.Tensor]],
                               plot: bool = False,
                               prefix: Optional[str] = None) -> Dict[str, torch.Tensor]:
    scores = torch.cat(list((scores for step in outputs for scores in step[0])))
    # NOTE: Need sigmoid here because we skip the sigmoid in forward() due to using BCE with logits for loss.
    #scores = torch.sigmoid(scores)
    print('Score range: [{}, {}]'
          .format(torch.min(scores).item(),
                  torch.max(scores).item()))
    labels = torch.cat(list((labels for step in outputs for labels in step[1])))

    auc = auroc(scores, labels, pos_label=1)
    fpr, tpr, thresholds = roc(scores, labels, pos_label=1)
    prec, recall = precision_recall(scores, labels)

    # mypy massaging, single tensors when num_classes is not specified (= binary case).
    fpr = cast(torch.Tensor, fpr)
    tpr = cast(torch.Tensor, tpr)
    thresholds = cast(torch.Tensor, thresholds)

    fnr = 1 - tpr
    eer, eer_threshold, idx = equal_error_rate(fpr, fnr, thresholds)
    min_dcf, min_dcf_threshold = minDCF(fpr, fnr, thresholds)

    # Accuracy based on EER and minDCF thresholds.
    eer_preds = (scores >= eer_threshold).long()
    min_dcf_preds = (scores >= min_dcf_threshold).long()
    eer_acc = torch.sum(eer_preds == labels).float() / labels.numel()
    min_dcf_acc = torch.sum(min_dcf_preds == labels).float() / labels.numel()

    if plot:
        assert idx.dim() == 0 or (idx.dim() == 1 and idx.size(0) == 1)
        i = int(idx.item())
        fpr = fpr.cpu().numpy()
        tpr = tpr.cpu().numpy()
        plt.xlabel('False positive rate')
        plt.ylabel('True positive rate')
        plt.plot([0, 1], [0, 1], 'r--', label='Reference', alpha=0.6)
        plt.plot([1, 0], [0, 1], 'k--', label='EER line', alpha=0.6)
        plt.plot(fpr, tpr, label='ROC curve')
        plt.fill_between(fpr, tpr, 0, label='AUC', color='0.8')
        plt.plot(fpr[i], tpr[i], 'ko', label='EER = {:.2f}%'.format(eer * 100))  # EER point
        plt.legend()
        plt.show()

    if prefix:
        prefix = '{}_'.format(prefix)
    else:
        prefix = ''

    return {
        '{}eer'.format(prefix): eer,
        '{}eer_acc'.format(prefix): eer_acc,
        '{}eer_threshold'.format(prefix): eer_threshold,
        '{}auc'.format(prefix): auc,
        '{}min_dcf'.format(prefix): min_dcf,
        '{}min_dcf_acc'.format(prefix): min_dcf_acc,
        '{}min_dcf_threshold'.format(prefix): min_dcf_threshold,
        '{}prec'.format(prefix): prec,
        '{}recall'.format(prefix): recall
    }