예제 #1
0
    def test_warm_start(self, ):
        n_features = 2
        n = 100
        random_state = 123
        X, y, _ = self._get_regression_data(n, n_features, random_state)

        for inference in [True, False]:
            forest = RegressionForest(n_estimators=4,
                                      inference=inference,
                                      warm_start=True,
                                      random_state=123).fit(X, y)
            forest.n_estimators = 8
            forest.fit(X, y)
            pred1 = forest.predict(X)
            inds1 = forest.get_subsample_inds()
            tree_states1 = [t.random_state for t in forest]

            forest = RegressionForest(n_estimators=8,
                                      inference=inference,
                                      warm_start=True,
                                      random_state=123).fit(X, y)
            pred2 = forest.predict(X)
            inds2 = forest.get_subsample_inds()
            tree_states2 = [t.random_state for t in forest]

            np.testing.assert_allclose(pred1, pred2)
            np.testing.assert_allclose(inds1, inds2)
            np.testing.assert_allclose(tree_states1, tree_states2)
        return
예제 #2
0
    def test_raise_exceptions(self, ):
        # test that we raise errors in mishandled situations.
        n_features = 2
        n = 10
        random_state = 123
        X, y, truth = self._get_regression_data(n, n_features, random_state)
        with np.testing.assert_raises(ValueError):
            forest = RegressionForest(n_estimators=20).fit(X, y[:4])
        with np.testing.assert_raises(ValueError):
            forest = RegressionForest(n_estimators=20,
                                      subforest_size=3).fit(X, y)
        with np.testing.assert_raises(ValueError):
            forest = RegressionForest(n_estimators=20,
                                      inference=True,
                                      max_samples=.6).fit(X, y)
        with np.testing.assert_raises(ValueError):
            forest = RegressionForest(n_estimators=20,
                                      max_samples=20).fit(X, y)
        with np.testing.assert_raises(ValueError):
            forest = RegressionForest(n_estimators=20,
                                      max_samples=1.2).fit(X, y)
        with np.testing.assert_raises(ValueError):
            forest = RegressionForest(n_estimators=4,
                                      warm_start=True,
                                      inference=True).fit(X, y)
            forest.inference = False
            forest.n_estimators = 8
            forest.fit(X, y)
        with np.testing.assert_raises(ValueError):
            forest = CausalForest(n_estimators=4,
                                  criterion='peculiar').fit(X, y, y)
        with np.testing.assert_raises(ValueError):
            forest = CausalForest(n_estimators=4, max_depth=-1).fit(X, y, y)
        with np.testing.assert_raises(ValueError):
            forest = CausalForest(n_estimators=4,
                                  min_samples_split=-1).fit(X, y, y)
        with np.testing.assert_raises(ValueError):
            forest = CausalForest(n_estimators=4,
                                  min_samples_leaf=-1).fit(X, y, y)
        with np.testing.assert_raises(ValueError):
            forest = CausalForest(n_estimators=4,
                                  min_weight_fraction_leaf=-1.0).fit(X, y, y)
        with np.testing.assert_raises(ValueError):
            forest = CausalForest(n_estimators=4,
                                  min_var_fraction_leaf=-1.0).fit(X, y, y)
        with np.testing.assert_raises(ValueError):
            forest = CausalForest(n_estimators=4, max_features=10).fit(X, y, y)
        with np.testing.assert_raises(ValueError):
            forest = CausalForest(n_estimators=4,
                                  min_balancedness_tol=.55).fit(X, y, y)

        return
예제 #3
0
    def test_pickling(self,):

        n_features = 2
        n = 10
        random_state = 123
        X, y, _ = self._get_regression_data(n, n_features, random_state)

        forest = RegressionForest(n_estimators=4, warm_start=True, random_state=123).fit(X, y)
        forest.n_estimators = 8
        forest.fit(X, y)
        pred1 = forest.predict(X)

        joblib.dump(forest, 'forest.jbl')
        loaded_forest = joblib.load('forest.jbl')
        np.testing.assert_equal(loaded_forest.n_estimators, forest.n_estimators)
        np.testing.assert_allclose(loaded_forest.predict(X), pred1)