コード例 #1
0
def test_v1_3_0_deprecated_metrics():
    from pytorch_lightning.metrics.functional.classification import to_onehot
    with pytest.deprecated_call(match='will be removed in v1.3'):
        to_onehot(torch.tensor([1, 2, 3]))

    from pytorch_lightning.metrics.functional.classification import to_categorical
    with pytest.deprecated_call(match='will be removed in v1.3'):
        to_categorical(torch.tensor([[0.2, 0.5], [0.9, 0.1]]))

    from pytorch_lightning.metrics.functional.classification import get_num_classes
    with pytest.deprecated_call(match='will be removed in v1.3'):
        get_num_classes(pred=torch.tensor([0, 1]), target=torch.tensor([1, 1]))

    x_binary = torch.tensor([0, 1, 2, 3])
    y_binary = torch.tensor([0, 1, 2, 3])

    from pytorch_lightning.metrics.functional.classification import roc
    with pytest.deprecated_call(match='will be removed in v1.3'):
        roc(pred=x_binary, target=y_binary)

    from pytorch_lightning.metrics.functional.classification import _roc
    with pytest.deprecated_call(match='will be removed in v1.3'):
        _roc(pred=x_binary, target=y_binary)

    x_multy = torch.tensor([
        [0.85, 0.05, 0.05, 0.05],
        [0.05, 0.85, 0.05, 0.05],
        [0.05, 0.05, 0.85, 0.05],
        [0.05, 0.05, 0.05, 0.85],
    ])
    y_multy = torch.tensor([0, 1, 3, 2])

    from pytorch_lightning.metrics.functional.classification import multiclass_roc
    with pytest.deprecated_call(match='will be removed in v1.3'):
        multiclass_roc(pred=x_multy, target=y_multy)

    from pytorch_lightning.metrics.functional.classification import average_precision
    with pytest.deprecated_call(match='will be removed in v1.3'):
        average_precision(pred=x_binary, target=y_binary)

    from pytorch_lightning.metrics.functional.classification import precision_recall_curve
    with pytest.deprecated_call(match='will be removed in v1.3'):
        precision_recall_curve(pred=x_binary, target=y_binary)

    from pytorch_lightning.metrics.functional.classification import multiclass_precision_recall_curve
    with pytest.deprecated_call(match='will be removed in v1.3'):
        multiclass_precision_recall_curve(pred=x_multy, target=y_multy)

    from pytorch_lightning.metrics.functional.reduction import reduce
    with pytest.deprecated_call(match='will be removed in v1.3'):
        reduce(torch.tensor([0, 1, 1, 0]), 'sum')

    from pytorch_lightning.metrics.functional.reduction import class_reduce
    with pytest.deprecated_call(match='will be removed in v1.3'):
        class_reduce(
            torch.randint(1, 10, (50, )).float(),
            torch.randint(10, 20, (50, )).float(),
            torch.randint(1, 100, (50, )).float())
コード例 #2
0
ファイル: model.py プロジェクト: jjedele/pytorch-isic2018
def confusion_matrix(pred: torch.Tensor,
                     target: torch.Tensor,
                     num_classes: int = None) -> torch.Tensor:
    num_classes = get_num_classes(pred, target, num_classes)

    unique_labels = target.view(-1) * num_classes + pred.view(-1)

    bins = torch.bincount(unique_labels, minlength=num_classes**2)
    cm = bins.reshape(num_classes, num_classes).squeeze().float()

    return cm
コード例 #3
0
def test_get_num_classes(pred, target, num_classes, expected_num_classes):
    assert get_num_classes(pred, target, num_classes) == expected_num_classes