コード例 #1
0
def write_data_labels(num_processes, merged_entity_emb_file,
                      merged_storage_type, data_file, out_file,
                      train_in_candidates, dump_embs, data_config):
    logger = logging.getLogger(__name__)

    # Get sent mapping
    start = time.time()
    sent_idx_map = get_sent_idx_map(merged_entity_emb_file,
                                    merged_storage_type)
    sent_idx_map_file = tempfile.NamedTemporaryFile(
        suffix="bootleg_sent_idx_map")
    utils.create_single_item_trie(sent_idx_map,
                                  out_file=sent_idx_map_file.name)

    # Chunk file for parallel writing
    create_ex_indir = tempfile.TemporaryDirectory()
    create_ex_outdir = tempfile.TemporaryDirectory()
    logger.debug(f"Counting lines")
    total_input = sum(1 for _ in open(data_file))
    chunk_input = int(np.ceil(total_input / num_processes))
    logger.debug(
        f"Chunking up {total_input} lines into subfiles of size {chunk_input} lines"
    )
    total_input_from_chunks, input_files_dict = utils.chunk_file(
        data_file, create_ex_indir.name, chunk_input)

    input_files = list(input_files_dict.keys())
    input_file_lines = [input_files_dict[k] for k in input_files]
    output_files = [
        in_file_name.replace(create_ex_indir.name, create_ex_outdir.name)
        for in_file_name in input_files
    ]
    assert total_input == total_input_from_chunks, f"Lengths of files {total_input} doesn't mathc {total_input_from_chunks}"
    logger.debug(f"Done chunking files")

    logger.info(f'Starting to write files with {num_processes} processes')
    pool = multiprocessing.Pool(processes=num_processes,
                                initializer=write_data_labels_initializer,
                                initargs=[
                                    merged_entity_emb_file,
                                    merged_storage_type,
                                    sent_idx_map_file.name,
                                    train_in_candidates, dump_embs, data_config
                                ])

    input_args = list(zip(input_files, input_file_lines, output_files))
    # Store output files and counts for saving in next step
    total = 0
    for res in pool.imap(write_data_labels_hlp, input_args, chunksize=1):
        total += 1

    # Merge output files to final file
    logger.debug(f"Merging output files")
    with open(out_file, 'wb') as outfile:
        for filename in glob.glob(os.path.join(create_ex_outdir.name, "*")):
            if filename == out_file:
                # don't want to copy the output into the output
                continue
            with open(filename, 'rb') as readfile:
                shutil.copyfileobj(readfile, outfile)
    sent_idx_map_file.close()
    create_ex_indir.cleanup()
    create_ex_outdir.cleanup()
    logger.info(f'Time to write files {time.time()-start}s')
コード例 #2
0
def merge_subsentences(num_processes,
                       data_file,
                       to_save_file,
                       to_save_storage,
                       to_read_file,
                       to_read_storage,
                       dump_embs=False):
    logger = logging.getLogger(__name__)
    logger.debug(f"Getting sentence mapping")
    sent_start_map, total_num_mentions = get_sent_start_map(data_file)
    sent_start_map_file = tempfile.NamedTemporaryFile(
        suffix="bootleg_sent_start_map")
    utils.create_single_item_trie(sent_start_map,
                                  out_file=sent_start_map_file.name)
    logger.debug(f"Done with sentence mapping")

    full_pred_data = np.memmap(to_read_file, dtype=to_read_storage, mode='r')
    M = int(full_pred_data[0]['M'])
    K = int(full_pred_data[0]['K'])
    hidden_size = int(full_pred_data[0]['hidden_size'])

    filt_emb_data = np.memmap(to_save_file,
                              dtype=to_save_storage,
                              mode='w+',
                              shape=(total_num_mentions, ))
    filt_emb_data['hidden_size'] = hidden_size
    filt_emb_data['sent_idx'][:] = -1
    filt_emb_data['alias_list_pos'][:] = -1

    chunk_size = int(np.ceil(len(full_pred_data) / num_processes))
    all_ids = list(range(0, len(full_pred_data)))
    row_idx_set_chunks = [
        all_ids[ids:ids + chunk_size]
        for ids in range(0, len(full_pred_data), chunk_size)
    ]
    input_args = [[M, K, hidden_size, dump_embs, chunk]
                  for chunk in row_idx_set_chunks]

    logger.info(
        f"Merging sentences together with {num_processes} processes. Starting pool"
    )

    pool = multiprocessing.Pool(processes=num_processes,
                                initializer=merge_subsentences_initializer,
                                initargs=[
                                    to_save_file, to_save_storage,
                                    to_read_file, to_read_storage,
                                    sent_start_map_file.name
                                ])
    logger.debug(f"Finished pool")
    start = time.time()
    seen_ids = set()
    for sent_ids_seen in pool.imap_unordered(merge_subsentences_hlp,
                                             input_args,
                                             chunksize=1):
        for emb_id in sent_ids_seen:
            assert emb_id not in seen_ids, f'{emb_id} already seen, something went wrong with sub-sentences'
            seen_ids.add(emb_id)
    sent_start_map_file.close()
    logger.info(f'Time to merge sub-sentences {time.time()-start}s')
    return
コード例 #3
0
def write_data_labels(
    num_processes,
    result_alias_offset,
    merged_entity_emb_file,
    merged_storage_type,
    sent_idx2row,
    cache_folder,
    out_file,
    entity_dump,
    train_in_candidates,
    max_candidates,
    dump_embs,
    trie_candidate_map_folder=None,
    trie_qid2eid_file=None,
):
    """Takes the flattened data from merge_sentences and writes out predictions
    to a file, one line per sentence.

    The embedding ids are added to the file if dump_embs is True.

    Args:
        num_processes: number of processes
        result_alias_offset: alias offset of this batch of examples for writing out
        merged_entity_emb_file: input memmap file after merge sentences
        merged_storage_type: input file storage type
        sent_idx2row: Dict of sentence idx to row relevant to this subbatch
        cache_folder: folder to save temporary outputs
        out_file: final output file for predictions
        entity_dump: entity dump
        train_in_candidates: whether NC entities are not in candidate lists
        max_candidates: maximum number of candidates
        dump_embs: whether to dump embeddings or not
        trie_candidate_map_folder: folder where trie of alias->candidate map is stored for parallel proccessing
        trie_qid2eid_file: file where trie of qid->eid map is stored for parallel proccessing

    Returns:
    """
    st = time.time()
    sental2embid = get_sental2embid(merged_entity_emb_file,
                                    merged_storage_type)
    log_rank_0_debug(logger,
                     f"Finished getting sentence map {time.time() - st}s")

    total_input = len(sent_idx2row)
    if num_processes == 1:
        filt_emb_data = np.memmap(merged_entity_emb_file,
                                  dtype=merged_storage_type,
                                  mode="r+")
        write_data_labels_single(
            sentidx2row=sent_idx2row,
            output_file=out_file,
            filt_emb_data=filt_emb_data,
            sental2embid=sental2embid,
            alias_cand_map=entity_dump.get_alias2qids(),
            qid2eid=entity_dump.get_qid2eid(),
            result_alias_offset=result_alias_offset,
            train_in_cands=train_in_candidates,
            max_cands=max_candidates,
            dump_embs=dump_embs,
        )
    else:
        assert (
            trie_candidate_map_folder is not None
        ), "trie_candidate_map_folder is None and you have parallel turned on"
        assert (trie_qid2eid_file is not None
                ), "trie_qid2eid_file is None and you have parallel turned on"

        # Get trie of sentence map
        trie_folder = os.path.join(cache_folder, "bootleg_sental2embid")
        utils.ensure_dir(trie_folder)
        trie_file = os.path.join(trie_folder, "sentidx.marisa")
        utils.create_single_item_trie(sental2embid, out_file=trie_file)
        # Chunk file for parallel writing
        # We do not use TemporaryFolders as the temp dir may not have enough space for large files
        create_ex_indir = os.path.join(cache_folder,
                                       "_bootleg_eval_temp_indir")
        utils.ensure_dir(create_ex_indir)
        create_ex_outdir = os.path.join(cache_folder,
                                        "_bootleg_eval_temp_outdir")
        utils.ensure_dir(create_ex_outdir)
        chunk_input = int(np.ceil(total_input / num_processes))
        logger.debug(
            f"Chunking up {total_input} lines into subfiles of size {chunk_input} lines"
        )
        # Chunk up dictionary of data for parallel processing
        input_files = []
        i = 0
        cur_lines = 0
        file_split = os.path.join(create_ex_indir, f"out{i}.jsonl")
        open_file = open(file_split, "w")
        for s_idx in sent_idx2row:
            if cur_lines >= chunk_input:
                open_file.close()
                input_files.append(file_split)
                cur_lines = 0
                i += 1
                file_split = os.path.join(create_ex_indir, f"out{i}.jsonl")
                open_file = open(file_split, "w")
            line = sent_idx2row[s_idx]
            open_file.write(ujson.dumps(line) + "\n")
            cur_lines += 1
        open_file.close()
        input_files.append(file_split)
        # Generation input/output pairs
        output_files = [
            in_file_name.replace(create_ex_indir, create_ex_outdir)
            for in_file_name in input_files
        ]
        log_rank_0_debug(logger, f"Done chunking files. Starting pool")

        pool = multiprocessing.Pool(
            processes=num_processes,
            initializer=write_data_labels_initializer,
            initargs=[
                merged_entity_emb_file,
                merged_storage_type,
                trie_file,
                result_alias_offset,
                train_in_candidates,
                max_candidates,
                dump_embs,
                trie_candidate_map_folder,
                trie_qid2eid_file,
            ],
        )

        input_args = list(zip(input_files, output_files))

        total = 0
        for res in pool.imap(write_data_labels_hlp, input_args, chunksize=1):
            total += 1

        # Merge output files to final file
        log_rank_0_debug(logger, f"Merging output files")
        with open(out_file, "wb") as outfile:
            for filename in glob.glob(os.path.join(create_ex_outdir, "*")):
                if filename == out_file:
                    # don't want to copy the output into the output
                    continue
                with open(filename, "rb") as readfile:
                    shutil.copyfileobj(readfile, outfile)
コード例 #4
0
def merge_subsentences(
    num_processes,
    subset_sent_idx2num_mens,
    cache_folder,
    to_save_file,
    to_save_storage,
    to_read_file,
    to_read_storage,
    dump_embs=False,
):
    """Flatten all sentences back together over sub-sentences; removing the PAD
    aliases from the data I.e., converts from sent_idx -> array of values to
    (sent_idx, alias_idx) -> value with varying numbers of aliases per
    sentence.

    Args:
        num_processes: number of processes
        subset_sent_idx2num_mens: Dict of sentence index to number of mentions for this batch
        cache_folder: cache directory
        to_save_file: memmap file to save results to
        to_save_storage: save file storage type
        to_read_file: memmap file to read predictions from
        to_read_storage: read file storage type
        dump_embs: whether to save embeddings or not

    Returns:
    """
    # Compute sent idx to offset so we know where to fill in mentions
    cur_offset = 0
    sentidx2offset = {}
    for k, v in subset_sent_idx2num_mens.items():
        sentidx2offset[k] = cur_offset
        cur_offset += v
        # print("Sent Idx, Num Mens, Offset", k, v, cur_offset)
    total_num_mentions = cur_offset
    # print("TOTAL", total_num_mentions)
    full_pred_data = np.memmap(to_read_file, dtype=to_read_storage, mode="r")
    M = int(full_pred_data[0]["M"])
    K = int(full_pred_data[0]["K"])
    hidden_size = int(full_pred_data[0]["hidden_size"])
    # print("TOTAL MENS", total_num_mentions)
    filt_emb_data = np.memmap(to_save_file,
                              dtype=to_save_storage,
                              mode="w+",
                              shape=(total_num_mentions, ))
    filt_emb_data["hidden_size"] = hidden_size
    filt_emb_data["sent_idx"][:] = -1
    filt_emb_data["alias_list_pos"][:] = -1

    all_ids = list(range(0, len(full_pred_data)))
    start = time.time()
    if num_processes == 1:
        seen_ids = merge_subsentences_single(
            M,
            K,
            hidden_size,
            dump_embs,
            all_ids,
            filt_emb_data,
            full_pred_data,
            sentidx2offset,
        )
    else:
        # Get trie for sentence start map
        trie_folder = os.path.join(cache_folder, "bootleg_sent_idx2num_mens")
        utils.ensure_dir(trie_folder)
        trie_file = os.path.join(trie_folder, "sentidx.marisa")
        utils.create_single_item_trie(sentidx2offset, out_file=trie_file)
        # Chunk up date
        chunk_size = int(np.ceil(len(full_pred_data) / num_processes))
        row_idx_set_chunks = [
            all_ids[ids:ids + chunk_size]
            for ids in range(0, len(full_pred_data), chunk_size)
        ]
        # Start pool
        input_args = [[M, K, hidden_size, dump_embs, chunk]
                      for chunk in row_idx_set_chunks]
        log_rank_0_debug(
            logger,
            f"Merging sentences together with {num_processes} processes")
        pool = multiprocessing.Pool(
            processes=num_processes,
            initializer=merge_subsentences_initializer,
            initargs=[
                to_save_file,
                to_save_storage,
                to_read_file,
                to_read_storage,
                trie_file,
            ],
        )

        seen_ids = set()
        for sent_ids_seen in pool.imap_unordered(merge_subsentences_hlp,
                                                 input_args,
                                                 chunksize=1):
            for emb_id in sent_ids_seen:
                assert (
                    emb_id not in seen_ids
                ), f"{emb_id} already seen, something went wrong with sub-sentences"
                seen_ids.add(emb_id)
    # filt_emb_data = np.memmap(to_save_file, dtype=to_save_storage, mode="r")
    # for i in range(len(filt_emb_data)):
    #     si = filt_emb_data[i]["sent_idx"]
    #     al_test = filt_emb_data[i]["alias_list_pos"]
    #     if si == -1 or al_test == -1:
    #         print("BAD", i, filt_emb_data[i])
    #         import ipdb; ipdb.set_trace()
    logging.debug(f"Saw {len(seen_ids)} sentences")
    logging.debug(f"Time to merge sub-sentences {time.time() - start}s")
    return
コード例 #5
0
def disambig_dump_preds(
    result_idx,
    result_alias_offset,
    config,
    res_dict,
    sent_idx2num_mens,
    sent_idx2row,
    save_folder,
    entity_symbols,
    dump_embs,
    task_name,
):
    """Dumps the predictions of a disambiguation task.

    Args:
        result_idx: batch index of the result arrays
        result_alias_offset: overall offset of the starting example (i.e., the number of previous mens already written)
        config: model config
        res_dict: result dictionary from Emmental predict
        sent_idx2num_mens: Dict sentence idx to number of mentions
        sent_idx2row: Dict sentence idx to row of eval data
        save_folder: folder to save results
        entity_symbols: entity symbols
        dump_embs: whether to save the contextualized embeddings or not
        task_name: task name

    Returns: saved prediction file, saved embedding file (will be None if dump_embs is False)
    """
    num_processes = min(config.run_config.dataset_threads,
                        int(multiprocessing.cpu_count() * 0.9))
    cache_dir = os.path.join(save_folder, f"cache_{result_idx}")
    utils.ensure_dir(cache_dir)
    trie_candidate_map_folder = None
    trie_qid2eid_file = None
    # Save the alias->QID candidate map and the QID->EID mapping in memory efficient structures for faster
    # prediction dumping
    if num_processes > 1:
        entity_prep_dir = data_utils.get_emb_prep_dir(config.data_config)
        trie_candidate_map_folder = os.path.join(entity_prep_dir,
                                                 "for_dumping_preds",
                                                 "alias_cand_trie")
        utils.ensure_dir(trie_candidate_map_folder)
        check_and_create_alias_cand_trie(trie_candidate_map_folder,
                                         entity_symbols)
        trie_qid2eid_file = os.path.join(entity_prep_dir, "for_dumping_preds",
                                         "qid2eid_trie.marisa")
        if not os.path.exists(trie_qid2eid_file):
            utils.create_single_item_trie(entity_symbols.get_qid2eid(),
                                          out_file=trie_qid2eid_file)

    # This is dumping
    disambig_res_dict = {}
    for k in res_dict:
        assert task_name in res_dict[
            k], f"{task_name} not in res_dict for key {k}"
        disambig_res_dict[k] = res_dict[k][task_name]

    # write to file (M x hidden x size for each data point -- next step will deal with recovering original sentence
    # indices for overflowing sentences)
    unmerged_entity_emb_file = os.path.join(save_folder, f"entity_embs.pt")
    merged_entity_emb_file = os.path.join(save_folder,
                                          f"entity_embs_unmerged.pt")
    emb_file_config = os.path.splitext(
        unmerged_entity_emb_file)[0] + "_config.npy"
    M = config.data_config.max_aliases
    K = entity_symbols.max_candidates + (
        not config.data_config.train_in_candidates)
    if dump_embs:
        unmerged_storage_type = np.dtype([
            ("M", int),
            ("K", int),
            ("hidden_size", int),
            ("sent_idx", int),
            ("subsent_idx", int),
            ("alias_list_pos", int, (M, )),
            ("entity_emb", float, M * config.model_config.hidden_size),
            ("final_loss_true", int, (M, )),
            ("final_loss_pred", int, (M, )),
            ("final_loss_prob", float, (M, )),
            ("final_loss_cand_probs", float, M * K),
        ])
        merged_storage_type = np.dtype([
            ("hidden_size", int),
            ("sent_idx", int),
            ("alias_list_pos", int),
            ("entity_emb", float, config.model_config.hidden_size),
            ("final_loss_pred", int),
            ("final_loss_prob", float),
            ("final_loss_cand_probs", float, K),
        ])
    else:
        # don't need to extract contextualized entity embedding
        unmerged_storage_type = np.dtype([
            ("M", int),
            ("K", int),
            ("hidden_size", int),
            ("sent_idx", int),
            ("subsent_idx", int),
            ("alias_list_pos", int, (M, )),
            ("final_loss_true", int, (M, )),
            ("final_loss_pred", int, (M, )),
            ("final_loss_prob", float, (M, )),
            ("final_loss_cand_probs", float, M * K),
        ])
        merged_storage_type = np.dtype([
            ("hidden_size", int),
            ("sent_idx", int),
            ("alias_list_pos", int),
            ("final_loss_pred", int),
            ("final_loss_prob", float),
            ("final_loss_cand_probs", float, K),
        ])
    mmap_file = np.memmap(
        unmerged_entity_emb_file,
        dtype=unmerged_storage_type,
        mode="w+",
        shape=(len(disambig_res_dict["uids"]), ),
    )
    # print("MEMMAP FILE SHAPE", len(disambig_res_dict["uids"]))
    # Init sent_idx to -1 for debugging
    mmap_file[:]["sent_idx"] = -1
    np.save(emb_file_config, unmerged_storage_type, allow_pickle=True)
    log_rank_0_debug(
        logger,
        f"Created file {unmerged_entity_emb_file} to save predictions.")

    log_rank_0_debug(logger, f'{len(disambig_res_dict["uids"])} samples')
    for_iteration = [
        disambig_res_dict["uids"],
        disambig_res_dict["golds"],
        disambig_res_dict["probs"],
        disambig_res_dict["preds"],
    ]
    all_sent_idx = set()
    for i, (uid, gold, probs, model_pred) in enumerate(zip(*for_iteration)):
        # disambig_res_dict["output"] is dict with keys ['_input__alias_orig_list_pos',
        # 'bootleg_pred_1', '_input__sent_idx', '_input__for_dump_gold_cand_K_idx_train', '_input__subsent_idx', 0, 1]
        sent_idx = disambig_res_dict["outputs"]["_input__sent_idx"][i]
        # print("INSIDE LOOP", sent_idx, "AT", i)
        subsent_idx = disambig_res_dict["outputs"]["_input__subsent_idx"][i]
        alias_orig_list_pos = disambig_res_dict["outputs"][
            "_input__alias_orig_list_pos"][i]
        gold_cand_K_idx_train = disambig_res_dict["outputs"][
            "_input__for_dump_gold_cand_K_idx_train"][i]
        output_embeddings = disambig_res_dict["outputs"][
            f"{PRED_LAYER}_ent_embs"][i]
        mmap_file[i]["M"] = M
        mmap_file[i]["K"] = K
        mmap_file[i]["hidden_size"] = config.model_config.hidden_size
        mmap_file[i]["sent_idx"] = sent_idx
        mmap_file[i]["subsent_idx"] = subsent_idx
        mmap_file[i]["alias_list_pos"] = alias_orig_list_pos
        # This will give all aliases seen by the model during training, independent of if it's gold or not
        mmap_file[i][f"final_loss_true"] = gold_cand_K_idx_train.reshape(M)

        # get max for each alias, probs is M x K
        max_probs = probs.max(axis=1)
        pred_cands = probs.argmax(axis=1)

        mmap_file[i]["final_loss_pred"] = pred_cands
        mmap_file[i]["final_loss_prob"] = max_probs
        mmap_file[i]["final_loss_cand_probs"] = probs.reshape(1, -1)

        all_sent_idx.add(str(sent_idx))
        # final_entity_embs is M x K x hidden_size, pred_cands is M
        if dump_embs:
            chosen_entity_embs = select_embs(embs=output_embeddings,
                                             pred_cands=pred_cands,
                                             M=M)

            # write chosen entity embs to file for contextualized entity embeddings
            mmap_file[i]["entity_emb"] = chosen_entity_embs.reshape(1, -1)

    # for i in range(len(mmap_file)):
    #     si = mmap_file[i]["sent_idx"]
    #     if -1 == si:
    #         import pdb
    #         pdb.set_trace()
    #     assert si != -1, f"{i} {mmap_file[i]}"
    # Store all predicted sentences to filter the sentence mapping by
    subset_sent_idx2num_mens = {
        k: v
        for k, v in sent_idx2num_mens.items() if k in all_sent_idx
    }
    # print("ALL SEEN", all_sent_idx)
    subsent_sent_idx2row = {
        k: v
        for k, v in sent_idx2row.items() if k in all_sent_idx
    }
    result_file = get_result_file(result_idx, save_folder)
    log_rank_0_debug(logger, f"Writing predictions to {result_file}...")
    merge_subsentences(
        num_processes=num_processes,
        subset_sent_idx2num_mens=subset_sent_idx2num_mens,
        cache_folder=cache_dir,
        to_save_file=merged_entity_emb_file,
        to_save_storage=merged_storage_type,
        to_read_file=unmerged_entity_emb_file,
        to_read_storage=unmerged_storage_type,
        dump_embs=dump_embs,
    )
    write_data_labels(
        num_processes=num_processes,
        result_alias_offset=result_alias_offset,
        merged_entity_emb_file=merged_entity_emb_file,
        merged_storage_type=merged_storage_type,
        sent_idx2row=subsent_sent_idx2row,
        cache_folder=cache_dir,
        out_file=result_file,
        entity_dump=entity_symbols,
        train_in_candidates=config.data_config.train_in_candidates,
        max_candidates=entity_symbols.max_candidates,
        dump_embs=dump_embs,
        trie_candidate_map_folder=trie_candidate_map_folder,
        trie_qid2eid_file=trie_qid2eid_file,
    )

    out_emb_file = None
    filt_emb_data = np.memmap(merged_entity_emb_file,
                              dtype=merged_storage_type,
                              mode="r+")
    total_mentions_seen = len(filt_emb_data)
    # save easier-to-use embedding file
    if dump_embs:
        hidden_size = filt_emb_data[0]["hidden_size"]
        out_emb_file = get_emb_file(result_idx, save_folder)
        np.save(out_emb_file,
                filt_emb_data["entity_emb"].reshape(-1, hidden_size))
        log_rank_0_debug(
            logger,
            f"Saving contextual entity embeddings for {result_idx} to {out_emb_file}",
        )
    filt_emb_data = None

    # Cleanup cache - sometimes the file in cache_dir is still open so we need to retry to delete it
    try_rmtree(cache_dir)

    log_rank_0_debug(logger,
                     f"Wrote predictions for {result_idx} to {result_file}")
    return result_file, out_emb_file, all_sent_idx, total_mentions_seen
コード例 #6
0
    def test_write_out_subsentences(self):

        merged_entity_emb_file = tempfile.NamedTemporaryFile()
        out_file = tempfile.NamedTemporaryFile()
        data_file = tempfile.NamedTemporaryFile()
        cache_folder = tempfile.TemporaryDirectory()

        entity_dir = "test/entity_db"
        entity_map_dir = "entity_mappings"

        entity_symbols = EntitySymbolsSubclass()
        entity_symbols.save(save_dir=os.path.join(entity_dir, entity_map_dir))

        total_num_mentions = 7
        K = 2
        hidden_size = 2

        # create data file -- just needs aliases and sentence indices
        data = [
            {
                "aliases": ["a", "b"],
                "sent_idx_unq": 0
            },
            {
                "aliases": ["c", "d", "e", "f", "g"],
                "sent_idx_unq": 1
            },
        ]
        # Dict is a string key for trie
        sent_idx2rows = {"0": data[0], "1": data[1]}
        with jsonlines.open(data_file.name, "w") as f:
            for row in data:
                f.write(row)

        merged_storage_type = np.dtype([
            ("hidden_size", int),
            ("sent_idx", int),
            ("alias_list_pos", int),
            ("entity_emb", float, hidden_size),
            ("final_loss_pred", int),
            ("final_loss_prob", float),
            ("final_loss_cand_probs", float, K),
        ])

        merged_entity_emb = np.memmap(
            merged_entity_emb_file.name,
            dtype=merged_storage_type,
            mode="w+",
            shape=(total_num_mentions, ),
        )
        # 2 sentences, 1st sent has 1 subsentence, 2nd sentence has 2 subsentences - 7 mentions total
        merged_entity_emb["hidden_size"] = hidden_size
        # first men
        merged_entity_emb[0]["sent_idx"] = 0
        merged_entity_emb[0]["alias_list_pos"] = 0
        merged_entity_emb[0]["entity_emb"] = np.array([0, 1])
        merged_entity_emb[0]["final_loss_pred"] = 1
        merged_entity_emb[0]["final_loss_prob"] = 0.9
        merged_entity_emb[0]["final_loss_cand_probs"] = np.array([0.1, 0.9])
        # second men
        merged_entity_emb[1]["sent_idx"] = 0
        merged_entity_emb[1]["alias_list_pos"] = 1
        merged_entity_emb[1]["entity_emb"] = np.array([2, 3])
        merged_entity_emb[1]["final_loss_pred"] = 1
        merged_entity_emb[1]["final_loss_prob"] = 0.9
        merged_entity_emb[1]["final_loss_cand_probs"] = np.array([0.1, 0.9])
        # third men
        merged_entity_emb[2]["sent_idx"] = 1
        merged_entity_emb[2]["alias_list_pos"] = 0
        merged_entity_emb[2]["entity_emb"] = np.array([4, 5])
        merged_entity_emb[2]["final_loss_pred"] = 0
        merged_entity_emb[2]["final_loss_prob"] = 0.9
        merged_entity_emb[2]["final_loss_cand_probs"] = np.array([0.9, 0.1])
        # fourth men
        merged_entity_emb[3]["sent_idx"] = 1
        merged_entity_emb[3]["alias_list_pos"] = 1
        merged_entity_emb[3]["entity_emb"] = np.array([6, 7])
        merged_entity_emb[3]["final_loss_pred"] = 0
        merged_entity_emb[3]["final_loss_prob"] = 0.9
        merged_entity_emb[3]["final_loss_cand_probs"] = np.array([0.9, 0.1])
        # fifth men
        merged_entity_emb[4]["sent_idx"] = 1
        merged_entity_emb[4]["alias_list_pos"] = 2
        merged_entity_emb[4]["entity_emb"] = np.array([10, 11])
        merged_entity_emb[4]["final_loss_pred"] = 1
        merged_entity_emb[4]["final_loss_prob"] = 0.9
        merged_entity_emb[4]["final_loss_cand_probs"] = np.array([0.1, 0.9])
        # sixth men
        merged_entity_emb[5]["sent_idx"] = 1
        merged_entity_emb[5]["alias_list_pos"] = 3
        merged_entity_emb[5]["entity_emb"] = np.array([12, 13])
        merged_entity_emb[5]["final_loss_pred"] = 1
        merged_entity_emb[5]["final_loss_prob"] = 0.9
        merged_entity_emb[5]["final_loss_cand_probs"] = np.array([0.1, 0.9])
        # seventh men
        merged_entity_emb[6]["sent_idx"] = 1
        merged_entity_emb[6]["alias_list_pos"] = 4
        merged_entity_emb[6]["entity_emb"] = np.array([14, 15])
        merged_entity_emb[6]["final_loss_pred"] = 1
        merged_entity_emb[6]["final_loss_prob"] = 0.9
        merged_entity_emb[6]["final_loss_cand_probs"] = np.array([0.1, 0.9])

        num_processes = 1
        train_in_candidates = True
        dump_embs = True
        max_candidates = 2
        """
          "a":[["Q1",10.0],["Q4",6]],
          "b":[["Q2",5.0],["Q1",3]],
          "c":[["Q1",30.0],["Q2",3]],
          "d":[["Q4",20],["Q3",15.0]],
          "e":[["Q1",10.0],["Q4",6]],
          "f":[["Q2",5.0],["Q1",3]],
          "g":[["Q1",30.0],["Q2",3]]
        """

        gold_lines = [
            {
                "sent_idx_unq": 0,
                "aliases": ["a", "b"],
                "qids": ["Q4", "Q1"],
                "probs": [0.9, 0.9],
                "cands": [["Q1", "Q4"], ["Q2", "Q1"]],
                "cand_probs": [[0.1, 0.9], [0.1, 0.9]],
                "entity_ids": [4, 1],
                "ctx_emb_ids": [0, 1],
            },
            {
                "sent_idx_unq":
                1,
                "aliases": ["c", "d", "e", "f", "g"],
                "qids": ["Q1", "Q4", "Q4", "Q1", "Q2"],
                "probs": [0.9, 0.9, 0.9, 0.9, 0.9],
                "cands": [
                    ["Q1", "Q2"],
                    ["Q4", "Q3"],
                    ["Q1", "Q4"],
                    ["Q2", "Q1"],
                    ["Q1", "Q2"],
                ],
                "cand_probs": [
                    [0.9, 0.1],
                    [0.9, 0.1],
                    [0.1, 0.9],
                    [0.1, 0.9],
                    [0.1, 0.9],
                ],
                "entity_ids": [1, 4, 4, 1, 2],
                "ctx_emb_ids": [2, 3, 4, 5, 6],
            },
        ]

        write_data_labels(
            num_processes=num_processes,
            result_alias_offset=0,
            merged_entity_emb_file=merged_entity_emb_file.name,
            merged_storage_type=merged_storage_type,
            sent_idx2row=sent_idx2rows,
            cache_folder=cache_folder.name,
            out_file=out_file.name,
            entity_dump=entity_symbols,
            train_in_candidates=train_in_candidates,
            max_candidates=max_candidates,
            dump_embs=dump_embs,
            trie_candidate_map_folder=None,
            trie_qid2eid_file=None,
        )
        all_lines = []
        with open(out_file.name) as check_f:
            for line in check_f:
                all_lines.append(ujson.loads(line))

        assert len(all_lines) == len(gold_lines)

        all_lines_sent_idx_map = {
            line["sent_idx_unq"]: line
            for line in all_lines
        }
        gold_lines_sent_idx_map = {
            line["sent_idx_unq"]: line
            for line in gold_lines
        }

        assert len(all_lines_sent_idx_map) == len(gold_lines_sent_idx_map)
        for sent_idx in all_lines_sent_idx_map:
            self.assertDictEqual(
                gold_lines_sent_idx_map[sent_idx],
                all_lines_sent_idx_map[sent_idx],
                f"{ujson.dumps(gold_lines_sent_idx_map[sent_idx], indent=4)} VS "
                f"{ujson.dumps(all_lines_sent_idx_map[sent_idx], indent=4)}",
            )

        # TRY MULTIPROCESSING
        num_processes = 2
        # create memmory files for multiprocessing
        trie_candidate_map_folder = tempfile.TemporaryDirectory()
        trie_qid2eid_file = tempfile.NamedTemporaryFile()
        create_single_item_trie(entity_symbols.get_qid2eid(),
                                out_file=trie_qid2eid_file.name)
        check_and_create_alias_cand_trie(trie_candidate_map_folder.name,
                                         entity_symbols)

        write_data_labels(
            num_processes=num_processes,
            result_alias_offset=0,
            merged_entity_emb_file=merged_entity_emb_file.name,
            merged_storage_type=merged_storage_type,
            sent_idx2row=sent_idx2rows,
            cache_folder=cache_folder.name,
            out_file=out_file.name,
            entity_dump=entity_symbols,
            train_in_candidates=train_in_candidates,
            max_candidates=max_candidates,
            dump_embs=dump_embs,
            trie_candidate_map_folder=trie_candidate_map_folder.name,
            trie_qid2eid_file=trie_qid2eid_file.name,
        )

        all_lines = []
        with open(out_file.name) as check_f:
            for line in check_f:
                all_lines.append(ujson.loads(line))

        assert len(all_lines) == len(gold_lines)

        all_lines_sent_idx_map = {
            line["sent_idx_unq"]: line
            for line in all_lines
        }
        gold_lines_sent_idx_map = {
            line["sent_idx_unq"]: line
            for line in gold_lines
        }

        assert len(all_lines_sent_idx_map) == len(gold_lines_sent_idx_map)
        for sent_idx in all_lines_sent_idx_map:
            self.assertDictEqual(
                gold_lines_sent_idx_map[sent_idx],
                all_lines_sent_idx_map[sent_idx],
                f"{ujson.dumps(gold_lines_sent_idx_map[sent_idx], indent=4)} VS "
                f"{ujson.dumps(all_lines_sent_idx_map[sent_idx], indent=4)}",
            )

        # clean up
        if os.path.exists(entity_dir):
            shutil.rmtree(entity_dir, ignore_errors=True)
        merged_entity_emb_file.close()
        out_file.close()
        data_file.close()
        trie_candidate_map_folder.cleanup()
        cache_folder.cleanup()
        trie_qid2eid_file.close()