def test_v1_3_0_deprecated_metrics(): from pytorch_lightning.metrics.functional.classification import to_onehot with pytest.deprecated_call(match='will be removed in v1.3'): to_onehot(torch.tensor([1, 2, 3])) from pytorch_lightning.metrics.functional.classification import to_categorical with pytest.deprecated_call(match='will be removed in v1.3'): to_categorical(torch.tensor([[0.2, 0.5], [0.9, 0.1]])) from pytorch_lightning.metrics.functional.classification import get_num_classes with pytest.deprecated_call(match='will be removed in v1.3'): get_num_classes(pred=torch.tensor([0, 1]), target=torch.tensor([1, 1])) x_binary = torch.tensor([0, 1, 2, 3]) y_binary = torch.tensor([0, 1, 2, 3]) from pytorch_lightning.metrics.functional.classification import roc with pytest.deprecated_call(match='will be removed in v1.3'): roc(pred=x_binary, target=y_binary) from pytorch_lightning.metrics.functional.classification import _roc with pytest.deprecated_call(match='will be removed in v1.3'): _roc(pred=x_binary, target=y_binary) x_multy = torch.tensor([ [0.85, 0.05, 0.05, 0.05], [0.05, 0.85, 0.05, 0.05], [0.05, 0.05, 0.85, 0.05], [0.05, 0.05, 0.05, 0.85], ]) y_multy = torch.tensor([0, 1, 3, 2]) from pytorch_lightning.metrics.functional.classification import multiclass_roc with pytest.deprecated_call(match='will be removed in v1.3'): multiclass_roc(pred=x_multy, target=y_multy) from pytorch_lightning.metrics.functional.classification import average_precision with pytest.deprecated_call(match='will be removed in v1.3'): average_precision(pred=x_binary, target=y_binary) from pytorch_lightning.metrics.functional.classification import precision_recall_curve with pytest.deprecated_call(match='will be removed in v1.3'): precision_recall_curve(pred=x_binary, target=y_binary) from pytorch_lightning.metrics.functional.classification import multiclass_precision_recall_curve with pytest.deprecated_call(match='will be removed in v1.3'): multiclass_precision_recall_curve(pred=x_multy, target=y_multy) from pytorch_lightning.metrics.functional.reduction import reduce with pytest.deprecated_call(match='will be removed in v1.3'): reduce(torch.tensor([0, 1, 1, 0]), 'sum') from pytorch_lightning.metrics.functional.reduction import class_reduce with pytest.deprecated_call(match='will be removed in v1.3'): class_reduce( torch.randint(1, 10, (50, )).float(), torch.randint(10, 20, (50, )).float(), torch.randint(1, 100, (50, )).float())
def test_average_precision_constant_values(): # Check the average_precision_score of a constant predictor is # the TPR # Generate a dataset with 25% of positives target = torch.zeros(100, dtype=torch.float) target[::4] = 1 # And a constant score pred = torch.ones(100) # The precision is then the fraction of positive whatever the recall # is, as there is only one threshold: assert average_precision(pred, target).item() == .25
def forward(self, pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None) -> torch.Tensor: """ Actual metric computation Args: pred: predicted labels target: groundtruth labels sample_weight: the weights per sample Return: torch.Tensor: classification score """ return average_precision(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label)
def test_average_precision(scores, target, expected_score): assert average_precision(scores, target) == expected_score