コード例 #1
0
def test_adv_debias_old():
    """Test that the predictions of the old and new AdversarialDebiasing match.
    """
    tf.reset_default_graph()
    sess = tf.Session()
    old_adv_deb = OldAdversarialDebiasing(unprivileged_groups=[{
        'sex': 0
    }],
                                          privileged_groups=[{
                                              'sex': 1
                                          }],
                                          scope_name='old_classifier',
                                          sess=sess,
                                          num_epochs=5,
                                          seed=123)
    old_preds = old_adv_deb.fit_predict(adult)
    sess.close()
    adv_deb = AdversarialDebiasing('sex', num_epochs=5, random_state=123)
    new_preds = adv_deb.fit(X, y).predict(X)
    adv_deb.sess_.close()
    assert np.allclose(old_preds.labels.flatten(), new_preds)
コード例 #2
0
def test_adv_debias_grid():
    """Test that the new AdversarialDebiasing works in a grid search (and that
    debiasing results in reduced accuracy).
    """
    adv_deb = AdversarialDebiasing('sex', num_epochs=10, random_state=123)

    params = {'debias': [True, False]}

    clf = GridSearchCV(adv_deb, params, cv=3)
    clf.fit(X, y)

    clf.best_estimator_.sess_.close()
    assert clf.best_params_ == {'debias': False}
コード例 #3
0
def test_adv_debias_reproduce():
    """Test that the new AdversarialDebiasing is reproducible."""
    adv_deb = AdversarialDebiasing('sex', num_epochs=5, random_state=123)
    new_preds = adv_deb.fit(X, y).predict(X)
    adv_deb.sess_.close()
    new_acc = accuracy_score(y, new_preds)

    adv_deb2 = AdversarialDebiasing('sex', num_epochs=5, random_state=123)
    new_preds = adv_deb2.fit(X, y).predict(X)
    adv_deb.sess_.close()

    assert new_acc == accuracy_score(y, new_preds)
コード例 #4
0
def test_adv_debias_intersection():
    """Test that the new AdversarialDebiasing runs with >2 protected groups."""
    adv_deb = AdversarialDebiasing(scope_name='intersect', num_epochs=5)
    adv_deb.fit(X, y)
    adv_deb.sess_.close()
    assert adv_deb.adversary_logits_.shape[1] == 4