Ejemplo n.º 1
0
def test_confusion_matrix():
    # custom confusion_matrix
    gt_labels = [np.int64(random.randint(0, 9)) for _ in range(100)]
    pred_labels = np.random.randint(10, size=100, dtype=np.int64)
    confusion_mat = confusion_matrix(pred_labels, gt_labels)
    gt_confusion_mat = gt_confusion_matrix(gt_labels, pred_labels)
    assert np.array_equal(confusion_mat, gt_confusion_mat)

    with pytest.raises(TypeError):
        # y_pred must be list or np.ndarray
        confusion_matrix(0.5, [1])

    with pytest.raises(TypeError):
        # y_real must be list or np.ndarray
        confusion_matrix([1], 0.5)

    with pytest.raises(TypeError):
        # y_pred dtype must be np.int64
        confusion_matrix([0.5], [1])

    with pytest.raises(TypeError):
        # y_real dtype must be np.int64
        confusion_matrix([1], [0.5])
Ejemplo n.º 2
0
def test_confusion_matrix():
    # custom confusion_matrix
    gt_labels = [np.int64(random.randint(0, 9)) for _ in range(100)]
    pred_labels = np.random.randint(10, size=100, dtype=np.int64)

    for normalize in [None, 'true', 'pred', 'all']:
        cf_mat = confusion_matrix(pred_labels, gt_labels, normalize)
        gt_cf_mat = gt_confusion_matrix(gt_labels, pred_labels, normalize)
        assert_array_equal(cf_mat, gt_cf_mat)

    with pytest.raises(ValueError):
        # normalize must be in ['true', 'pred', 'all', None]
        confusion_matrix([1], [1], 'unsupport')

    with pytest.raises(TypeError):
        # y_pred must be list or np.ndarray
        confusion_matrix(0.5, [1])

    with pytest.raises(TypeError):
        # y_real must be list or np.ndarray
        confusion_matrix([1], 0.5)

    with pytest.raises(TypeError):
        # y_pred dtype must be np.int64
        confusion_matrix([0.5], [1])

    with pytest.raises(TypeError):
        # y_real dtype must be np.int64
        confusion_matrix([1], [0.5])