コード例 #1
0
ファイル: test_classifier.py プロジェクト: r3v1/skranger
    def test_split_rule(self, iris_X, iris_y, split_rule):
        tree = RangerTreeClassifier(split_rule=split_rule)
        assert tree.criterion == split_rule

        if split_rule not in ["gini", "extratrees", "hellinger"]:
            with pytest.raises(ValueError):
                tree.fit(iris_X, iris_y)
            return

        # hellinger can only be used in binary classification
        if split_rule == "hellinger":
            with pytest.raises(ValueError):
                tree.fit(iris_X, iris_y)

        iris_2 = [0 if v == 2 else v for v in iris_y]
        tree.fit(iris_X, iris_2)

        if split_rule == "gini":
            assert tree.split_rule_ == 1
        elif split_rule == "extratrees":
            assert tree.split_rule_ == 5
        if split_rule == "hellinger":
            assert tree.split_rule_ == 7

        if split_rule == "extratrees":
            tree = RangerTreeClassifier(
                split_rule=split_rule,
                respect_categorical_features="partition",
                save_memory=True,
            )
            with pytest.raises(ValueError):
                tree.fit(iris_X, iris_y)
        else:
            tree = RangerTreeClassifier(split_rule=split_rule,
                                        num_random_splits=2)
            with pytest.raises(ValueError):
                tree.fit(iris_X, iris_y)
コード例 #2
0
ファイル: test_classifier.py プロジェクト: r3v1/skranger
    def test_sample_fraction(self, iris_X, iris_y):
        tree = RangerTreeClassifier(sample_fraction=[0.69])
        tree.fit(iris_X, iris_y)
        assert tree.sample_fraction_ == [0.69]
        tree = RangerTreeClassifier(sample_fraction=0.69)
        tree.fit(iris_X, iris_y)
        assert tree.sample_fraction_ == [0.69]

        # test with single record
        iris_X_record = iris_X[0:1, :]
        pred = tree.predict(iris_X_record)
        assert len(pred) == 1
        pred = tree.predict_proba(iris_X_record)
        assert len(pred) == 1
        pred = tree.predict_log_proba(iris_X_record)
        assert len(pred) == 1
コード例 #3
0
ファイル: test_classifier.py プロジェクト: r3v1/skranger
 def test_init(self):
     _ = RangerTreeClassifier()
コード例 #4
0
ファイル: test_classifier.py プロジェクト: r3v1/skranger
    def test_inbag(self, iris_X, iris_y):
        inbag = [[1, 2, 3]]
        tree = RangerTreeClassifier(inbag=inbag)
        tree.fit(iris_X, iris_y)

        # can't use inbag with sample weight
        tree = RangerTreeClassifier(inbag=inbag)
        with pytest.raises(ValueError):
            tree.fit(iris_X, iris_y, sample_weight=[1] * len(iris_y))

        # can't use class sampling and inbag
        tree = RangerTreeClassifier(inbag=inbag, sample_fraction=[1, 1])
        with pytest.raises(ValueError):
            tree.fit(iris_X, iris_y)