def test_v1_3_0_deprecated_metrics(): from pytorch_lightning.metrics.functional.classification import to_onehot with pytest.deprecated_call(match='will be removed in v1.3'): to_onehot(torch.tensor([1, 2, 3])) from pytorch_lightning.metrics.functional.classification import to_categorical with pytest.deprecated_call(match='will be removed in v1.3'): to_categorical(torch.tensor([[0.2, 0.5], [0.9, 0.1]])) from pytorch_lightning.metrics.functional.classification import get_num_classes with pytest.deprecated_call(match='will be removed in v1.3'): get_num_classes(pred=torch.tensor([0, 1]), target=torch.tensor([1, 1])) x_binary = torch.tensor([0, 1, 2, 3]) y_binary = torch.tensor([0, 1, 2, 3]) from pytorch_lightning.metrics.functional.classification import roc with pytest.deprecated_call(match='will be removed in v1.3'): roc(pred=x_binary, target=y_binary) from pytorch_lightning.metrics.functional.classification import _roc with pytest.deprecated_call(match='will be removed in v1.3'): _roc(pred=x_binary, target=y_binary) x_multy = torch.tensor([ [0.85, 0.05, 0.05, 0.05], [0.05, 0.85, 0.05, 0.05], [0.05, 0.05, 0.85, 0.05], [0.05, 0.05, 0.05, 0.85], ]) y_multy = torch.tensor([0, 1, 3, 2]) from pytorch_lightning.metrics.functional.classification import multiclass_roc with pytest.deprecated_call(match='will be removed in v1.3'): multiclass_roc(pred=x_multy, target=y_multy) from pytorch_lightning.metrics.functional.classification import average_precision with pytest.deprecated_call(match='will be removed in v1.3'): average_precision(pred=x_binary, target=y_binary) from pytorch_lightning.metrics.functional.classification import precision_recall_curve with pytest.deprecated_call(match='will be removed in v1.3'): precision_recall_curve(pred=x_binary, target=y_binary) from pytorch_lightning.metrics.functional.classification import multiclass_precision_recall_curve with pytest.deprecated_call(match='will be removed in v1.3'): multiclass_precision_recall_curve(pred=x_multy, target=y_multy) from pytorch_lightning.metrics.functional.reduction import reduce with pytest.deprecated_call(match='will be removed in v1.3'): reduce(torch.tensor([0, 1, 1, 0]), 'sum') from pytorch_lightning.metrics.functional.reduction import class_reduce with pytest.deprecated_call(match='will be removed in v1.3'): class_reduce( torch.randint(1, 10, (50, )).float(), torch.randint(10, 20, (50, )).float(), torch.randint(1, 100, (50, )).float())
def confusion_matrix(pred: torch.Tensor, target: torch.Tensor, num_classes: int = None) -> torch.Tensor: num_classes = get_num_classes(pred, target, num_classes) unique_labels = target.view(-1) * num_classes + pred.view(-1) bins = torch.bincount(unique_labels, minlength=num_classes**2) cm = bins.reshape(num_classes, num_classes).squeeze().float() return cm
def test_get_num_classes(pred, target, num_classes, expected_num_classes): assert get_num_classes(pred, target, num_classes) == expected_num_classes