def test_get_weight(self):
        # set up L matrix
        true_accs = [0.95, 0.6, 0.7, 0.55, 0.8]
        coverage = [1.0, 0.8, 1.0, 1.0, 1.0]
        L = -1 * np.ones((1000, len(true_accs)))
        Y = np.zeros(1000)

        for i in range(1000):
            Y[i] = 1 if np.random.rand() <= 0.5 else 0
            for j in range(5):
                if np.random.rand() <= coverage[j]:
                    L[i, j] = (Y[i] if np.random.rand() <= true_accs[j] else
                               np.abs(Y[i] - 1))

        label_model = LabelModel(cardinality=2)
        label_model.fit(L, n_epochs=1000, seed=123)

        accs = label_model.get_weights()
        for i in range(len(accs)):
            true_acc = true_accs[i]
            self.assertAlmostEqual(accs[i], true_acc, delta=0.1)
Exemplo n.º 2
0
def label_user(inp_path, prefix=""):
    df_train = pd.read_pickle(inp_path)

    ########## threshold on word similarity
    take_first = 100
    overall_first = 10000
    global thresh_by_value, overall_thresh
    df_train['root_value'] = df_train['value'].swifter.set_dask_threshold(
        dask_threshold=0.001).allow_dask_on_strings().apply(
            lambda x: syn_to_hob[x])
    thresh_by_value = df_train.groupby(
        ["root_value"]).apply(lambda x: np.partition(
            x['lexicon_counts'], max(len(x['lexicon_counts']) - take_first, 0)
        )[max(len(x['lexicon_counts']) - take_first, 0)]).to_dict()
    overall_thresh = np.partition(df_train["lexicon_counts"].to_numpy(),
                                  max(len(df_train) - overall_first, 0))[max(
                                      len(df_train) - overall_first, 0)]
    print(overall_thresh)
    #############################

    # separately loose - strict, pos - neg, period - without
    names_pool = [
        "context:2_count_pos", "context:3_count_pos", "context:100_count_pos",
        "context:2_period_count_pos", "context:3_period_count_pos",
        "context:100_period_count_pos", "context:2_count_neg",
        "context:3_count_neg", "context:100_count_neg",
        "context:2_period_count_neg", "context:3_period_count_neg",
        "context:100_period_count_neg"
    ]
    for f_name in names_pool:
        curr_cols = [x for x in df_train.columns if f_name in x]
        df_train['total_' + f_name] = df_train[curr_cols].swifter.apply(sum,
                                                                        axis=1)
        df_train = df_train.drop(curr_cols, axis=1)
    for p in ["pos", "neg"]:
        df_train["new_total_context:100_count_" + p] = df_train[[
            "total_context:100_count_" + p, "total_context:3_count_" + p
        ]].swifter.apply(lambda x: max(
            0, x["total_context:100_count_" + p] - x["total_context:3_count_" +
                                                     p]),
                         axis=1)
        df_train["new_total_context:3_count_" + p] = df_train[[
            "total_context:3_count_" + p, "total_context:2_count_" + p
        ]].swifter.apply(lambda x: max(
            0, x["total_context:3_count_" + p] - x["total_context:2_count_" + p
                                                   ]),
                         axis=1)
        df_train["new_total_context:100_period_count_" + p] = df_train[[
            "total_context:3_period_count_" + p,
            "total_context:100_period_count_" + p
        ]].swifter.apply(lambda x: max(
            0, x["total_context:100_period_count_" + p] - x[
                "total_context:3_period_count_" + p]),
                         axis=1)
        df_train["new_total_context:3_period_count_" + p] = df_train[[
            "total_context:3_period_count_" + p,
            "total_context:2_period_count_" + p
        ]].swifter.apply(lambda x: max(
            0, x["total_context:3_period_count_" + p] - x[
                "total_context:2_period_count_" + p]),
                         axis=1)
        df_train["new_total_context:2_count_" + p] = df_train[[
            "total_context:100_period_count_" + p, "total_context:2_count_" + p
        ]].swifter.apply(lambda x: max(
            0, x["total_context:2_count_" + p] - x[
                "total_context:100_period_count_" + p]),
                         axis=1)

    df_train = df_train.drop(
        ["total_" + x for x in names_pool if "2_period_count" not in x],
        axis=1)

    lfs = [val_in_name, positive_lexicon_overall, positive_lexicon_pervalue]
    num_of_thesholds = 3
    step = 100 // num_of_thesholds

    for col in df_train:
        if col not in ["author", "value", "idd", "root_value"]:
            if col not in [
                    "pos_prob_mean", "neg_prob_mean", "num_good_posts"
            ]:  # , "lexicon_counts", "subreddit_counts", "name_in_subr_count"]:
                thresholds = [0]
                if "lexicon" in col and "unique" not in col:
                    continue
                if True:  # col in ["lexicon_counts", "unique_lexicon_counts"]:
                    vals = df_train[col].to_numpy()
                    thresholds = np.percentile(
                        vals, list(range(0 + step, 99 + step,
                                         step))).astype(int)
                    thresholds = sorted(list(set(thresholds)))
                    if len(thresholds) > 1:
                        thresholds = thresholds[:-1]
                    if "lexicon" in col:
                        thresholds = [3]
                    # max_val = max(vals)
                    # thresholds = list(range(0, int(max_val), int(max_val/5) + 1))
                # elif col == "pos_prob_mean":
                #    thresholds = [0.5 + 0.1 * x for x in range(5)]
                for i in range(len(thresholds)):
                    thresh = thresholds[i]
                    next_threshold = sys.maxsize if i == len(
                        thresholds) - 1 else thresholds[i + 1]
                    previous_threshold = -sys.maxsize if i == 0 else thresholds[
                        i - 1]
                    if "lexicon_counts" not in col:
                        lfs.append(
                            make_thresold_lf(thresh=thresh,
                                             col_name=col,
                                             next_threshold=next_threshold))
                    else:
                        lfs.append(
                            make_lexicon_lf(
                                thresh=thresh,
                                pref=col,
                                previous_threshold=previous_threshold))

    num_annotators = 0
    if num_annotators > 0:
        for i in range(1, num_annotators + 1):
            lfs.append(make_annotator_lf(worker_index=i))

    lfs = [
        x for x in lfs
        if any(y in str(x) for y in ["less", "context:2", "worker", "lexicon"])
    ]
    print("created lfs their number", len(lfs))
    print("\n".join(str(x) for x in lfs))

    #### validation #####
    do_val = False
    if do_val:
        df_golden = pd.read_csv(
            "/home/tigunova/PycharmProjects/snorkel_labels/data/profession/gold_dev.csv"
        )
        name_val = list(df_golden["auth_val"])
        # df_train['root_value'] = df_train['value'].swifter.apply(lambda x: syn_to_hob[x])
        df_train["auth_val"] = df_train[["author", "value"]].swifter.apply(
            lambda x: x["author"] + "+++" + x["value"], axis=1)
        df_val = df_train[df_train.auth_val.isin(name_val)]
        df_dev = df_train[~df_train.auth_val.isin(name_val)]
        print("Number val", df_val.shape)
        print("Number dev", df_dev.shape)
        df_val = df_val.merge(df_golden, on="auth_val")
        y_val = np.array(df_val["final"])
        df_val = df_val.drop(labels="final", axis=1)
        # create test set as well

        with TQDMDaskProgressBar(desc="Dask Apply"):
            applier = PandasParallelLFApplier(lfs=lfs)
            L_val = applier.apply(df=df_val, n_parallel=num_cpu)
            L_dev = applier.apply(df=df_dev, n_parallel=num_cpu)

        dev_analysis = LFAnalysis(L=L_dev, lfs=lfs).lf_summary()
        analysis = LFAnalysis(L=L_val, lfs=lfs).lf_summary(y_val)
        analysis.to_csv("/home/tigunova/val_analysis.csv")
        dev_analysis.to_csv("/home/tigunova/dev_analysis.csv")
        print(analysis)
        label_model = LabelModel(cardinality=2, verbose=True)
        label_model.fit(L_dev)  #, Y_dev=y_val)
        model_stat = label_model.score(L=L_val, Y=y_val)
        print(model_stat)
        exit(0)
    ###########

    #### picking threshold #####
    do_threshold = False
    if do_threshold:
        df_golden = pd.read_csv(
            "/home/tigunova/PycharmProjects/snorkel_labels/data/profession/gold_validation.csv"
        )
        name_val = list(df_golden["auth_val"])
        # df_train['root_value'] = df_train['value'].swifter.apply(lambda x: syn_to_hob[x])
        df_train["auth_val"] = df_train[["author", "value"]].swifter.apply(
            lambda x: x["author"] + "+++" + x["value"], axis=1)
        df_val = df_train[df_train.auth_val.isin(name_val)]
        df_dev = df_train[~df_train.auth_val.isin(name_val)]
        pop_size = df_dev.shape[0]
        print("Number val", df_val.shape)
        print("Number dev", df_dev.shape)
        applier = PandasParallelLFApplier(lfs=lfs)
        df_val = df_val.merge(df_golden, on="auth_val")
        L_val = applier.apply(df=df_val, n_parallel=num_cpu)
        val_thresholds = [0.01 * x for x in range(100)]
        label_model = LabelModel(cardinality=2, verbose=True)
        with TQDMDaskProgressBar(desc="Dask Apply"):
            L_dev = applier.apply(df=df_dev, n_parallel=num_cpu)
            label_model.fit(L_dev, class_balance=[0.5, 0.5])  # , Y_dev=y_val)
            wghts = label_model.get_weights()
            print("\n".join(str(x) for x in zip(lfs, wghts)))
            probs_val = label_model.predict_proba(L=L_val)
            probs_df = pd.DataFrame(probs_val,
                                    columns=["neg_prob", "pos_prob"])
            df_val = pd.concat([df_val.reset_index(), probs_df], axis=1)
            probs_dev = label_model.predict_proba(L=L_dev)
            probs_df = pd.DataFrame(probs_dev,
                                    columns=["neg_prob", "pos_prob"])
            df_dev = pd.concat([df_dev.reset_index(), probs_df], axis=1)
            y_true = np.array(df_val["final"])
        for th in val_thresholds:
            y_pred = np.array(
                df_val["pos_prob"].apply(lambda x: 1 if x > th else 0))
            #print("true negatives")
            #print(df_val[df_val["final"] == 1][df_val["pos_prob"] <= th][["auth_val", "text"]])
            prec = precision_score(y_true, y_pred)

            pred_labels = y_pred
            true_labels = y_true
            # True Positive (TP): we predict a label of 1 (positive), and the true label is 1.
            TP = np.sum(np.logical_and(pred_labels == 1, true_labels == 1))

            # True Negative (TN): we predict a label of 0 (negative), and the true label is 0.
            TN = np.sum(np.logical_and(pred_labels == 0, true_labels == 0))

            # False Positive (FP): we predict a label of 1 (positive), but the true label is 0.
            FP = np.sum(np.logical_and(pred_labels == 1, true_labels == 0))

            # False Negative (FN): we predict a label of 0 (negative), but the true label is 1.
            FN = np.sum(np.logical_and(pred_labels == 0, true_labels == 1))

            print('TP: %i, FP: %i, TN: %i, FN: %i' % (TP, FP, TN, FN))

            # print(list(zip(label_model.predict(L=L_val_curr), y_val_curr)))
            # print("******************************")
            print("threshold %s, proportion population %.4f, precision %s" %
                  (str(th), df_dev[df_dev["pos_prob"] > th].shape[0] /
                   pop_size, str(prec)))
        exit(0)
    ###########

    with TQDMDaskProgressBar(desc="Dask Apply"):
        applier = PandasParallelLFApplier(lfs=lfs)
        L_train = applier.apply(df=df_train, n_parallel=num_cpu)

    analysis = LFAnalysis(L=L_train, lfs=lfs).lf_summary()
    print(analysis)

    df_l_train = pd.DataFrame(
        L_train, columns=["llf_" + str(x).split(",")[0] for x in lfs])
    print(df_train.shape)
    print(df_l_train.shape)
    df_train = pd.concat([df_train.reset_index(), df_l_train], axis=1)
    print(df_train.shape)
    print("********************************************")

    t4 = time.time()
    label_model = LabelModel(cardinality=2, verbose=True)
    label_model.fit(L_train=L_train,
                    n_epochs=1000,
                    lr=0.001,
                    log_freq=100,
                    seed=123,
                    class_balance=[0.3, 0.7])

    probs_train = label_model.predict_proba(L=L_train)
    print("labeling model work ", (time.time() - t4) / 60)

    df_train_filtered, probs_train_filtered = filter_unlabeled_dataframe(
        X=df_train, y=probs_train, L=L_train)

    probs_df = pd.DataFrame(probs_train_filtered,
                            columns=["neg_prob", "pos_prob"])
    print(df_train_filtered.shape)
    print(probs_df.shape)
    result_filtered = pd.concat([
        df_train_filtered[['author', 'value', 'idd']].reset_index(), probs_df
    ],
                                axis=1)
    print(result_filtered.shape)
    print("****************************************************")

    result_filtered.to_csv("/home/tigunova/some_result_" + prefix + ".csv")

    print(df_train_filtered.shape)
    print(probs_df.shape)
    df_train_filtered = pd.concat([df_train_filtered.reset_index(), probs_df],
                                  axis=1)
    df_train_filtered = df_train_filtered.drop(["index"], axis=1)
    print(df_train_filtered.shape)
    df_train_filtered.to_pickle(
        "/home/tigunova/PycharmProjects/snorkel_labels/data/profession/user_" +
        prefix + ".pkl")
    df_train_filtered.to_csv(
        "/home/tigunova/PycharmProjects/snorkel_labels/data/profession/user_" +
        prefix + ".csv")

    # df_train.iloc[L_train[:, 1] == POS].to_csv("/home/tigunova/PycharmProjects/snorkel_labels/data/user_" + prefix + ".csv")

    ### write dict
    output_threshold = 0.63
    output_dict = defaultdict(list)
    auth_hobby_dict = defaultdict(list)
    for index, row in result_filtered.iterrows():
        if row.value == row.value and row.author == row.author:
            auth_hobby_dict[row.author].append([row.value, row.pos_prob])

    allowed_labels = []
    for index, row in df_train_filtered.iterrows():
        if row.value == row.value and row.author == row.author:
            if row.pos_prob > output_threshold:
                output_dict[row.author].append([row.value] + row.idd +
                                               [row.pos_prob])
                allowed_labels.append(syn_to_hob[row.value])
    print("\n".join([
        str(y) for y in sorted(dict(Counter(allowed_labels)).items(),
                               key=lambda x: x[1])
    ]))
    print(
        "After cropping",
        sum([
            x if x < 500 else 500
            for x in dict(Counter(allowed_labels)).values()
        ]))
    print("users in total", len(output_dict))
    for auth, stuffs in output_dict.items():
        prof = ":::".join(set([x[0] for x in stuffs]))
        prob = ":::".join([str(x[-1]) for x in stuffs])
        msgs = set([x for l in stuffs for x in l[1:-1]])
        output_dict[auth] = [prof] + list(msgs) + [prob]

    with open(
            "/home/tigunova/PycharmProjects/snorkel_labels/data/profession/sources/final_author_dict_"
            + prefix + ".txt", "w") as f_out:
        f_out.write(repr(dict(auth_hobby_dict)))
    with open("/home/tigunova/users_profession1.txt", "w") as f_out:
        f_out.write(repr(dict(output_dict)))
Exemplo n.º 3
0
def run_snorkel_labelling_classification(labeling_functions, file, l_train,
                                         l_valid):
    lfs = labeling_functions
    # lfs = [lf.is_same_thread, lf.has_entities, lf.enity_overlap_jacc, lf.entity_type_overlap_jacc]
    # lfs = [is_same_thread, enity_overlap, entity_types, entity_type_overlap]

    # lfs = [is_long, has_votes, is_doctor_reply, is_same_thread, enity_overlap, has_type_dsyn, has_type_patf, has_type_sosy,
    #        has_type_dora, has_type_fndg, has_type_menp, has_type_chem, has_type_orch, has_type_horm, has_type_phsu,
    #        has_type_medd, has_type_bhvr, has_type_diap, has_type_bacs, has_type_enzy, has_type_inpo, has_type_elii]
    # lfs = [has_votes, is_doctor_reply, is_same_thread, enity_overlap]
    # lfs = [is_same_thread, enity_overlap, is_doctor_reply]

    # analysis = LFAnalysis(L=l_train, lfs=lfs).lf_summary(Y=Y_train)
    # print(analysis)
    # print(analysis['Conflicts'])
    # print(analysis['Overlaps'])

    label_model = LabelModel(cardinality=2, verbose=True)
    label_model.fit(L_train=l_train,
                    n_epochs=20000,
                    lr=0.0001,
                    log_freq=10,
                    seed=2345)
    # label_model.fit(L_train=L_train, n_epochs=20, lr=0.0001, log_freq=10, seed=81794)

    print("Model weights: " + str(label_model.get_weights()))

    valid_probabilities = label_model.predict_proba(L=l_valid)
    if 'predicted_prob' in df_valid:
        # df_valid.drop(columns=['predicted_prob'], axis=1)
        del df_valid['predicted_prob']
    df_valid.insert(50, 'predicted_prob', valid_probabilities[:, 1])

    # df_valid.to_csv("/container/filip/json/ehealthforum/trac/validation_df2.txt", sep="\t", header=True)
    # df_valid = pd.read_csv("/filip/json/ehealthforum/trac/validation_df.txt", sep="\t")

    def compute_precision_at_k(l, k):
        l = l[:k]
        return sum(l) / k

    PROBABILITY_CUTOFF = 0.5

    df_valid[
        'predicted_label'] = df_valid['predicted_prob'] >= PROBABILITY_CUTOFF

    true_positive_ratio = df_valid[df_valid.bm25_relevant == 1].count()['bm25_relevant'] / \
                          df_valid[df_valid.predicted_label == 1].count()['predicted_label']

    print("Number of True relevant: " +
          str(df_valid[df_valid.bm25_relevant == 1].count()['bm25_relevant']))
    print("Number of Predicted relevant: " + str(df_valid[
        df_valid.predicted_label == 1].count()['predicted_label']) + '\n')
    print('True positive ratio: ' + str(true_positive_ratio) + '\n')

    df_tru = df_valid.groupby(['query_thread']).head(10)['bm25_relevant']

    df_pred = df_valid.groupby(['query_thread']).head(10)['predicted_label']

    overall_precision = []

    for query, group in df_valid.groupby(['query_thread']):
        precision = compute_precision_at_k(
            group['predicted_label'].head(10).tolist(), 10)
        overall_precision.append(precision)

    print('Overall precision: ' +
          str(sum(overall_precision) / len(overall_precision)))
    print("Accuracy: " + str(accuracy_score(df_tru, df_pred)))

    label_model_acc = label_model.score(L=l_valid, Y=Y_valid)["accuracy"]
    print(f"{'Label Model Accuracy:':<25} {label_model_acc * 100:.1f}%")
Exemplo n.º 4
0
class Modeler:
    def __init__(self,
                 df_train,
                 df_dev,
                 df_valid,
                 df_test,
                 df_heldout,
                 lfs={},
                 label_model=None):
        df_train["seen"] = 0
        self.df_train = df_train.reset_index()
        self.df_dev = df_dev
        self.df_valid = df_valid
        self.df_test = df_test
        self.df_heldout = df_heldout
        #self.Y_train = df_train.label.values
        self.Y_dev = df_dev.label.values
        self.Y_valid = df_valid.label.values
        self.Y_test = df_test.label.values
        self.Y_heldout = df_heldout.label.values

        self.lfs = lfs

        self.L_train = None
        self.L_dev = None
        self.L_valid = None
        self.L_heldout = None
        cardinality = len(df_valid.label.unique())

        # for DEMOing purposes
        self.first_text_indices = [
            1262,  #"check out" "youtube"
            1892,  # I love
            1117,  # url concept
            1706,  # emoji concept
            952,  # "nice"
            971,  # positive concept
            958,  # actually use emoji concept
        ]

        self.count = 0

        if label_model is None:
            self.label_model = LabelModel(cardinality=cardinality,
                                          verbose=True)
        else:
            self.label_model = label_model

        self.vectorizer = CountVectorizer(ngram_range=(1, 2))
        self.vectorizer.fit(df_train.text.tolist())

    def get_lfs(self):
        return list(self.lfs.values())

    def add_lfs(self, new_lfs: dict):
        self.lfs.update(new_lfs)

    def remove_lfs(self, old_lf_ids: list):
        for lf_id in old_lf_ids:
            del self.lfs[lf_id]
        return len(self.lfs)

    def apply_lfs(self):
        applier = PandasLFApplier(lfs=self.get_lfs())
        self.L_train = applier.apply(df=self.df_train)
        self.L_dev = applier.apply(df=self.df_dev)
        self.L_heldout = applier.apply(df=self.df_heldout)
        #self.L_valid = applier.apply(df=self.df_valid)

    def find_duplicate_signature(self):
        label_matrix = np.vstack([self.L_train, self.L_dev])
        seen_signatures = {}
        dupes = {}
        lfs = self.get_lfs()
        signatures = [
            hash(label_matrix[:, i].tostring()) for i in range(len(lfs))
        ]
        for i, s in enumerate(signatures):
            lf = lfs[i]
            if s in seen_signatures:
                dupes[lf.name] = seen_signatures[s]
            else:
                seen_signatures[s] = lf.name
        return dupes

    def lf_examples(self, lf_id, n=5):
        lf = self.lfs[lf_id]
        applier = PandasLFApplier(lfs=[lf])
        L_train = applier.apply(df=self.df_train)
        labeled_examples = self.df_train[L_train != -1]
        samples = labeled_examples.sample(min(n, len(labeled_examples)),
                                          random_state=13)
        return [{"text": t} for t in samples["text"].values]

    def lf_mistakes(self, lf_id, n=5):
        lf = self.lfs[lf_id]
        applier = PandasLFApplier(lfs=[lf])
        L_dev = applier.apply(df=self.df_dev).squeeze()
        labeled_examples = self.df_dev[(L_dev != -1)
                                       & (L_dev != self.df_dev["label"])]
        samples = labeled_examples.sample(min(n, len(labeled_examples)),
                                          random_state=13)
        return [{"text": t} for t in samples["text"].values]

    def fit_label_model(self):
        assert self.L_train is not None

        self.label_model.fit(L_train=self.L_train,
                             n_epochs=1000,
                             lr=0.001,
                             log_freq=100,
                             seed=123)

    def analyze_lfs(self):
        if len(self.lfs) > 0:
            df = LFAnalysis(L=self.L_train, lfs=self.get_lfs()).lf_summary()
            dev_df = LFAnalysis(L=self.L_dev,
                                lfs=self.get_lfs()).lf_summary(Y=self.Y_dev)
            df = df.merge(dev_df,
                          how="outer",
                          suffixes=(" Training", " Dev."),
                          left_index=True,
                          right_index=True)
            df["Weight"] = self.label_model.get_weights()
            df["Duplicate"] = None
            for dupe, OG in self.find_duplicate_signature().items():
                print("Duplicate labeling signature detected")
                print(dupe, OG)
                df.at[dupe, "Duplicate"] = OG

            return df
        return None

    def get_label_model_stats(self):
        result = self.label_model.score(L=self.L_dev,
                                        Y=self.Y_dev,
                                        metrics=["f1", "precision", "recall"])

        probs_train = self.label_model.predict_proba(L=self.L_train)
        df_train_filtered, probs_train_filtered = filter_unlabeled_dataframe(
            X=self.df_train, y=probs_train, L=self.L_train)
        result["training_label_coverage"] = len(probs_train_filtered) / len(
            probs_train)
        result["class_0_ratio"] = (probs_train_filtered[:, 0] >
                                   0.5).sum() / len(probs_train_filtered)
        if len(probs_train_filtered) == 0:
            result["class_0_ratio"] = 0

        return result

    def get_heldout_stats(self):
        if self.L_heldout is not None:
            return self.label_model.score(
                L=self.L_heldout,
                Y=self.Y_heldout,
                metrics=["f1", "precision", "recall"])
        return {}

    def train(self):
        probs_train = self.label_model.predict_proba(L=self.L_train)

        df_train_filtered, probs_train_filtered = filter_unlabeled_dataframe(
            X=self.df_train, y=probs_train, L=self.L_train)

        if len(df_train_filtered) == 0:
            print("Labeling functions cover none of the training examples!",
                  file=sys.stderr)
            return {"micro_f1": 0}

        #from tensorflow.keras.utils import to_categorical
        #df_train_filtered, probs_train_filtered = self.df_dev, to_categorical(self.df_dev["label"].values)

        vectorizer = self.vectorizer
        X_train = vectorizer.transform(df_train_filtered.text.tolist())

        X_dev = vectorizer.transform(self.df_dev.text.tolist())
        X_valid = vectorizer.transform(self.df_valid.text.tolist())
        X_test = vectorizer.transform(self.df_test.text.tolist())

        self.keras_model = get_keras_logreg(input_dim=X_train.shape[1])

        self.keras_model.fit(
            x=X_train,
            y=probs_train_filtered,
            validation_data=(X_valid, preds_to_probs(self.Y_valid, 2)),
            callbacks=[get_keras_early_stopping()],
            epochs=20,
            verbose=0,
        )

        preds_test = self.keras_model.predict(x=X_test).argmax(axis=1)

        #return preds_test
        return self.get_stats(self.Y_test, preds_test)

    def get_heldout_lr_stats(self):
        X_heldout = self.vectorizer.transform(self.df_heldout.text.tolist())
        preds_test = self.keras_model.predict(x=X_heldout).argmax(axis=1)
        return self.get_stats(self.Y_heldout, preds_test)

    def get_stats(self, Y_test, preds_test):
        label_classes = np.unique(self.Y_test)
        accuracy = metrics.accuracy_score(Y_test, preds_test)
        precision_0, precision_1 = metrics.precision_score(
            Y_test, preds_test, labels=label_classes, average=None)
        recall_0, recall_1 = metrics.recall_score(Y_test,
                                                  preds_test,
                                                  labels=label_classes,
                                                  average=None)
        test_f1 = metrics.f1_score(Y_test, preds_test, labels=label_classes)

        #recall_0, recall_1 = metrics.precision_recall_fscore_support(self.Y_test, preds_test, labels=label_classes)["recall"]
        return {
            "micro_f1": test_f1,
            "recall_0": recall_0,
            "precision_0": precision_0,
            "accuracy": accuracy,
            "recall_1": recall_1,
            "precision_1": precision_1
        }

    def entropy(self, prob_dist):
        #return(-(L_row_i==-1).sum())
        return (-sum([x * log(x) for x in prob_dist]))

    def save(self, dir_name):
        self.label_model.save(os.path.join(dir_name, 'label_model.pkl'))
        with open(os.path.join(dir_name, 'model_lfs.pkl'), "wb+") as file:
            pickle.dump(self.lfs, file)

    def load(self, dir_name):
        with open(os.path.join(dir_name, 'model_lfs.pkl'), "rb") as file:
            lfs = pickle.load(file)
            label_model = LabelModel.load(
                os.path.join(dir_name, 'label_model.pkl'))
            self.lfs = lfs
            self.label_model = label_model