Ejemplo n.º 1
0
def test_fit_complex(complex_data_split, complex_fitted_lightgbm):
    X_train, X_test, y_train, y_test = complex_data_split

    plotter = TreeDependencePlotter(complex_fitted_lightgbm)

    plotter.fit(X_test, y_test)

    pd.testing.assert_frame_equal(plotter.X, X_test)
    pd.testing.assert_series_equal(plotter.y, pd.Series(y_test, index=X_test.index))
    assert plotter.fitted is True

    # Check if plotting doesnt cause errors
    with patch('matplotlib.pyplot.figure') as mock_plt:
        for binning in ["simple", "agglomerative", "quantile"]:
            _ = plotter.plot(feature='f2_missing', type_binning=binning)
Ejemplo n.º 2
0
def test_plot_input(X_y, clf):
    plotter = TreeDependencePlotter(clf).fit(X_y[0], X_y[1])
    with pytest.raises(ValueError):
        plotter.plot(feature="not a feature")
    with pytest.raises(ValueError):
        plotter.plot(feature=0, type_binning=5)
    with pytest.raises(ValueError):
        plotter.plot(feature=0, min_q=1, max_q=0)
Ejemplo n.º 3
0
def test_get_X_y_shap_with_q_cut_normal(X_y, clf):
    X, y = X_y

    plotter = TreeDependencePlotter(clf).fit(X, y)
    plotter.min_q, plotter.max_q = 0, 1

    X_cut, y_cut, shap_val = plotter._get_X_y_shap_with_q_cut(0)
    assert np.isclose(X[0], X_cut).all()
    assert y.equals(y_cut)

    plotter.min_q = 0.2
    plotter.max_q = 0.8

    X_cut, y_cut, shap_val = plotter._get_X_y_shap_with_q_cut(0)
    assert np.isclose(
        X_cut,
        [
            -1.48382902,
            -0.44947744,
            -1.38101231,
            -0.18261804,
            0.27514902,
            -0.27264455,
            -1.27251335,
            -2.10917352,
            -1.25945582,
        ],
    ).all()
    assert np.equal(y_cut.values, [1, 0, 0, 1, 1, 0, 0, 0, 0]).all()
Ejemplo n.º 4
0
def test_not_fitted(clf):
    plotter = TreeDependencePlotter(clf)
    assert plotter.fitted is False
Ejemplo n.º 5
0
def test__repr__(clf):
    plotter = TreeDependencePlotter(clf)
    assert str(plotter) == "Shap dependence plotter for RandomForestClassifier"
Ejemplo n.º 6
0
def test_plot_class_names(X_y, clf):
    plotter = TreeDependencePlotter(clf).fit(X_y[0], X_y[1], class_names=["a", "b"])
    fig = plotter.plot(feature=0)
    assert plotter.class_names == ["a", "b"]
Ejemplo n.º 7
0
def test_plot_normal(X_y, clf):
    plotter = TreeDependencePlotter(clf).fit(X_y[0], X_y[1])
    for binning in ["simple", "agglomerative", "quantile"]:
        fig = plotter.plot(feature=0, type_binning=binning)
Ejemplo n.º 8
0
def test_get_X_y_shap_with_q_cut_input(X_y, clf):
    plotter = TreeDependencePlotter(clf).fit(X_y[0], X_y[1])
    with pytest.raises(ValueError):
        plotter._get_X_y_shap_with_q_cut("not a feature")
Ejemplo n.º 9
0
def test_get_X_y_shap_with_q_cut_unfitted(clf):
    plotter = TreeDependencePlotter(clf)
    with pytest.raises(NotFittedError):
        plotter._get_X_y_shap_with_q_cut(0)