Пример #1
0
def demographic_parity(df_test_encoded, predictions, print_=False):
    dpd_sex = demographic_parity_difference(df_test_encoded.earnings, predictions, sensitive_features=df_test_encoded.sex)
    dpr_sex = demographic_parity_ratio(df_test_encoded.earnings, predictions, sensitive_features=df_test_encoded.sex)

    if (print_):
        print(f"Demographic parity difference sex:", dpd_sex)
        print(f"Demographic parity ratio sex:", dpr_sex)
Пример #2
0
def fair_metrics(bst, data, column, thresh):
    tr = list(data.get_label())
    best_iteration = bst.best_ntree_limit
    pred = bst.predict(data, ntree_limit=best_iteration)
    pred = [1 if p > thresh else 0 for p in pred]
    na0 = 0
    na1 = 0
    nd0 = 0
    nd1 = 0
    for p, c in zip(pred, column):
        if (p == 1 and c == 0):
            nd1 += 1
        if (p == 1 and c == 1):
            na1 += 1
        if (p == 0 and c == 0):
            nd0 += 1
        if (p == 0 and c == 1):
            na0 += 1
    Pa1, Pd1, Pa0, Pd0 = na1 / (na1 + na0), nd1 / (nd1 + nd0), na0 / (
        na1 + na0), nd0 / (nd1 + nd0)
    dsp_metric = np.abs(Pd1 - Pa1)
    #dsp_metric = np.abs((first-second)/(first+second))
    sr_metric = selection_rate(tr, pred, pos_label=1)
    dpd_metric = demographic_parity_difference(tr,
                                               pred,
                                               sensitive_features=column)
    dpr_metric = demographic_parity_ratio(tr, pred, sensitive_features=column)
    eod_metric = equalized_odds_difference(tr, pred, sensitive_features=column)

    return dsp_metric, sr_metric, dpd_metric, dpr_metric, eod_metric
    def compute_ideal_area(Y, C):
        len_c = len(numpy.unique(C))
        len_y = len(numpy.unique(Y))
        p_y_c = numpy.zeros((len_y, len_c))

        for c in range(len_c):
            for y in range(len_y):
                p_y_c[y, c] = numpy.logical_and(Y == y, C == c).mean()
        print(p_y_c)

        # compute desired rate i.e p(y=1|C=c)
        desired_rate = p_y_c[1, :].mean()
        errors = p_y_c[1, :] - desired_rate

        majority_acc = max(numpy.mean(Y == 1), 1 - numpy.mean(Y == 1))
        max_dp = demographic_parity_difference(Y, Y, sensitive_features=C)

        solution = get_optimal_front(Y, C)
        # add no error and max_dp to the solution
        solution.append([1, max_dp])

        solution = numpy.array(solution)

        # sort by dp
        solution = solution[solution[:, 1].argsort()]

        area = numpy.sum(
            # acc                            * dp_next - dp_cur
            (solution[:-1, 0] - majority_acc) *
            (solution[1:, 1] - solution[0:-1, 1]))
        return area, majority_acc, max_dp
def test_against_demographic_parity_difference(method):
    expected = metrics.demographic_parity_difference(
        y_true, y_pred, sensitive_features=sf_binary, method=method
    )
    actual = metrics.selection_rate_difference(
        y_true, y_pred, sensitive_features=sf_binary, method=method
    )
    assert expected == actual
Пример #5
0
def test_demographic_parity_difference(agg_method):
    actual = demographic_parity_difference(y_t,
                                           y_p,
                                           sensitive_features=g_1,
                                           method=agg_method)

    gm = MetricFrame(selection_rate, y_t, y_p, sensitive_features=g_1)

    assert actual == gm.difference(method=agg_method)
Пример #6
0
def main(argv):
    df_data = pd.read_csv(r"adults_dataset/adult_train.csv")
    df_data = name_columns(df_data)
    df_test = pd.read_csv(r"adults_dataset/adult_test.csv")
    df_test = name_columns(df_test)

    df_data = data_preprocessing(df_data)
    df_test = data_preprocessing(df_test)

    # fig_proportion_of_rich(df_test, argv[1], False)

    df_data_encoded = one_hot_encoding(df_data)
    df_test_encoded = one_hot_encoding(df_test)

    normalization(df_data_encoded)
    normalization(df_test_encoded)

    samples = split_samples(df_data_encoded, df_test_encoded)
    
    model = random_forest_classifier(samples)

    predictions = predict(model, samples, False)

    # proportion_of_rich(argv[2], samples, predictions, False)

    gender_performance(df_test_encoded, predictions)
    demographic_parity(df_test_encoded, predictions)
    equalized_odds(df_test_encoded, predictions)

    #Kamiran and Calders
    train_sds = StandardDataset(df_data_encoded, label_name="earnings", favorable_classes=[1], 
                                protected_attribute_names=["sex"], privileged_classes=[[1]])

    test_sds = StandardDataset(df_test_encoded, label_name="earnings", favorable_classes=[1],
                               protected_attribute_names=["sex"], privileged_classes=[[1]])

    privileged_groups = [{"sex": 1.0}]
    unprivileged_groups = [{"sex": 0.0}]

    RW = Reweighing(unprivileged_groups=unprivileged_groups, privileged_groups=privileged_groups)
    RW.fit(train_sds)

    test_sds_pred = test_sds.copy(deepcopy=True)
    test_sds_transf = RW.transform(test_sds)

    samples_fair = split_samples_fair(train_sds, test_sds, test_sds_pred)
    
    model_fair = logistic_regression(test_sds_transf)

    predictions_fair, test_pred = predict_fair(model_fair, samples_fair, True)
    test_pred = test_pred.astype(int)

    dpd = demographic_parity_difference(
        df_test_encoded.earnings, test_pred, sensitive_features=df_test_encoded.sex)

    print(f"Model demographic parity difference:", dpd)
Пример #7
0
def __binary_group_fairness_measures(X,
                                     prtc_attr,
                                     y_true,
                                     y_pred,
                                     y_prob=None,
                                     priv_grp=1):
    """[summary]

    Args:
        X (pandas DataFrame): Sample features
        prtc_attr (named array-like): values for the protected attribute
            (note: protected attribute may also be present in X)
        y_true (pandas DataFrame): Sample targets
        y_pred (pandas DataFrame): Sample target predictions
        y_prob (pandas DataFrame, optional): Sample target probabilities. Defaults
            to None.

    Returns:
        [type]: [description]
    """
    pa_names = prtc_attr.columns.tolist()
    gf_vals = {}
    gf_key = 'Group Fairness'
    gf_vals['Statistical Parity Difference'] = \
        aif_mtrc.statistical_parity_difference(y_true, y_pred, prot_attr=pa_names)
    gf_vals['Disparate Impact Ratio'] = \
        aif_mtrc.disparate_impact_ratio(y_true, y_pred, prot_attr=pa_names)
    if not helper.is_tutorial_running() and not len(pa_names) > 1:
        gf_vals['Demographic Parity Difference'] = \
            fl_mtrc.demographic_parity_difference(y_true, y_pred,
                                                  sensitive_features=prtc_attr)
        gf_vals['Demographic Parity Ratio'] = \
            fl_mtrc.demographic_parity_ratio(y_true, y_pred,
                                             sensitive_features=prtc_attr)
    gf_vals['Average Odds Difference'] = \
        aif_mtrc.average_odds_difference(y_true, y_pred, prot_attr=pa_names)
    gf_vals['Equal Opportunity Difference'] = \
        aif_mtrc.equal_opportunity_difference(y_true, y_pred, prot_attr=pa_names)
    if not helper.is_tutorial_running() and not len(pa_names) > 1:
        gf_vals['Equalized Odds Difference'] = \
            fl_mtrc.equalized_odds_difference(y_true, y_pred,
                                              sensitive_features=prtc_attr)
        gf_vals['Equalized Odds Ratio'] = \
            fl_mtrc.equalized_odds_ratio(y_true, y_pred,
                                         sensitive_features=prtc_attr)
    gf_vals['Positive Predictive Parity Difference'] = \
        aif_mtrc.difference(sk_metric.precision_score, y_true,
                            y_pred, prot_attr=pa_names, priv_group=priv_grp)
    gf_vals['Balanced Accuracy Difference'] = \
        aif_mtrc.difference(sk_metric.balanced_accuracy_score, y_true,
                            y_pred, prot_attr=pa_names, priv_group=priv_grp)
    if y_prob is not None:
        gf_vals['AUC Difference'] = \
            aif_mtrc.difference(sk_metric.roc_auc_score, y_true, y_prob,
                                prot_attr=pa_names, priv_group=priv_grp)
    return (gf_key, gf_vals)
def demographic_parity_difference(y, c, y_hat):
    """This will only return max, mean implementation is not there"""
    c = c.reshape(-1)
    assert y_hat.shape[1] == 2
    assert y.shape[0] == c.shape[0]

    y_pred = y_hat[:, 1] > 0.5

    return fairlearn_metrics.demographic_parity_difference(
        y, y_pred, sensitive_features=c), None
Пример #9
0
def run_thresholdoptimizer_classification(estimator):
    """Run classification test with ThresholdOptimizer."""
    X_train, Y_train, A_train, X_test, Y_test, A_test = fetch_adult()

    unmitigated = copy.deepcopy(estimator)
    unmitigated.fit(X_train, Y_train)
    unmitigated_predictions = unmitigated.predict(X_test)

    to = ThresholdOptimizer(estimator=estimator,
                            prefit=False,
                            predict_method='predict')
    to.fit(X_train, Y_train, sensitive_features=A_train)

    mitigated_predictions = to.predict(X_test, sensitive_features=A_test)

    dp_diff_unmitigated = demographic_parity_difference(
        Y_test, unmitigated_predictions, sensitive_features=A_test)

    dp_diff_mitigated = demographic_parity_difference(
        Y_test, mitigated_predictions, sensitive_features=A_test)
    assert dp_diff_mitigated <= dp_diff_unmitigated
Пример #10
0
def test_demographic_parity_difference_weighted(agg_method):
    actual = demographic_parity_difference(y_t,
                                           y_p,
                                           sensitive_features=g_1,
                                           sample_weight=s_w,
                                           method=agg_method)

    gm = MetricFrame(selection_rate,
                     y_t,
                     y_p,
                     sensitive_features=g_1,
                     sample_params={'sample_weight': s_w})

    assert actual == gm.difference(method=agg_method)
Пример #11
0
def conditional_demographic_parity_difference(labels, pred, attr, groups):
    """
    Calculate conditional demographic parity by calculating the average
    demographic parity difference across bins defined by `groups`.
    """
    diffs = []

    for group in set(groups):
        mask = groups == group

        diffs.append(
            demographic_parity_difference(labels[mask],
                                          pred[mask],
                                          sensitive_features=attr[mask]))

    return np.mean(diffs)
    def compute_ideal_stats(Y, C):
        len_c = len(numpy.unique(C))
        len_y = len(numpy.unique(Y))
        p_y_c = numpy.zeros((len_y, len_c))

        for c in range(len_c):
            for y in range(len_y):
                p_y_c[y, c] = numpy.logical_and(Y == y, C == c).mean()
        print(p_y_c)

        # compute desired rate i.e p(y=1|C=c)
        desired_rate = p_y_c[1, :].mean()
        errors = p_y_c[1, :] - desired_rate

        majority_acc = max(numpy.mean(Y == 1), 1 - numpy.mean(Y == 1))
        max_dp = demographic_parity_difference(Y, Y, sensitive_features=C)

        return 0, majority_acc, max_dp
Пример #13
0
def evaluate_model(model, device, criterion, data_loader):
    model.eval()
    y_true = []
    y_pred = []
    y_out = []
    sensitives = []
    for i, data in enumerate(data_loader):
        x, y, sensitive_features = data
        x = x.to(device)
        y = y.to(device)
        sensitive_features = sensitive_features.to(device)
        with torch.no_grad():
            logit = model(x)
        # logit : binary prediction size=(b, 1)
        bina = (torch.sigmoid(logit) > 0.5).float()
        y_true += y.cpu().tolist()
        y_pred += bina.cpu().tolist()
        y_out += torch.sigmoid(logit).tolist()
        sensitives += sensitive_features.cpu().tolist()
    result = {}
    result["acc"] = skm.accuracy_score(y_true, y_pred)
    result["f1score"] = skm.f1_score(y_true, y_pred)
    result["AUC"] = skm.roc_auc_score(y_true, y_out)
    result['DP'] = {
        "diff":
        flm.demographic_parity_difference(
            y_true, y_pred, sensitive_features=sensitive_features),
        "ratio":
        flm.demographic_parity_ratio(y_true,
                                     y_pred,
                                     sensitive_features=sensitive_features),
    }
    result["EO"] = {
        "diff":
        flm.equalized_odds_difference(y_true,
                                      y_pred,
                                      sensitive_features=sensitive_features),
        "ratio":
        flm.equalized_odds_ratio(y_true,
                                 y_pred,
                                 sensitive_features=sensitive_features),
    }
    return result
Пример #14
0
def get_optimal_front(Y, C):
    len_c = len(numpy.unique(C))
    len_y = len(numpy.unique(Y))
    p_y_c = numpy.zeros((len_y, len_c))

    for c in range(len_c):
        for y in range(len_y):
            p_y_c[y, c] = numpy.logical_and(Y == y, C == c).mean()
    # print(p_y_c)
    #
    # # compute desired rate i.e p(y=1|C=c)
    # desired_rate = p_y_c[1, :].mean()
    # errors = p_y_c[1, :] - desired_rate

    majority_acc = max(numpy.mean(Y == 1), 1 - numpy.mean(Y == 1))
    max_dp = demographic_parity_difference(Y, Y, sensitive_features=C)
    STEPS = 0.01

    solution = []
    for dp in numpy.arange(0, max_dp, STEPS):
        delta = cvxpy.Variable(p_y_c.shape[1])

        # delta is fraction flipped to 0
        objective = cvxpy.Maximize(1 - cvxpy.sum(cvxpy.abs(delta)))
        constraints = []
        constraints.extend([-p_y_c[0, :] <= delta, delta <= p_y_c[1, :]])

        p_c = p_y_c.sum(axis=0)
        for i in range(p_y_c.shape[1]):
            for j in range(i + 1, p_y_c.shape[1]):
                constraints.extend([
                    -dp <= (p_y_c[1, i] - delta[i]) / p_c[i] -
                    (p_y_c[1, j] - delta[j]) / p_c[j],
                    (p_y_c[1, i] - delta[i]) / p_c[i] -
                    (p_y_c[1, j] - delta[j]) / p_c[j] <= dp,
                ])
        prob = cvxpy.Problem(objective, constraints)
        result = prob.solve()
        # breakpoint()
        solution.append([result, dp])
        print(f"DP: {dp}, sol : {result}")

    return solution
Пример #15
0
def test_student(args, student_train_loader, student_labels, student_test_loader, test_size, cat_emb_size, num_conts, device, sensitive_idx):
    student_model = RegressionModel(emb_szs=cat_emb_size,
                    n_cont=num_conts,
                    emb_drop=0.04,
                    out_sz=1,
                    szs=[1000, 500, 250],
                    drops=[0.001, 0.01, 0.01],
                    y_range=(0, 1)).to(device)

    criterion = nn.BCELoss()
    optimizer = optim.SGD(student_model.parameters(), lr=args.lr, momentum=0)
    steps = 0
    running_loss = 0
    correct = 0
    print("========== Testing Student Model ==========")
    for epoch in range(args.epochs):
        student_model.train()
        train_loader = student_loader(student_train_loader, student_labels)
        for (cats, conts) , labels in train_loader:
        #for _batch_idx, (data, target) in enumerate(tqdm(train_loader)):
            #cats = data[0]
            #conts = data[1]
            steps += 1

            optimizer.zero_grad()
            output = student_model(cats, conts).view(-1)
            labels = labels.to(torch.float32)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        #            if steps % 50 == 0:
            student_model.eval()
            test_loss = 0
            correct = 0
            i = 0

            avg_recall = 0
            avg_precision = 0
            overall_results = []
            avg_eq_odds = 0
            avg_dem_par = 0
            avg_tpr = 0
            avg_tp = 0
            avg_tn = 0
            avg_fp = 0
            avg_fn = 0

            with torch.no_grad():
                for batch_idx, (cats, conts, target) in enumerate(student_test_loader):
                    print("target\n", sum(target))
                    i+=1
                    output = student_model(cats, conts)
                    loss += criterion(output, target).item()
                    test_loss = test_loss + ((1 / (batch_idx + 1)) * (loss.data - test_loss))
                    pred = (output > 0.5).float()
                    print("pred\n", sum(pred))
                    correct += pred.eq(target.view_as(pred)).sum().item()

                    curr_datetime = datetime.now()
                    curr_hour = curr_datetime.hour
                    curr_min = curr_datetime.minute

                    pred_df = pd.DataFrame(pred.numpy())
                    pred_df.to_csv(f"pred_results/{args.run_name}_{curr_hour}-{curr_min}.csv")

                    #print(pred, np.sum(np.squeeze(pred.eq(target.data.view_as(pred))).cpu().numpy()))
                    #correct += np.sum(np.squeeze(pred.eq(target.data.view_as(pred))).cpu().numpy())
                    #total += cats.size(0)


                    # confusion matrixç
                    tn, fp, fn, tp = confusion_matrix(target, pred).ravel()
                    avg_tn += tn
                    avg_fp += fp
                    avg_fn += fn
                    avg_tp += tp

                    # position of col for sensitive values
                    sensitive = [i[sensitive_idx].item() for i in cats]
                    cat_len = max(sensitive)

                    #exit()
                    sub_cm = []
                    # print(cat_len)
                    for j in range(cat_len+1):
                        try:
                            idx = list(locate(sensitive, lambda x: x == j))
                            sub_tar = target[idx]
                            sub_pred = pred[idx]
                            sub_tn, sub_fp, sub_fn, sub_tp = confusion_matrix(sub_tar, sub_pred).ravel()
                        except:
                            # when only one value to predict
                            print("----WHAT?")
                            temp_tar = int(sub_tar.numpy()[0])
                            temp_pred = int(sub_pred.numpy()[0])
                            # print(tar, pred)
                            if temp_tar and temp_pred:
                                sub_tn, sub_fp, sub_fn, sub_tp = 0, 0, 0, 1
                            elif temp_tar and not temp_pred:
                                sub_tn, sub_fp, sub_fn, sub_tp = 0, 0, 1, 0
                            elif not temp_tar and not temp_pred:
                                sub_tn, sub_fp, sub_fn, sub_tp = 1, 0, 0, 0
                            elif not temp_tar and temp_pred:
                                sub_tn, sub_fp, sub_fn, sub_tp = 0, 1, 0, 0
                            else:
                                sub_tn, sub_fp, sub_fn, sub_tp = 0, 0, 0, 0

                        total = mysum(sub_tn, sub_fp, sub_fn, sub_tp)
                        print("??", total)
                        sub_cm.append((sub_tn / total, sub_fp / total, sub_fn / total, sub_tp / total))

                    # Fairness metrics

                    group_metrics = MetricFrame({'precision': skm.precision_score, 'recall': skm.recall_score},
                                                target, pred,
                                                sensitive_features=sensitive)


                    demographic_parity = flm.demographic_parity_difference(target, pred,
                                                                           sensitive_features=sensitive)

                    eq_odds = flm.equalized_odds_difference(target, pred,
                                                            sensitive_features=sensitive)

                    # metric_fns = {'true_positive_rate': true_positive_rate}

                    tpr = MetricFrame(true_positive_rate,
                                      target, pred,
                                      sensitive_features=sensitive)

                    # tpr = flm.true_positive_rate(target, pred,sample_weight=sensitive)
                    sub_results = group_metrics.overall.to_dict()
                    sub_results_by_group = group_metrics.by_group.to_dict()

                    # print("\n", group_metrics.by_group, "\n")
                    avg_precision += sub_results['precision']
                    avg_recall += sub_results['recall']
                    print("pre_rec", sub_results)
                    overall_results.append(sub_results_by_group)
                    avg_eq_odds += eq_odds
                    print("eqo", eq_odds)
                    avg_dem_par += demographic_parity
                    print("dempar", demographic_parity)
                    avg_tpr += tpr.difference(method='between_groups')
                    print("tpr", tpr.difference(method='between_groups'))

            total = mysum(avg_tn, avg_fp, avg_fn, avg_tp)
            print("!!", total)
            cm = (avg_tn / total, avg_fp / total, avg_fn / total, avg_tp / total)
            test_loss /= test_size
            accuracy = correct / test_size
            avg_loss = test_loss

            return accuracy, avg_loss, avg_precision, avg_recall, avg_eq_odds, avg_tpr, avg_dem_par, cm, sub_cm, overall_results
Пример #16
0
def fair_metrics(gt, y, group):
    metrics_dict = {
        "DPd": demographic_parity_difference(gt, y, sensitive_features=group),
        "EOd": equalized_odds_difference(gt, y, sensitive_features=group),
    }
    return metrics_dict
Пример #17
0
def test(args, model, device, test_loader, test_size, sensitive_idx):

    model.eval()
    criterion = nn.BCELoss()
    test_loss = 0
    correct = 0
    i = 0

    avg_recall = 0
    avg_precision = 0
    overall_results = []
    avg_eq_odds = 0
    avg_dem_par = 0
    avg_tpr = 0
    avg_tp = 0
    avg_tn = 0
    avg_fp = 0
    avg_fn = 0
    with torch.no_grad():
        for cats, conts, target in tqdm(test_loader):
            print("*********")
            #i += 1
            cats, conts, target = cats.to(device), conts.to(device), target.to(device)


            output = model(cats, conts)
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = (output > 0.5).float()
            correct += pred.eq(target.view_as(pred)).sum().item()

            curr_datetime = datetime.now()
            curr_hour = curr_datetime.hour
            curr_min = curr_datetime.minute

            pred_df = pd.DataFrame(pred.numpy())
            pred_df.to_csv(f"pred_results/{args.run_name}_{curr_hour}-{curr_min}.csv")

            # confusion matrixç
            tn, fp, fn, tp = confusion_matrix(target, pred).ravel()
            avg_tn+=tn
            avg_fp+=fp
            avg_fn+=fn
            avg_tp+=tp

            # position of col for sensitive values
            sensitive = [i[sensitive_idx].item() for i in cats]
            cat_len = max(sensitive)
            print(cat_len)
            #exit()
            sub_cm = []
            #print(cat_len)
            for j in range(cat_len+1):
                try:
                    idx = list(locate(sensitive, lambda x: x == j))
                    sub_tar = target[idx]
                    sub_pred = pred[idx]
                    sub_tn, sub_fp, sub_fn, sub_tp = confusion_matrix(sub_tar, sub_pred).ravel()
                except:
                    # when only one value to predict
                    temp_tar = int(sub_tar.numpy()[0])
                    temp_pred = int(sub_pred.numpy()[0])
                    #print(tar, pred)
                    if temp_tar and temp_pred:
                        sub_tn, sub_fp, sub_fn, sub_tp = 0, 0, 0, 1
                    elif temp_tar and not temp_pred:
                        sub_tn, sub_fp, sub_fn, sub_tp = 0, 0, 1, 0
                    elif not temp_tar and not temp_pred:
                        sub_tn, sub_fp, sub_fn, sub_tp = 1, 0, 0, 0
                    elif not temp_tar and temp_pred:
                        sub_tn, sub_fp, sub_fn, sub_tp = 0, 1, 0, 0
                    else:
                        sub_tn, sub_fp, sub_fn, sub_tp = 0, 0, 0, 0

                total = mysum(sub_tn, sub_fp, sub_fn, sub_tp)
                sub_cm.append((sub_tn/total, sub_fp/total, sub_fn/total, sub_tp/total))

            # Fairness metrics
            group_metrics = MetricFrame({'precision': skm.precision_score, 'recall': skm.recall_score},
                                        target, pred,
                                        sensitive_features=sensitive)

            demographic_parity = flm.demographic_parity_difference(target, pred,
                                                                   sensitive_features=sensitive)

            eq_odds = flm.equalized_odds_difference(target, pred,
                                                    sensitive_features=sensitive)

            # metric_fns = {'true_positive_rate': true_positive_rate}

            tpr = MetricFrame(true_positive_rate,
                              target, pred,
                              sensitive_features=sensitive)

            # tpr = flm.true_positive_rate(target, pred,sample_weight=sensitive)
            sub_results = group_metrics.overall.to_dict()
            sub_results_by_group = group_metrics.by_group.to_dict()

            #print("\n", group_metrics.by_group, "\n")
            avg_precision += sub_results['precision']
            avg_recall += sub_results['recall']
            overall_results.append(sub_results_by_group)
            avg_eq_odds += eq_odds
            avg_dem_par += demographic_parity
            avg_tpr += tpr.difference(method='between_groups')

    print(i)
    total = mysum(avg_tn, avg_fp, avg_fn, avg_tp)
    cm = (avg_tn/total, avg_fp/total, avg_fn/total, avg_tp/total)
    test_loss /= test_size
    accuracy = correct / test_size
    avg_loss = test_loss


    return accuracy, avg_loss, avg_precision, avg_recall, avg_eq_odds, avg_tpr, avg_dem_par, cm, sub_cm, overall_results
def test_student(args, student_train_loader, student_labels,
                 student_test_loader, test_size, cat_emb_size, num_conts,
                 device, sensitive_idx):
    student_model = RandomForestClassifier(random_state=42,
                                           warm_start=True,
                                           n_estimators=100)

    print("========== Testing Student Model ==========")
    for epoch in range(args.epochs):
        train_loader = student_loader(student_train_loader, student_labels)
        for (cats, conts), labels in train_loader:
            X = torch.cat((cats, conts), 1)
            student_model = student_model.fit(X, labels)

            test_loss = 0
            correct = 0
            i = 0

            avg_recall = 0
            avg_precision = 0
            overall_results = []
            avg_eq_odds = 0
            avg_dem_par = 0
            avg_tpr = 0
            avg_tp = 0
            avg_tn = 0
            avg_fp = 0
            avg_fn = 0

            with torch.no_grad():
                for batch_idx, (cats, conts,
                                target) in enumerate(student_test_loader):
                    print("target\n", sum(target))
                    i += 1
                    X = torch.cat((cats, conts), 1)
                    output = student_model.predict(X)
                    output = torch.from_numpy(output)
                    pred = (output > 0.5).float()
                    print("pred\n", sum(pred))
                    correct += pred.eq(target.view_as(pred)).sum().item()

                    curr_datetime = datetime.now()
                    curr_hour = curr_datetime.hour
                    curr_min = curr_datetime.minute

                    pred_df = pd.DataFrame(pred.numpy())
                    pred_df.to_csv(
                        f"pred_results/{args.run_name}_{curr_hour}-{curr_min}.csv"
                    )

                    #print(pred, np.sum(np.squeeze(pred.eq(target.data.view_as(pred))).cpu().numpy()))
                    #correct += np.sum(np.squeeze(pred.eq(target.data.view_as(pred))).cpu().numpy())
                    #total += cats.size(0)

                    # confusion matrixç
                    tn, fp, fn, tp = confusion_matrix(target, pred).ravel()
                    avg_tn += tn
                    avg_fp += fp
                    avg_fn += fn
                    avg_tp += tp

                    # position of col for sensitive values
                    sensitive = [i[sensitive_idx].item() for i in cats]
                    cat_len = max(sensitive)

                    #exit()
                    sub_cm = []
                    # print(cat_len)
                    for j in range(cat_len + 1):
                        try:
                            idx = list(locate(sensitive, lambda x: x == j))
                            sub_tar = target[idx]
                            sub_pred = pred[idx]
                            sub_tn, sub_fp, sub_fn, sub_tp = confusion_matrix(
                                sub_tar, sub_pred).ravel()
                        except:
                            # when only one value to predict
                            temp_tar = int(sub_tar.numpy()[0])
                            temp_pred = int(sub_pred.numpy()[0])
                            # print(tar, pred)
                            if temp_tar and temp_pred:
                                sub_tn, sub_fp, sub_fn, sub_tp = 0, 0, 0, 1
                            elif temp_tar and not temp_pred:
                                sub_tn, sub_fp, sub_fn, sub_tp = 0, 0, 1, 0
                            elif not temp_tar and not temp_pred:
                                sub_tn, sub_fp, sub_fn, sub_tp = 1, 0, 0, 0
                            elif not temp_tar and temp_pred:
                                sub_tn, sub_fp, sub_fn, sub_tp = 0, 1, 0, 0
                            else:
                                sub_tn, sub_fp, sub_fn, sub_tp = 0, 0, 0, 0

                        total = mysum(sub_tn, sub_fp, sub_fn, sub_tp)
                        print("??", total)
                        sub_cm.append((sub_tn / total, sub_fp / total,
                                       sub_fn / total, sub_tp / total))

                    # Fairness metrics

                    group_metrics = MetricFrame(
                        {
                            'precision': skm.precision_score,
                            'recall': skm.recall_score
                        },
                        target,
                        pred,
                        sensitive_features=sensitive)

                    print(target)
                    print(pred)
                    demographic_parity = flm.demographic_parity_difference(
                        target, pred, sensitive_features=sensitive)

                    eq_odds = flm.equalized_odds_difference(
                        target, pred, sensitive_features=sensitive)

                    # metric_fns = {'true_positive_rate': true_positive_rate}

                    tpr = MetricFrame(true_positive_rate,
                                      target,
                                      pred,
                                      sensitive_features=sensitive)

                    # tpr = flm.true_positive_rate(target, pred,sample_weight=sensitive)
                    sub_results = group_metrics.overall.to_dict()
                    sub_results_by_group = group_metrics.by_group.to_dict()

                    # print("\n", group_metrics.by_group, "\n")
                    avg_precision += sub_results['precision']
                    avg_recall += sub_results['recall']
                    print("pre_rec", sub_results)
                    overall_results.append(sub_results_by_group)
                    avg_eq_odds += eq_odds
                    print("eqo", eq_odds)
                    avg_dem_par += demographic_parity
                    print("dempar", demographic_parity)
                    avg_tpr += tpr.difference(method='between_groups')
                    print("tpr", tpr.difference(method='between_groups'))

            total = mysum(avg_tn, avg_fp, avg_fn, avg_tp)
            print("!!", total)
            cm = (avg_tn / total, avg_fp / total, avg_fn / total,
                  avg_tp / total)
            test_loss /= test_size
            accuracy = correct / test_size
            avg_loss = test_loss

            return accuracy, avg_loss, avg_precision, avg_recall, avg_eq_odds, avg_tpr, avg_dem_par, cm, sub_cm, overall_results