Ejemplo n.º 1
0
 def test_get_estimator(self, boston_X, boston_y):
     forest = RangerForestRegressor(n_estimators=10)
     with pytest.raises(NotFittedError):
         _ = forest.get_estimator(idx=0)
     forest.fit(boston_X, boston_y)
     with pytest.raises(ValueError):
         _ = forest.get_estimator(0)
     forest = RangerForestRegressor(n_estimators=10, enable_tree_details=True)
     forest.fit(boston_X, boston_y)
     estimator = forest.get_estimator(0)
     estimator.predict(boston_X)
     assert isinstance(estimator, RangerTreeRegressor)
     with pytest.raises(IndexError):
         _ = forest.get_estimator(idx=20)
Ejemplo n.º 2
0
def test_plot():
    from matplotlib import pyplot as plt
    from sklearn.datasets import load_boston
    from sklearn.tree import plot_tree

    boston_X, boston_y = load_boston(return_X_y=True)
    forest = RangerForestRegressor(enable_tree_details=True)
    forest.fit(boston_X, boston_y)
    estimator = forest.get_estimator(0)
    plt.figure()
    plot_tree(
        estimator,
        impurity=False,  # impurity not yet implemented
    )
    plt.savefig(
        "tree.svg",
        bbox_inches="tight",  # don't truncate
    )