Ejemplo n.º 1
0
def test_v1_4_0_deprecated_metrics():
    from pytorch_lightning.metrics.functional.classification import stat_scores_multiple_classes
    with pytest.deprecated_call(match='will be removed in v1.4'):
        stat_scores_multiple_classes(pred=torch.tensor([0, 1]),
                                     target=torch.tensor([0, 1]))

    from pytorch_lightning.metrics.functional.classification import iou
    with pytest.deprecated_call(match='will be removed in v1.4'):
        iou(torch.randint(0, 2, (10, 3, 3)), torch.randint(0, 2, (10, 3, 3)))

    from pytorch_lightning.metrics.functional.classification import recall
    with pytest.deprecated_call(match='will be removed in v1.4'):
        recall(torch.randint(0, 2, (10, 3, 3)),
               torch.randint(0, 2, (10, 3, 3)))

    from pytorch_lightning.metrics.functional.classification import precision
    with pytest.deprecated_call(match='will be removed in v1.4'):
        precision(torch.randint(0, 2, (10, 3, 3)),
                  torch.randint(0, 2, (10, 3, 3)))

    from pytorch_lightning.metrics.functional.classification import precision_recall
    with pytest.deprecated_call(match='will be removed in v1.4'):
        precision_recall(torch.randint(0, 2, (10, 3, 3)),
                         torch.randint(0, 2, (10, 3, 3)))

    # Testing deprecation of class_reduction arg in the *new* precision
    from pytorch_lightning.metrics.functional import precision
    with pytest.deprecated_call(match='will be removed in v1.4'):
        precision(torch.randint(0, 2, (10, )),
                  torch.randint(0, 2, (10, )),
                  class_reduction='micro')

    # Testing deprecation of class_reduction arg in the *new* recall
    from pytorch_lightning.metrics.functional import recall
    with pytest.deprecated_call(match='will be removed in v1.4'):
        recall(torch.randint(0, 2, (10, )),
               torch.randint(0, 2, (10, )),
               class_reduction='micro')

    from pytorch_lightning.metrics.functional.classification import auc
    with pytest.deprecated_call(match='will be removed in v1.4'):
        auc(torch.rand(10, ).sort().values, torch.rand(10, ))

    from pytorch_lightning.metrics.functional.classification import auroc
    with pytest.deprecated_call(match='will be removed in v1.4'):
        auroc(torch.rand(10, ), torch.randint(0, 2, (10, )))

    from pytorch_lightning.metrics.functional.classification import multiclass_auroc
    with pytest.deprecated_call(match='will be removed in v1.4'):
        multiclass_auroc(torch.rand(20, 5).softmax(dim=-1),
                         torch.randint(0, 5, (20, )),
                         num_classes=5)

    from pytorch_lightning.metrics.functional.classification import auc_decorator
    with pytest.deprecated_call(match='will be removed in v1.4'):
        auc_decorator()

    from pytorch_lightning.metrics.functional.classification import multiclass_auc_decorator
    with pytest.deprecated_call(match='will be removed in v1.4'):
        multiclass_auc_decorator()
def test_multiclass_auroc():
    with pytest.raises(
            ValueError,
            match=
            r".*probabilities, i.e. they should sum up to 1.0 over classes"):
        _ = multiclass_auroc(pred=torch.tensor([[0.9, 0.9], [1.0, 0]]),
                             target=torch.tensor([0, 1]))

    with pytest.raises(
            ValueError,
            match=
            r".*not defined when all of the classes do not occur in the target.*"
    ):
        _ = multiclass_auroc(pred=torch.rand((4, 3)).softmax(dim=1),
                             target=torch.tensor([1, 0, 1, 0]))

    with pytest.raises(
            ValueError,
            match=
            r".*does not equal the number of classes passed in 'num_classes'.*"
    ):
        _ = multiclass_auroc(pred=torch.rand((5, 4)).softmax(dim=1),
                             target=torch.tensor([0, 1, 2, 2, 3]),
                             num_classes=6)
def test_multiclass_auroc_against_sklearn(n_cls):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    n_samples = 300
    pred = torch.rand(n_samples, n_cls, device=device).softmax(dim=1)
    target = torch.randint(n_cls, (n_samples, ), device=device)
    # Make sure target includes all class labels so that multiclass AUROC is defined
    target[10:10 + n_cls] = torch.arange(n_cls)

    pl_score = multiclass_auroc(pred, target)
    # For the binary case, sklearn expects an (n_samples,) array of probabilities of
    # the positive class
    pred = pred[:, 1] if n_cls == 2 else pred
    sk_score = sk_roc_auc_score(target.cpu().detach().numpy(),
                                pred.cpu().detach().numpy(),
                                multi_class="ovr")

    sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
    assert torch.allclose(sk_score, pl_score)