Пример #1
0
    def test_save_and_load(self):
        L = np.array([[0, -1, 0], [0, 1, 1]])
        label_model = LabelModel(cardinality=2, verbose=False)
        label_model.fit(L, n_epochs=1)
        original_preds = label_model.predict(L)

        dir_path = tempfile.mkdtemp()
        save_path = dir_path + "label_model.pkl"
        label_model.save(save_path)

        label_model_new = LabelModel(cardinality=2, verbose=False)
        label_model_new.load(save_path)
        loaded_preds = label_model_new.predict(L)
        shutil.rmtree(dir_path)

        np.testing.assert_array_equal(loaded_preds, original_preds)
Пример #2
0
def label_dataset(
    task: Task,
    dataset: Dataset,
    path_config: Optional[PathConfig] = None,
    debug: bool = False,
):
    path_config = path_config or PathConfig.load()

    applied_heuristics_df = pd.read_pickle(
        str(path_config.generated / task.name /
            f"heuristic_matrix_{dataset.name}.pkl"))

    label_model = LabelModel()
    label_model.load(str(path_config.generated / task.name /
                         "label_model.pkl"))
    df = dataset.load()
    df_labeled = do_labeling(label_model, applied_heuristics_df.to_numpy(), df,
                             task.labels)

    if debug:
        for (
                heuristic_name,
                applied_heuristic_series,
        ) in applied_heuristics_df.iteritems():
            applied_heuristics_df[
                heuristic_name] = applied_heuristic_series.map({
                    0: heuristic_name,
                    1: heuristic_name,
                    -1: ""
                })
        col_lfs = applied_heuristics_df.apply(
            lambda row: ";".join([elm for elm in row if elm]), axis=1)
        df_labeled["lfs"] = col_lfs

    labeled_data_path = path_config.labeled_data / task.name
    if not labeled_data_path.exists():
        labeled_data_path.mkdir(parents=True)
    target_file = labeled_data_path / f"{dataset.name}.labeled.csv"
    df_labeled.to_csv(target_file, index=False)
    print(f"Labeled dataset has been written to {target_file}.")