def test_multiclass_hinge_sgd(): for data in (mult_dense, mult_csr): for fit_intercept in (True, False): clf = SGDClassifier(loss="hinge", multiclass=True, fit_intercept=fit_intercept, random_state=0) clf.fit(data, mult_target) assert_greater(clf.score(data, mult_target), 0.78)
def test_multiclass_log_sgd(): for fit_intercept in (True, False): clf = SGDClassifier(loss="log", multiclass="natural", fit_intercept=fit_intercept, random_state=0) clf.fit(mult_dense, mult_target) assert_greater(clf.score(mult_dense, mult_target), 0.78)
def test_multiclass_squared_hinge_sgd(): for data in (mult_dense, mult_csr): for fit_intercept in (True, False): clf = SGDClassifier(loss="squared_hinge", multiclass=True, learning_rate="constant", eta0=1e-3, fit_intercept=fit_intercept, random_state=0) clf.fit(data, mult_target) assert_greater(clf.score(data, mult_target), 0.78)
def test_multiclass_hinge_sgd_l1l2(): for data in (mult_dense, mult_csr): clf = SGDClassifier(loss="hinge", penalty="l1/l2", multiclass=True, random_state=0) clf.fit(data, mult_target) assert_greater(clf.score(data, mult_target), 0.75)
def test_binary_linear_sgd(): for data in (bin_dense, bin_csr): for clf in ( SGDClassifier(random_state=0, loss="hinge", fit_intercept=True, learning_rate="pegasos"), SGDClassifier(random_state=0, loss="hinge", fit_intercept=False, learning_rate="pegasos"), SGDClassifier(random_state=0, loss="hinge", fit_intercept=True, learning_rate="invscaling"), SGDClassifier(random_state=0, loss="hinge", fit_intercept=True, learning_rate="constant"), SGDClassifier(random_state=0, loss="squared_hinge", eta0=1e-2, fit_intercept=True, learning_rate="constant"), SGDClassifier(random_state=0, loss="log", fit_intercept=True, learning_rate="constant"), SGDClassifier(random_state=0, loss="modified_huber", fit_intercept=True, learning_rate="constant"), ): clf.fit(data, bin_target) assert_greater(clf.score(data, bin_target), 0.934)
def test_multiclass_sgd(): clf = SGDClassifier(random_state=0) clf.fit(mult_dense, mult_target) assert_greater(clf.score(mult_dense, mult_target), 0.80)