def test_merged_labels_over_threshold_confusion_matrix():
    labels = range(200)
    more_labels = range(201, 300)
    matrix_1 = ConfusionMatrix(labels)
    matrix_2 = ConfusionMatrix(more_labels)

    with pytest.raises(ValueError):
Beispiel #2
def test_merge_with_none():
    targets_1 = ["cat", "dog", "pig"]
    predictions_1 = ["cat", "dog", "dog"]
    scores_1 = [0.1, 0.2, 0.4]

    labels_1 = ["cat", "dog", "pig"]
    matrix = ConfusionMatrix(labels_1)
    matrix.add(predictions_1, targets_1, scores_1)

    res = matrix.merge(other_cm=None)
    assert res.target_field is None
    assert res.prediction_field is None
    assert res.score_field is None
Beispiel #3
def test_merge_conf_matrix():
    tests merging two confusion matrices of different sizes
    targets_1 = ["cat", "dog", "pig"]
    targets_2 = ["cat", "dog"]

    predictions_1 = ["cat", "dog", "dog"]
    predictions_2 = ["cat", "cat"]

    scores_1 = [0.1, 0.2, 0.4]
    scores_2 = [0.4, 0.3]

    expected_1 = [[1, 0, 0], [0, 1, 1], [0, 0, 0]]
    expected_2 = [[1, 1], [0, 0]]
    expected_merge = [[2, 1, 0], [0, 1, 1], [0, 0, 0]]

    labels_1 = ["cat", "dog", "pig"]
    conf_M_1 = ConfusionMatrix(labels_1)
    conf_M_1.add(predictions_1, targets_1, scores_1)

    for idx, value in enumerate(conf_M_1.labels):
        for jdx, value_2 in enumerate(conf_M_1.labels):
            assert conf_M_1.confusion_matrix[
                idx, jdx].floats.count == expected_1[idx][jdx]
    labels_2 = ["cat", "dog"]
    conf_M_2 = ConfusionMatrix(labels_2)
    conf_M_2.add(predictions_2, targets_2, scores_2)

    for idx, value in enumerate(conf_M_2.labels):
        for jdx, value_2 in enumerate(conf_M_2.labels):
            assert conf_M_2.confusion_matrix[idx, jdx].floats.count == \

    new_conf = conf_M_1.merge(conf_M_2)

    for idx, value in enumerate(new_conf.labels):
        for jdx, value_2 in enumerate(new_conf.labels):
            print(idx, jdx)
            assert new_conf.confusion_matrix[
                idx, jdx].floats.count == expected_merge[idx][jdx]