def test_predict_proba_shape():
    n_test_samples = 200
    X, y = make_classification(n_samples=1000)
    X_test, y_test = make_classification(n_samples=n_test_samples)
    pool = RandomForestClassifier(max_depth=3).fit(X, y)
    oracle = Oracle(pool_classifiers=pool).fit(X, y)

    proba = oracle.predict_proba(X_test, y_test)
    assert proba.shape == (n_test_samples, 2)
def test_predict_proba_right_class():
    n_test_samples = 200
    X, y = make_classification(n_samples=1000)
    X_test, y_test = make_classification(n_samples=n_test_samples)
    pool = RandomForestClassifier(max_depth=3).fit(X, y)
    oracle = Oracle(pool_classifiers=pool).fit(X, y)

    preds = oracle.predict(X_test, y_test)
    proba = oracle.predict_proba(X_test, y_test)
    probas_max = np.argmax(proba, axis=1)
    assert np.allclose(probas_max, preds)