コード例 #1
0
ファイル: test_forest.py プロジェクト: pyensemble/wildwood
    def test_n_estimators(self):
        clf = ForestClassifier()
        assert clf.n_estimators == 10
        clf = ForestClassifier(n_estimators=17)
        assert clf.n_estimators == 17
        clf.n_estimators = 42
        assert clf.n_estimators == 42
        with pytest.raises(ValueError,
                           match="n_estimators must be an integer number"):
            clf.n_estimators = 0.42
        with pytest.raises(ValueError,
                           match="n_estimators must be an integer number"):
            clf.n_estimators = None
        with pytest.raises(ValueError,
                           match="n_estimators must be an integer number"):
            clf.n_estimators = "4"
        with pytest.raises(ValueError, match="n_estimators must be >= 1"):
            clf.n_estimators = 0
        with pytest.raises(ValueError, match="n_estimators must be >= 1"):
            clf.n_estimators = -3

        clf = ForestClassifier()
        # Check that  the properties checks the _fitted flag
        clf._fitted = True
        with pytest.raises(
                ValueError,
                match="You cannot change n_estimators after calling fit"):
            clf.n_estimators = 0.42
コード例 #2
0
ファイル: test_forest.py プロジェクト: pyensemble/wildwood
    def test_n_samples_in_(self):
        clf = ForestClassifier()
        with pytest.raises(
                ValueError,
                match="You must call fit before asking for n_samples_in_"):
            _ = clf.n_samples_in_

        clf.n_samples_in_ = 42
        with pytest.raises(
                ValueError,
                match="You must call fit before asking for n_samples_in_"):
            _ = clf.n_samples_in_

        clf._fitted = True
        assert clf.n_samples_in_ == 42
        assert clf.n_samples_in_ == 42

        np.random.seed(42)
        X = np.random.randn(10, 3)
        y = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
        clf = ForestClassifier().fit(X, y)
        assert clf.n_samples_in_ == 10