Exemple #1
0
def test_classif_binary(weighting):
    clf = RobustWeightedClassifier(
        max_iter=100,
        weighting=weighting,
        k=0,
        c=1e7,
        burn_in=0,
        multi_class="binary",
        random_state=rng,
    )
    clf_not_rob = SGDClassifier(loss="log", random_state=rng)
    clf.fit(X_cb, y_cb)
    clf_not_rob.fit(X_cb, y_cb)
    norm_coef1 = np.linalg.norm(np.hstack([clf.coef_.ravel(), clf.intercept_]))
    norm_coef2 = np.linalg.norm(
        np.hstack([clf_not_rob.coef_.ravel(), clf_not_rob.intercept_])
    )
    coef1 = clf.coef_ / norm_coef1
    coef2 = clf_not_rob.coef_ / norm_coef2

    intercept1 = clf.intercept_ / norm_coef1
    intercept2 = clf_not_rob.intercept_ / norm_coef2

    assert np.linalg.norm(coef1 - coef2) < 0.5
    assert np.linalg.norm(intercept1 - intercept2) < 0.5

    assert len(clf.weights_) == len(X_cb)
Exemple #2
0
def test_classif_corrupted_weights(weighting):
    clf = RobustWeightedClassifier(
        max_iter=100,
        weighting=weighting,
        k=5,
        c=1,
        burn_in=0,
        multi_class="binary",
        random_state=rng,
    )
    clf.fit(X_cc, y_cc)
    assert np.mean(clf.weights_[:3]) < np.mean(clf.weights_[3:])
Exemple #3
0
def test_corrupted_classif(loss, weighting, k, c, multi_class):
    clf = RobustWeightedClassifier(
        loss=loss,
        max_iter=100,
        weighting=weighting,
        k=k,
        c=c,
        multi_class=multi_class,
        random_state=rng,
    )
    clf.fit(X_cc, y_cc)
    score = clf.score(X_cc, y_cc)
    assert score > 0.8
Exemple #4
0
def test_predict_proba(weighting):
    clf = RobustWeightedClassifier(
        max_iter=100,
        weighting=weighting,
        k=0,
        c=1e7,
        burn_in=0,
        random_state=rng,
    )
    clf_not_rob = SGDClassifier(loss="log", random_state=rng)
    clf.fit(X_c, y_c)
    clf_not_rob.fit(X_c, y_c)
    pred1 = clf.base_estimator_.predict_proba(X_c)[:, 1]
    pred2 = clf_not_rob.predict_proba(X_c)[:, 1]

    assert np.mean((pred1 > 1 / 2) == (pred2 > 1 / 2)) > 0.8
def test_not_robust_classif(loss, weighting, multi_class):
    clf = RobustWeightedClassifier(
        loss=loss,
        max_iter=100,
        weighting=weighting,
        k=0,
        c=1e7,
        burn_in=0,
        multi_class=multi_class,
        random_state=rng,
    )
    clf_not_rob = SGDClassifier(loss=loss, random_state=rng)
    clf.fit(X_c, y_c)
    clf_not_rob.fit(X_c, y_c)
    pred1 = clf.base_estimator_.decision_function(X_c)
    pred2 = clf_not_rob.decision_function(X_c)

    assert np.mean((pred1 > 0) == (pred2 > 0)) > 0.8
Exemple #6
0
def test_not_robust_classif(loss, weighting, multi_class):
    clf = RobustWeightedClassifier(
        loss=loss,
        max_iter=100,
        weighting=weighting,
        k=0,
        c=1e7,
        burn_in=0,
        multi_class=multi_class,
        random_state=rng,
    )
    clf_not_rob = SGDClassifier(loss=loss, random_state=rng)
    clf.fit(X_c, y_c)
    clf_not_rob.fit(X_c, y_c)
    pred1 = clf.predict(X_c)
    pred2 = clf_not_rob.predict(X_c)

    assert np.mean((pred1 > 0) == (pred2 > 0)) > 0.8
    assert clf.score(X_c, y_c) == np.mean(pred1 == y_c)
Exemple #7
0
def test_robust_estimator_unsupported_multiclass():
    """Test that warning message is thrown when unsupported weighting."""
    model = RobustWeightedClassifier(multi_class="invalid")
    msg = "No such multiclass method implemented."
    with pytest.raises(ValueError, match=msg):
        model.fit(X_cc, y_cc)
Exemple #8
0
def test_robust_estimator_unsupported_weighting():
    """Test that warning message is thrown when unsupported weighting."""
    model = RobustWeightedClassifier(weighting="invalid")
    msg = "No such weighting scheme"
    with pytest.raises(ValueError, match=msg):
        model.fit(X_cc, y_cc)
Exemple #9
0
def test_robust_estimator_unsupported_loss():
    """Test that warning message is thrown when unsupported loss."""
    model = RobustWeightedClassifier(loss="invalid")
    msg = "The loss invalid is not supported. "
    with pytest.raises(ValueError, match=msg):
        model.fit(X_cc, y_cc)
Exemple #10
0
def test_robust_estimator_max_iter():
    """Test that warning message is thrown when max_iter is reached."""
    model = RobustWeightedClassifier(max_iter=1)
    msg = "Maximum number of iteration reached before"
    with pytest.warns(UserWarning, match=msg):
        model.fit(X_cc, y_cc)
Exemple #11
0
def test_robust_estimator_unsupported_loss():
    model = RobustWeightedClassifier(multi_class="binary")
    msg = "y must be binary."
    with pytest.raises(ValueError, match=msg):
        model.fit(X_c, y_c)