예제 #1
0
def test_confusion_matrix_inconsistency():
    """Test of `confusion_matrix` function."""
    y_pred = np.array([0, 1])
    y_true = np.array([1])

    with pytest.raises(InvalidInput):
        confusion_matrix(y_pred, y_true)
예제 #2
0
def test_confusion_matrix_non_int_labels():
    """Test of `confusion_matrix` with non int labels."""
    y_true = np.array(["a", "b", "c", "b"])
    y_pred = np.array(["a", "b", "c", "a"])

    assert np.array_equal(confusion_matrix(y_pred, y_true),
                          np.array([[1, 0, 0], [1, 1, 0], [0, 0, 1]]))
예제 #3
0
def test_confusion_matrix_binary_labels():
    """Test of `confusion_matrix` with binary labels."""
    y_pred = np.array([1, 1, 1, 0, 1])
    y_true = np.array([0, 1, 1, 1, 1])

    assert np.array_equal(confusion_matrix(y_pred, y_true),
                          np.array([[0, 1], [1, 3]]))
예제 #4
0
def test_confusion_matrix_multiclass_labels():
    """Test of `confusion_matrix` with multiclass_labels."""
    y_true = np.array([0, 1, 2, 0, 1, 2])
    y_pred = np.array([0, 2, 1, 0, 0, 1])

    assert np.array_equal(confusion_matrix(y_pred, y_true),
                          np.array([[2, 0, 0], [1, 0, 1], [0, 2, 0]]))
예제 #5
0
def test_confusion_matrix_no_prediction():
    """Test of `confusion_matrix` with no prediction."""
    y_pred = np.array([0])
    y_true = np.array([1])

    assert np.array_equal(confusion_matrix(y_pred, y_true),
                          np.array([[0, 0], [1, 0]]))