def test_grid_search(self): # Test that the best estimator contains the right value for foo_param clf = MockClassifier() grid_search = ATGridSearchCV(clf, {'foo_param': [1, 2, 3]}, verbose=3, webserver_url=self.live_server_url) # make sure it selects the smallest parameter in case of ties old_stdout = sys.stdout sys.stdout = StringIO() wait(grid_search.fit(X, y)) sys.stdout = old_stdout self.assertIn(grid_search.best_estimator_.foo_param, [2, 3]) _mock_sort = partial(_sort_grid_scores, param='foo_param') for idx, tup in enumerate( sorted(grid_search.grid_scores_, key=cmp_to_key(_mock_sort))): self.assertEqual(tup[0], {'foo_param': idx + 1}, '%d%s' % (idx, tup)) # Smoke test the score etc: grid_search.score(X, y) grid_search.predict_proba(X) grid_search.decision_function(X) grid_search.transform(X) # Test exception handling on scoring grid_search.scoring = 'sklearn' self.assertRaises(ValueError, grid_search.fit, X, y)
def test_grid_search_no_score(self): # Test grid-search on classifier that has no score function. clf = LinearSVC(random_state=0) X, y = make_blobs(random_state=0, centers=2) Cs = [.1, 1, 10] clf_no_score = LinearSVCNoScore(random_state=0) grid_search = ATGridSearchCV(clf, {'C': Cs}, scoring='accuracy', webserver_url=self.live_server_url) wait(grid_search.fit(X, y)) grid_search_no_score = ATGridSearchCV( clf_no_score, {'C': Cs}, scoring='accuracy', webserver_url=self.live_server_url) # smoketest grid search wait(grid_search_no_score.fit(X, y)) # check that best params are equal try: assert_equal(grid_search_no_score.best_params_, grid_search.best_params_) except AssertionError: if grid_search.best_params_ == {'C': 1}: assert_equal(grid_search_no_score.best_params_, {'C': 10}) else: assert_equal(grid_search_no_score.best_params_, {'C': 1}) # check that we can call score and that it gives the correct result assert_equal(grid_search.score(X, y), grid_search_no_score.score(X, y)) # giving no scoring function raises an error grid_search_no_score = ATGridSearchCV( clf_no_score, {'C': Cs}, webserver_url=self.live_server_url) assert_raise_message(TypeError, "no scoring", grid_search_no_score.fit, [[1]], webserver_url=self.live_server_url)