コード例 #1
0
def test_wrong_params(threshold):
    preds, target = _input_mcls_prob.preds, _input_mcls_prob.target

    with pytest.raises(ValueError):
        ham_dist = HammingDistance(threshold=threshold)
        ham_dist(preds, target)
        ham_dist.compute()

    with pytest.raises(ValueError):
        hamming_distance(preds, target, threshold=threshold)
コード例 #2
0
def test_v1_5_metric_classif_mix():
    ConfusionMatrix.__init__._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        ConfusionMatrix(num_classes=1)

    FBeta.__init__._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        FBeta(num_classes=1)

    F1.__init__._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        F1(num_classes=1)

    HammingDistance.__init__._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        HammingDistance()

    StatScores.__init__._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        StatScores()

    target = torch.tensor([1, 1, 0, 0])
    preds = torch.tensor([0, 1, 0, 0])
    confusion_matrix._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        assert torch.equal(
            confusion_matrix(preds, target, num_classes=2).float(),
            torch.tensor([[2.0, 0.0], [1.0, 1.0]]))

    target = torch.tensor([0, 1, 2, 0, 1, 2])
    preds = torch.tensor([0, 2, 1, 0, 0, 1])
    fbeta._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        assert torch.allclose(fbeta(preds, target, num_classes=3, beta=0.5),
                              torch.tensor(0.3333),
                              atol=1e-4)

    f1._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        assert torch.allclose(f1(preds, target, num_classes=3),
                              torch.tensor(0.3333),
                              atol=1e-4)

    target = torch.tensor([[0, 1], [1, 1]])
    preds = torch.tensor([[0, 1], [0, 1]])
    hamming_distance._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        assert hamming_distance(preds, target) == torch.tensor(0.25)

    preds = torch.tensor([1, 0, 2, 1])
    target = torch.tensor([1, 1, 2, 0])
    stat_scores._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        assert torch.equal(stat_scores(preds, target, reduce="micro"),
                           torch.tensor([2, 2, 6, 2, 4]))