class Fair_rew_NN():
    def __init__(self, un_gr, pr_gr, inp_size, num_layers_y, step_y):

        self.model_reweight = Reweighing(un_gr, pr_gr)
        self.model = FairClass(inp_size, num_layers_y, step_y)

    def fit(self, data, labels, prot):
        ds = BinaryLabelDataset(df=data,
                                label_names=labels,
                                protected_attribute_names=prot)
        self.prot = prot
        x = self.model_reweight.fit_transform(ds)
        index = x.feature_names.index(prot[0])
        x_train = np.delete(x.features, index, 1)
        y_train = x.labels
        x_train = torch.tensor(x_train).type('torch.FloatTensor')
        y_train = torch.tensor(y_train).type('torch.FloatTensor')
        self.model.fit(x_train, y_train)

    def predict_proba(self, data_test):
        x = self.model_reweight.transform(data_test)
        index = x.feature_names.index(self.prot[0])
        x_test = np.delete(x.features, index, 1)
        x_test = torch.tensor(x_test).type('torch.FloatTensor')
        y = self.model.predict_proba(x_test)
        return y
class Fair_rew_RF():
    def __init__(self, un_gr, pr_gr, n_est=100, min_sam_leaf=25):
        self.model_reweight = Reweighing(un_gr, pr_gr)
        self.model = RandomForestClassifier(n_estimators=n_est,
                                            min_samples_leaf=min_sam_leaf)

    def fit(self, data, labels, prot):
        ds = BinaryLabelDataset(df=data,
                                label_names=labels,
                                protected_attribute_names=prot)
        self.prot = prot
        x = self.model_reweight.fit_transform(ds)
        index = x.feature_names.index(prot[0])
        x_train = np.delete(x.features, index, 1)
        y_train = x.labels.ravel()
        self.model.fit(x_train, y_train)

    def predict_proba(self, data_test):
        x = self.model_reweight.transform(data_test)
        index = x.feature_names.index(self.prot[0])
        x_test = np.delete(x.features, index, 1)
        y = self.model.predict_proba(x_test)[:, 1]
        return y
tmp["two_year_recid"] = pred_

parity_diff = compute_statistical_parity(tmp, unpriv_group, priv_group)
tpr_diff, tpr_priv, tpr_unpriv = compute_metrics(t_data, pred_, unpriv_group,
                                                 priv_group)
all_results.append(("Without Race", ps, rs, fs, as_, parity_diff, tpr_diff,
                    tpr_priv, tpr_unpriv))

print(
    f"The precision is {ps}.\nThe recall is {rs}.\nThe F1 is {fs}.\nThe accuracy is {as_}."
)

bag = [X_, y_, pred_]
# %%
log_reg_RW = Reweighing(unpriv_group, priv_group).fit(train_data)
transformed_data = log_reg_RW.transform(train_data)
display(train_data.instance_weights.mean(), train_data.instance_weights.std())
display(transformed_data.instance_weights.mean(),
        transformed_data.instance_weights.std())
t_data = train_data.convert_to_dataframe()[0]
t_data["weights"] = transformed_data.instance_weights
t_data_blacks = t_data[t_data.race == 0]
t_data_whites = t_data[t_data.race == 1]
print(t_data_blacks.weights.describe())
print(t_data_whites.weights.describe())
t_data.boxplot(["weights"], by="race", figsize=(10, 5))
plt.show()

# %%
data = train_data.convert_to_dataframe()[0]
X, y = data.drop(["two_year_recid"], axis=1), data["two_year_recid"]