예제 #1
0
    def test_manual_grid(self):
        from sklearn.datasets import load_iris

        from lale.lib.lale import GridSearchCV
        from lale.lib.sklearn import SVC

        warnings.simplefilter("ignore")

        from lale import wrap_imported_operators

        wrap_imported_operators()
        iris = load_iris()
        parameters = {"kernel": ("linear", "rbf"), "C": [1, 10]}
        svc = SVC()
        clf = GridSearchCV(estimator=svc, param_grid=parameters)
        clf.fit(iris.data, iris.target)
        clf.predict(iris.data)
예제 #2
0
 def test_break_ties(self):
     with self.assertRaisesRegex(jsonschema.ValidationError,
                                 "argument 'break_ties' was unexpected"):
         _ = SVC(break_ties=True)
예제 #3
0
 def test_gamma(self):
     default = SVC.hyperparam_defaults()["gamma"]
     self.assertEqual(default, "auto_deprecated")
예제 #4
0
 def test_with_defaults(self):
     trainable = SVC()
     trained = trainable.fit(self.train_X, self.train_y)
     _ = trained.predict(self.test_X)
예제 #5
0
    def test_svc(self):
        from lale.lib.sklearn import SVC

        reg = SVC(kernel='linear', gamma=1)
        reg.fit(self.X_train, self.y_train)
예제 #6
0
 def test_svc(self):
     reg = SVC(kernel="linear", gamma=1)
     reg.fit(self.X_train, self.y_train)
예제 #7
0
 def test_gamma(self):
     default = SVC.hyperparam_defaults()['gamma']
     self.assertEqual(default, 'auto_deprecated')