def test_demographicparity_fair_uneven_populations(): # Variant of test_demographicparity_already_fair, which has unequal # populations in the two classes # Also allow the threshold to be adjustable score_threshold = 0.625 number_a0 = 4 number_a1 = 4 a0_label = 17 a1_label = 37 X, Y, A = _simple_threshold_data(number_a0, number_a1, score_threshold, score_threshold, a0_label, a1_label) target = GridSearch(LogisticRegression(solver='liblinear', fit_intercept=True), disparity_metric=moments.DemographicParity(), quality_metric=SimpleClassificationQualityMetric(), grid_size=11) target.fit(X, Y, sensitive_features=A) assert len(target.all_results) == 11 test_X = pd.DataFrame({ "actual_feature": [0.2, 0.7], "sensitive_features": [a0_label, a1_label], "constant_ones_feature": [1, 1] }) sample_results = target.predict(test_X) sample_proba = target.predict_proba(test_X) assert np.allclose(sample_proba, [[0.53748641, 0.46251359], [0.46688736, 0.53311264]]) sample_results = target.all_results[0].model.predict(test_X) assert np.array_equal(sample_results, [1, 0]) all_results = target.posterior_predict(test_X) assert len(all_results) == 11 all_proba = target.posterior_predict_proba(test_X) assert len(all_proba) == 11
def test_bgl_unfair(): a0_count = 5 a1_count = 7 a0_label = 2 a1_label = 3 a0_factor = 1 a1_factor = 16 X, Y, A = _simple_regression_data(a0_count, a1_count, a0_factor, a1_factor, a0_label, a1_label) target = GridSearch(LinearRegression(), disparity_metric=moments.GroupLossMoment( moments.ZeroOneLoss()), quality_metric=SimpleRegressionQualityMetric(), grid_size=7) target.fit(X, Y, sensitive_features=A) assert len(target.all_results) == 7 test_X = pd.DataFrame({ "actual_feature": [0.2, 0.7], "sensitive_features": [a0_label, a1_label], "constant_ones_feature": [1, 1] }) best_predict = target.predict(test_X) assert np.allclose([-1.91764706, 9.61176471], best_predict) all_predict = target.posterior_predict(test_X) assert np.allclose( [[3.2, 11.2], [-3.47346939, 10.64897959], [-2.68, 10.12], [-1.91764706, 9.61176471], [-1.18461538, 9.12307692], [-0.47924528, 8.65283019], [0.2, 0.7]], all_predict)