def test_wrong_params(threshold): preds, target = _input_mcls_prob.preds, _input_mcls_prob.target with pytest.raises(ValueError): ham_dist = HammingDistance(threshold=threshold) ham_dist(preds, target) ham_dist.compute() with pytest.raises(ValueError): hamming_distance(preds, target, threshold=threshold)
def test_v1_5_metric_classif_mix(): ConfusionMatrix.__init__._warned = False with pytest.deprecated_call(match="It will be removed in v1.5.0"): ConfusionMatrix(num_classes=1) FBeta.__init__._warned = False with pytest.deprecated_call(match="It will be removed in v1.5.0"): FBeta(num_classes=1) F1.__init__._warned = False with pytest.deprecated_call(match="It will be removed in v1.5.0"): F1(num_classes=1) HammingDistance.__init__._warned = False with pytest.deprecated_call(match="It will be removed in v1.5.0"): HammingDistance() StatScores.__init__._warned = False with pytest.deprecated_call(match="It will be removed in v1.5.0"): StatScores() target = torch.tensor([1, 1, 0, 0]) preds = torch.tensor([0, 1, 0, 0]) confusion_matrix._warned = False with pytest.deprecated_call(match="It will be removed in v1.5.0"): assert torch.equal( confusion_matrix(preds, target, num_classes=2).float(), torch.tensor([[2.0, 0.0], [1.0, 1.0]])) target = torch.tensor([0, 1, 2, 0, 1, 2]) preds = torch.tensor([0, 2, 1, 0, 0, 1]) fbeta._warned = False with pytest.deprecated_call(match="It will be removed in v1.5.0"): assert torch.allclose(fbeta(preds, target, num_classes=3, beta=0.5), torch.tensor(0.3333), atol=1e-4) f1._warned = False with pytest.deprecated_call(match="It will be removed in v1.5.0"): assert torch.allclose(f1(preds, target, num_classes=3), torch.tensor(0.3333), atol=1e-4) target = torch.tensor([[0, 1], [1, 1]]) preds = torch.tensor([[0, 1], [0, 1]]) hamming_distance._warned = False with pytest.deprecated_call(match="It will be removed in v1.5.0"): assert hamming_distance(preds, target) == torch.tensor(0.25) preds = torch.tensor([1, 0, 2, 1]) target = torch.tensor([1, 1, 2, 0]) stat_scores._warned = False with pytest.deprecated_call(match="It will be removed in v1.5.0"): assert torch.equal(stat_scores(preds, target, reduce="micro"), torch.tensor([2, 2, 6, 2, 4]))