Exemplo n.º 1
0
    def test_shallow_planned_nested_indiv_operator(self):
        from lale.lib.sklearn import BaggingClassifier, DecisionTreeClassifier

        clf = BaggingClassifier(base_estimator=DecisionTreeClassifier())
        params = clf.get_params(deep=False)
        filtered_params = self.remove_lale_params(params)
        assert filtered_params["bootstrap"]
Exemplo n.º 2
0
    def test_deep_planned_nested_indiv_operator(self):
        from lale.lib.sklearn import BaggingClassifier, DecisionTreeClassifier

        dtc = DecisionTreeClassifier()
        clf = BaggingClassifier(base_estimator=dtc)
        params = clf.get_params(deep=True)
        filtered_params = self.remove_lale_params(params)

        # expected = LogisticRegression.get_defaults()
        base = filtered_params["base_estimator"]
        base_params = self.remove_lale_params(base.get_params(deep=True))
        nested_base_params = nest_HPparams("base_estimator", base_params)
        self.assertDictEqual(
            {
                k: v
                for k, v in filtered_params.items()
                if k.startswith("base_estimator__")
                and not k.startswith("base_estimator___lale")
            },
            nested_base_params,
        )
Exemplo n.º 3
0
    def test_deep_grammar(self):
        from lale.grammar import Grammar
        from lale.lib.sklearn import BaggingClassifier, DecisionTreeClassifier
        from lale.lib.sklearn import KNeighborsClassifier as KNN
        from lale.lib.sklearn import LogisticRegression as LR
        from lale.lib.sklearn import StandardScaler as Scaler

        dtc = DecisionTreeClassifier()
        clf = BaggingClassifier(base_estimator=dtc)
        params = clf.get_params(deep=True)
        filtered_params = self.remove_lale_params(params)

        g = Grammar()
        g.start = g.estimator
        g.estimator = (NoOp | g.transformer) >> g.prim_est
        g.transformer = (NoOp | g.transformer) >> g.prim_tfm

        g.prim_est = LR | KNN
        g.prim_tfm = PCA | Scaler

        params = g.get_params(deep=True)
        filtered_params = self.remove_lale_params(params)
        assert filtered_params["start__name"] == "estimator"
        assert filtered_params["prim_est__LogisticRegression__penalty"] == "l2"