Exemplo n.º 1
0
    def test_grid_search(self):
        # Test that the best estimator contains the right value for foo_param
        clf = MockClassifier()
        grid_search = ATGridSearchCV(clf, {'foo_param': [1, 2, 3]},
                                     verbose=3,
                                     webserver_url=self.live_server_url)
        # make sure it selects the smallest parameter in case of ties
        old_stdout = sys.stdout
        sys.stdout = StringIO()
        wait(grid_search.fit(X, y))
        sys.stdout = old_stdout
        self.assertIn(grid_search.best_estimator_.foo_param, [2, 3])

        _mock_sort = partial(_sort_grid_scores, param='foo_param')

        for idx, tup in enumerate(
                sorted(grid_search.grid_scores_, key=cmp_to_key(_mock_sort))):
            self.assertEqual(tup[0], {'foo_param': idx + 1},
                             '%d%s' % (idx, tup))

        # Smoke test the score etc:
        grid_search.score(X, y)
        grid_search.predict_proba(X)
        grid_search.decision_function(X)
        grid_search.transform(X)

        # Test exception handling on scoring
        grid_search.scoring = 'sklearn'
        self.assertRaises(ValueError, grid_search.fit, X, y)
Exemplo n.º 2
0
    def test_grid_search_no_score(self):
        # Test grid-search on classifier that has no score function.
        clf = LinearSVC(random_state=0)
        X, y = make_blobs(random_state=0, centers=2)
        Cs = [.1, 1, 10]
        clf_no_score = LinearSVCNoScore(random_state=0)
        grid_search = ATGridSearchCV(clf, {'C': Cs},
                                     scoring='accuracy',
                                     webserver_url=self.live_server_url)
        wait(grid_search.fit(X, y))

        grid_search_no_score = ATGridSearchCV(
            clf_no_score, {'C': Cs},
            scoring='accuracy',
            webserver_url=self.live_server_url)
        # smoketest grid search
        wait(grid_search_no_score.fit(X, y))

        # check that best params are equal
        try:
            assert_equal(grid_search_no_score.best_params_,
                         grid_search.best_params_)
        except AssertionError:
            if grid_search.best_params_ == {'C': 1}:
                assert_equal(grid_search_no_score.best_params_, {'C': 10})
            else:
                assert_equal(grid_search_no_score.best_params_, {'C': 1})
        # check that we can call score and that it gives the correct result
        assert_equal(grid_search.score(X, y), grid_search_no_score.score(X, y))

        # giving no scoring function raises an error
        grid_search_no_score = ATGridSearchCV(
            clf_no_score, {'C': Cs}, webserver_url=self.live_server_url)
        assert_raise_message(TypeError,
                             "no scoring",
                             grid_search_no_score.fit, [[1]],
                             webserver_url=self.live_server_url)