def test_v1_3_0_deprecated_metrics(): from pytorch_lightning.metrics.functional.classification import to_onehot with pytest.deprecated_call(match='will be removed in v1.3'): to_onehot(torch.tensor([1, 2, 3])) from pytorch_lightning.metrics.functional.classification import to_categorical with pytest.deprecated_call(match='will be removed in v1.3'): to_categorical(torch.tensor([[0.2, 0.5], [0.9, 0.1]])) from pytorch_lightning.metrics.functional.classification import get_num_classes with pytest.deprecated_call(match='will be removed in v1.3'): get_num_classes(pred=torch.tensor([0, 1]), target=torch.tensor([1, 1])) x_binary = torch.tensor([0, 1, 2, 3]) y_binary = torch.tensor([0, 1, 2, 3]) from pytorch_lightning.metrics.functional.classification import roc with pytest.deprecated_call(match='will be removed in v1.3'): roc(pred=x_binary, target=y_binary) from pytorch_lightning.metrics.functional.classification import _roc with pytest.deprecated_call(match='will be removed in v1.3'): _roc(pred=x_binary, target=y_binary) x_multy = torch.tensor([ [0.85, 0.05, 0.05, 0.05], [0.05, 0.85, 0.05, 0.05], [0.05, 0.05, 0.85, 0.05], [0.05, 0.05, 0.05, 0.85], ]) y_multy = torch.tensor([0, 1, 3, 2]) from pytorch_lightning.metrics.functional.classification import multiclass_roc with pytest.deprecated_call(match='will be removed in v1.3'): multiclass_roc(pred=x_multy, target=y_multy) from pytorch_lightning.metrics.functional.classification import average_precision with pytest.deprecated_call(match='will be removed in v1.3'): average_precision(pred=x_binary, target=y_binary) from pytorch_lightning.metrics.functional.classification import precision_recall_curve with pytest.deprecated_call(match='will be removed in v1.3'): precision_recall_curve(pred=x_binary, target=y_binary) from pytorch_lightning.metrics.functional.classification import multiclass_precision_recall_curve with pytest.deprecated_call(match='will be removed in v1.3'): multiclass_precision_recall_curve(pred=x_multy, target=y_multy) from pytorch_lightning.metrics.functional.reduction import reduce with pytest.deprecated_call(match='will be removed in v1.3'): reduce(torch.tensor([0, 1, 1, 0]), 'sum') from pytorch_lightning.metrics.functional.reduction import class_reduce with pytest.deprecated_call(match='will be removed in v1.3'): class_reduce( torch.randint(1, 10, (50, )).float(), torch.randint(10, 20, (50, )).float(), torch.randint(1, 100, (50, )).float())
def test_to_categorical(): test_tensor = torch.tensor([ [ [1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0] ], [ [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1] ] ]).to(torch.float) expected = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) assert expected.shape == (2, 5) assert test_tensor.shape == (2, 10, 5) result = to_categorical(test_tensor) assert result.shape == expected.shape assert torch.allclose(result, expected.to(result.dtype))
def test_to_categorical(): test_tensor = torch.stack([ torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]), torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)]) ]).to(torch.float) expected = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) assert expected.shape == (2, 5) assert test_tensor.shape == (2, 10, 5) result = to_categorical(test_tensor) assert result.shape == expected.shape assert torch.allclose(result, expected.to(result.dtype))