def test_warm_start_detection(self): parameter_grid = {"alpha": Real(1e-4, 1e-1, 1)} from sklearn.ensemble import RandomForestClassifier clf = RandomForestClassifier(max_depth=2, random_state=0) tune_search = TuneSearchCV(clf, parameter_grid, n_jobs=1, max_iters=10, local_dir="./test-result") self.assertFalse(tune_search._can_early_stop()) from sklearn.tree import DecisionTreeClassifier clf = DecisionTreeClassifier(random_state=0) tune_search2 = TuneSearchCV(clf, parameter_grid, n_jobs=1, max_iters=10, local_dir="./test-result") self.assertFalse(tune_search2._can_early_stop()) from sklearn.linear_model import LogisticRegression clf = LogisticRegression() tune_search3 = TuneSearchCV(clf, parameter_grid, n_jobs=1, max_iters=10, local_dir="./test-result") self.assertTrue(tune_search3._can_early_stop())
def test_warm_start_error(self): parameter_grid = {"alpha": Real(1e-4, 1e-1, 1)} from sklearn.ensemble import VotingClassifier, RandomForestClassifier clf = VotingClassifier(estimators=[( "rf", RandomForestClassifier(n_estimators=50, random_state=0))]) tune_search = TuneSearchCV( clf, parameter_grid, n_jobs=1, early_stopping=False, max_iters=10, local_dir="./test-result") self.assertFalse(tune_search._can_early_stop()) with self.assertRaises(ValueError): tune_search = TuneSearchCV( clf, parameter_grid, n_jobs=1, early_stopping=True, max_iters=10, local_dir="./test-result") from sklearn.linear_model import LogisticRegression clf = LogisticRegression() with self.assertRaises(ValueError): parameter_grid = {"max_iter": [1, 2]} TuneSearchCV( clf, parameter_grid, early_stopping=True, n_jobs=1, max_iters=10, local_dir="./test-result") from sklearn.ensemble import RandomForestClassifier clf = RandomForestClassifier() with self.assertRaises(ValueError): parameter_grid = {"n_estimators": [1, 2]} TuneSearchCV( clf, parameter_grid, early_stopping=True, n_jobs=1, max_iters=10, local_dir="./test-result")
def test_warm_start_error(self): parameter_grid = {"alpha": Real(1e-4, 1e-1, 1)} from sklearn.ensemble import RandomForestClassifier clf = RandomForestClassifier(max_depth=2, random_state=0) tune_search = TuneSearchCV(clf, parameter_grid, n_jobs=1, early_stopping=False, max_iters=10, local_dir="./test-result") self.assertFalse(tune_search._can_early_stop()) with self.assertRaises(ValueError): tune_search = TuneSearchCV(clf, parameter_grid, n_jobs=1, early_stopping=True, max_iters=10, local_dir="./test-result")