def write_data_labels(num_processes, merged_entity_emb_file, merged_storage_type, data_file, out_file, train_in_candidates, dump_embs, data_config): logger = logging.getLogger(__name__) # Get sent mapping start = time.time() sent_idx_map = get_sent_idx_map(merged_entity_emb_file, merged_storage_type) sent_idx_map_file = tempfile.NamedTemporaryFile( suffix="bootleg_sent_idx_map") utils.create_single_item_trie(sent_idx_map, out_file=sent_idx_map_file.name) # Chunk file for parallel writing create_ex_indir = tempfile.TemporaryDirectory() create_ex_outdir = tempfile.TemporaryDirectory() logger.debug(f"Counting lines") total_input = sum(1 for _ in open(data_file)) chunk_input = int(np.ceil(total_input / num_processes)) logger.debug( f"Chunking up {total_input} lines into subfiles of size {chunk_input} lines" ) total_input_from_chunks, input_files_dict = utils.chunk_file( data_file, create_ex_indir.name, chunk_input) input_files = list(input_files_dict.keys()) input_file_lines = [input_files_dict[k] for k in input_files] output_files = [ in_file_name.replace(create_ex_indir.name, create_ex_outdir.name) for in_file_name in input_files ] assert total_input == total_input_from_chunks, f"Lengths of files {total_input} doesn't mathc {total_input_from_chunks}" logger.debug(f"Done chunking files") logger.info(f'Starting to write files with {num_processes} processes') pool = multiprocessing.Pool(processes=num_processes, initializer=write_data_labels_initializer, initargs=[ merged_entity_emb_file, merged_storage_type, sent_idx_map_file.name, train_in_candidates, dump_embs, data_config ]) input_args = list(zip(input_files, input_file_lines, output_files)) # Store output files and counts for saving in next step total = 0 for res in pool.imap(write_data_labels_hlp, input_args, chunksize=1): total += 1 # Merge output files to final file logger.debug(f"Merging output files") with open(out_file, 'wb') as outfile: for filename in glob.glob(os.path.join(create_ex_outdir.name, "*")): if filename == out_file: # don't want to copy the output into the output continue with open(filename, 'rb') as readfile: shutil.copyfileobj(readfile, outfile) sent_idx_map_file.close() create_ex_indir.cleanup() create_ex_outdir.cleanup() logger.info(f'Time to write files {time.time()-start}s')
def merge_subsentences(num_processes, data_file, to_save_file, to_save_storage, to_read_file, to_read_storage, dump_embs=False): logger = logging.getLogger(__name__) logger.debug(f"Getting sentence mapping") sent_start_map, total_num_mentions = get_sent_start_map(data_file) sent_start_map_file = tempfile.NamedTemporaryFile( suffix="bootleg_sent_start_map") utils.create_single_item_trie(sent_start_map, out_file=sent_start_map_file.name) logger.debug(f"Done with sentence mapping") full_pred_data = np.memmap(to_read_file, dtype=to_read_storage, mode='r') M = int(full_pred_data[0]['M']) K = int(full_pred_data[0]['K']) hidden_size = int(full_pred_data[0]['hidden_size']) filt_emb_data = np.memmap(to_save_file, dtype=to_save_storage, mode='w+', shape=(total_num_mentions, )) filt_emb_data['hidden_size'] = hidden_size filt_emb_data['sent_idx'][:] = -1 filt_emb_data['alias_list_pos'][:] = -1 chunk_size = int(np.ceil(len(full_pred_data) / num_processes)) all_ids = list(range(0, len(full_pred_data))) row_idx_set_chunks = [ all_ids[ids:ids + chunk_size] for ids in range(0, len(full_pred_data), chunk_size) ] input_args = [[M, K, hidden_size, dump_embs, chunk] for chunk in row_idx_set_chunks] logger.info( f"Merging sentences together with {num_processes} processes. Starting pool" ) pool = multiprocessing.Pool(processes=num_processes, initializer=merge_subsentences_initializer, initargs=[ to_save_file, to_save_storage, to_read_file, to_read_storage, sent_start_map_file.name ]) logger.debug(f"Finished pool") start = time.time() seen_ids = set() for sent_ids_seen in pool.imap_unordered(merge_subsentences_hlp, input_args, chunksize=1): for emb_id in sent_ids_seen: assert emb_id not in seen_ids, f'{emb_id} already seen, something went wrong with sub-sentences' seen_ids.add(emb_id) sent_start_map_file.close() logger.info(f'Time to merge sub-sentences {time.time()-start}s') return
def write_data_labels( num_processes, result_alias_offset, merged_entity_emb_file, merged_storage_type, sent_idx2row, cache_folder, out_file, entity_dump, train_in_candidates, max_candidates, dump_embs, trie_candidate_map_folder=None, trie_qid2eid_file=None, ): """Takes the flattened data from merge_sentences and writes out predictions to a file, one line per sentence. The embedding ids are added to the file if dump_embs is True. Args: num_processes: number of processes result_alias_offset: alias offset of this batch of examples for writing out merged_entity_emb_file: input memmap file after merge sentences merged_storage_type: input file storage type sent_idx2row: Dict of sentence idx to row relevant to this subbatch cache_folder: folder to save temporary outputs out_file: final output file for predictions entity_dump: entity dump train_in_candidates: whether NC entities are not in candidate lists max_candidates: maximum number of candidates dump_embs: whether to dump embeddings or not trie_candidate_map_folder: folder where trie of alias->candidate map is stored for parallel proccessing trie_qid2eid_file: file where trie of qid->eid map is stored for parallel proccessing Returns: """ st = time.time() sental2embid = get_sental2embid(merged_entity_emb_file, merged_storage_type) log_rank_0_debug(logger, f"Finished getting sentence map {time.time() - st}s") total_input = len(sent_idx2row) if num_processes == 1: filt_emb_data = np.memmap(merged_entity_emb_file, dtype=merged_storage_type, mode="r+") write_data_labels_single( sentidx2row=sent_idx2row, output_file=out_file, filt_emb_data=filt_emb_data, sental2embid=sental2embid, alias_cand_map=entity_dump.get_alias2qids(), qid2eid=entity_dump.get_qid2eid(), result_alias_offset=result_alias_offset, train_in_cands=train_in_candidates, max_cands=max_candidates, dump_embs=dump_embs, ) else: assert ( trie_candidate_map_folder is not None ), "trie_candidate_map_folder is None and you have parallel turned on" assert (trie_qid2eid_file is not None ), "trie_qid2eid_file is None and you have parallel turned on" # Get trie of sentence map trie_folder = os.path.join(cache_folder, "bootleg_sental2embid") utils.ensure_dir(trie_folder) trie_file = os.path.join(trie_folder, "sentidx.marisa") utils.create_single_item_trie(sental2embid, out_file=trie_file) # Chunk file for parallel writing # We do not use TemporaryFolders as the temp dir may not have enough space for large files create_ex_indir = os.path.join(cache_folder, "_bootleg_eval_temp_indir") utils.ensure_dir(create_ex_indir) create_ex_outdir = os.path.join(cache_folder, "_bootleg_eval_temp_outdir") utils.ensure_dir(create_ex_outdir) chunk_input = int(np.ceil(total_input / num_processes)) logger.debug( f"Chunking up {total_input} lines into subfiles of size {chunk_input} lines" ) # Chunk up dictionary of data for parallel processing input_files = [] i = 0 cur_lines = 0 file_split = os.path.join(create_ex_indir, f"out{i}.jsonl") open_file = open(file_split, "w") for s_idx in sent_idx2row: if cur_lines >= chunk_input: open_file.close() input_files.append(file_split) cur_lines = 0 i += 1 file_split = os.path.join(create_ex_indir, f"out{i}.jsonl") open_file = open(file_split, "w") line = sent_idx2row[s_idx] open_file.write(ujson.dumps(line) + "\n") cur_lines += 1 open_file.close() input_files.append(file_split) # Generation input/output pairs output_files = [ in_file_name.replace(create_ex_indir, create_ex_outdir) for in_file_name in input_files ] log_rank_0_debug(logger, f"Done chunking files. Starting pool") pool = multiprocessing.Pool( processes=num_processes, initializer=write_data_labels_initializer, initargs=[ merged_entity_emb_file, merged_storage_type, trie_file, result_alias_offset, train_in_candidates, max_candidates, dump_embs, trie_candidate_map_folder, trie_qid2eid_file, ], ) input_args = list(zip(input_files, output_files)) total = 0 for res in pool.imap(write_data_labels_hlp, input_args, chunksize=1): total += 1 # Merge output files to final file log_rank_0_debug(logger, f"Merging output files") with open(out_file, "wb") as outfile: for filename in glob.glob(os.path.join(create_ex_outdir, "*")): if filename == out_file: # don't want to copy the output into the output continue with open(filename, "rb") as readfile: shutil.copyfileobj(readfile, outfile)
def merge_subsentences( num_processes, subset_sent_idx2num_mens, cache_folder, to_save_file, to_save_storage, to_read_file, to_read_storage, dump_embs=False, ): """Flatten all sentences back together over sub-sentences; removing the PAD aliases from the data I.e., converts from sent_idx -> array of values to (sent_idx, alias_idx) -> value with varying numbers of aliases per sentence. Args: num_processes: number of processes subset_sent_idx2num_mens: Dict of sentence index to number of mentions for this batch cache_folder: cache directory to_save_file: memmap file to save results to to_save_storage: save file storage type to_read_file: memmap file to read predictions from to_read_storage: read file storage type dump_embs: whether to save embeddings or not Returns: """ # Compute sent idx to offset so we know where to fill in mentions cur_offset = 0 sentidx2offset = {} for k, v in subset_sent_idx2num_mens.items(): sentidx2offset[k] = cur_offset cur_offset += v # print("Sent Idx, Num Mens, Offset", k, v, cur_offset) total_num_mentions = cur_offset # print("TOTAL", total_num_mentions) full_pred_data = np.memmap(to_read_file, dtype=to_read_storage, mode="r") M = int(full_pred_data[0]["M"]) K = int(full_pred_data[0]["K"]) hidden_size = int(full_pred_data[0]["hidden_size"]) # print("TOTAL MENS", total_num_mentions) filt_emb_data = np.memmap(to_save_file, dtype=to_save_storage, mode="w+", shape=(total_num_mentions, )) filt_emb_data["hidden_size"] = hidden_size filt_emb_data["sent_idx"][:] = -1 filt_emb_data["alias_list_pos"][:] = -1 all_ids = list(range(0, len(full_pred_data))) start = time.time() if num_processes == 1: seen_ids = merge_subsentences_single( M, K, hidden_size, dump_embs, all_ids, filt_emb_data, full_pred_data, sentidx2offset, ) else: # Get trie for sentence start map trie_folder = os.path.join(cache_folder, "bootleg_sent_idx2num_mens") utils.ensure_dir(trie_folder) trie_file = os.path.join(trie_folder, "sentidx.marisa") utils.create_single_item_trie(sentidx2offset, out_file=trie_file) # Chunk up date chunk_size = int(np.ceil(len(full_pred_data) / num_processes)) row_idx_set_chunks = [ all_ids[ids:ids + chunk_size] for ids in range(0, len(full_pred_data), chunk_size) ] # Start pool input_args = [[M, K, hidden_size, dump_embs, chunk] for chunk in row_idx_set_chunks] log_rank_0_debug( logger, f"Merging sentences together with {num_processes} processes") pool = multiprocessing.Pool( processes=num_processes, initializer=merge_subsentences_initializer, initargs=[ to_save_file, to_save_storage, to_read_file, to_read_storage, trie_file, ], ) seen_ids = set() for sent_ids_seen in pool.imap_unordered(merge_subsentences_hlp, input_args, chunksize=1): for emb_id in sent_ids_seen: assert ( emb_id not in seen_ids ), f"{emb_id} already seen, something went wrong with sub-sentences" seen_ids.add(emb_id) # filt_emb_data = np.memmap(to_save_file, dtype=to_save_storage, mode="r") # for i in range(len(filt_emb_data)): # si = filt_emb_data[i]["sent_idx"] # al_test = filt_emb_data[i]["alias_list_pos"] # if si == -1 or al_test == -1: # print("BAD", i, filt_emb_data[i]) # import ipdb; ipdb.set_trace() logging.debug(f"Saw {len(seen_ids)} sentences") logging.debug(f"Time to merge sub-sentences {time.time() - start}s") return
def disambig_dump_preds( result_idx, result_alias_offset, config, res_dict, sent_idx2num_mens, sent_idx2row, save_folder, entity_symbols, dump_embs, task_name, ): """Dumps the predictions of a disambiguation task. Args: result_idx: batch index of the result arrays result_alias_offset: overall offset of the starting example (i.e., the number of previous mens already written) config: model config res_dict: result dictionary from Emmental predict sent_idx2num_mens: Dict sentence idx to number of mentions sent_idx2row: Dict sentence idx to row of eval data save_folder: folder to save results entity_symbols: entity symbols dump_embs: whether to save the contextualized embeddings or not task_name: task name Returns: saved prediction file, saved embedding file (will be None if dump_embs is False) """ num_processes = min(config.run_config.dataset_threads, int(multiprocessing.cpu_count() * 0.9)) cache_dir = os.path.join(save_folder, f"cache_{result_idx}") utils.ensure_dir(cache_dir) trie_candidate_map_folder = None trie_qid2eid_file = None # Save the alias->QID candidate map and the QID->EID mapping in memory efficient structures for faster # prediction dumping if num_processes > 1: entity_prep_dir = data_utils.get_emb_prep_dir(config.data_config) trie_candidate_map_folder = os.path.join(entity_prep_dir, "for_dumping_preds", "alias_cand_trie") utils.ensure_dir(trie_candidate_map_folder) check_and_create_alias_cand_trie(trie_candidate_map_folder, entity_symbols) trie_qid2eid_file = os.path.join(entity_prep_dir, "for_dumping_preds", "qid2eid_trie.marisa") if not os.path.exists(trie_qid2eid_file): utils.create_single_item_trie(entity_symbols.get_qid2eid(), out_file=trie_qid2eid_file) # This is dumping disambig_res_dict = {} for k in res_dict: assert task_name in res_dict[ k], f"{task_name} not in res_dict for key {k}" disambig_res_dict[k] = res_dict[k][task_name] # write to file (M x hidden x size for each data point -- next step will deal with recovering original sentence # indices for overflowing sentences) unmerged_entity_emb_file = os.path.join(save_folder, f"entity_embs.pt") merged_entity_emb_file = os.path.join(save_folder, f"entity_embs_unmerged.pt") emb_file_config = os.path.splitext( unmerged_entity_emb_file)[0] + "_config.npy" M = config.data_config.max_aliases K = entity_symbols.max_candidates + ( not config.data_config.train_in_candidates) if dump_embs: unmerged_storage_type = np.dtype([ ("M", int), ("K", int), ("hidden_size", int), ("sent_idx", int), ("subsent_idx", int), ("alias_list_pos", int, (M, )), ("entity_emb", float, M * config.model_config.hidden_size), ("final_loss_true", int, (M, )), ("final_loss_pred", int, (M, )), ("final_loss_prob", float, (M, )), ("final_loss_cand_probs", float, M * K), ]) merged_storage_type = np.dtype([ ("hidden_size", int), ("sent_idx", int), ("alias_list_pos", int), ("entity_emb", float, config.model_config.hidden_size), ("final_loss_pred", int), ("final_loss_prob", float), ("final_loss_cand_probs", float, K), ]) else: # don't need to extract contextualized entity embedding unmerged_storage_type = np.dtype([ ("M", int), ("K", int), ("hidden_size", int), ("sent_idx", int), ("subsent_idx", int), ("alias_list_pos", int, (M, )), ("final_loss_true", int, (M, )), ("final_loss_pred", int, (M, )), ("final_loss_prob", float, (M, )), ("final_loss_cand_probs", float, M * K), ]) merged_storage_type = np.dtype([ ("hidden_size", int), ("sent_idx", int), ("alias_list_pos", int), ("final_loss_pred", int), ("final_loss_prob", float), ("final_loss_cand_probs", float, K), ]) mmap_file = np.memmap( unmerged_entity_emb_file, dtype=unmerged_storage_type, mode="w+", shape=(len(disambig_res_dict["uids"]), ), ) # print("MEMMAP FILE SHAPE", len(disambig_res_dict["uids"])) # Init sent_idx to -1 for debugging mmap_file[:]["sent_idx"] = -1 np.save(emb_file_config, unmerged_storage_type, allow_pickle=True) log_rank_0_debug( logger, f"Created file {unmerged_entity_emb_file} to save predictions.") log_rank_0_debug(logger, f'{len(disambig_res_dict["uids"])} samples') for_iteration = [ disambig_res_dict["uids"], disambig_res_dict["golds"], disambig_res_dict["probs"], disambig_res_dict["preds"], ] all_sent_idx = set() for i, (uid, gold, probs, model_pred) in enumerate(zip(*for_iteration)): # disambig_res_dict["output"] is dict with keys ['_input__alias_orig_list_pos', # 'bootleg_pred_1', '_input__sent_idx', '_input__for_dump_gold_cand_K_idx_train', '_input__subsent_idx', 0, 1] sent_idx = disambig_res_dict["outputs"]["_input__sent_idx"][i] # print("INSIDE LOOP", sent_idx, "AT", i) subsent_idx = disambig_res_dict["outputs"]["_input__subsent_idx"][i] alias_orig_list_pos = disambig_res_dict["outputs"][ "_input__alias_orig_list_pos"][i] gold_cand_K_idx_train = disambig_res_dict["outputs"][ "_input__for_dump_gold_cand_K_idx_train"][i] output_embeddings = disambig_res_dict["outputs"][ f"{PRED_LAYER}_ent_embs"][i] mmap_file[i]["M"] = M mmap_file[i]["K"] = K mmap_file[i]["hidden_size"] = config.model_config.hidden_size mmap_file[i]["sent_idx"] = sent_idx mmap_file[i]["subsent_idx"] = subsent_idx mmap_file[i]["alias_list_pos"] = alias_orig_list_pos # This will give all aliases seen by the model during training, independent of if it's gold or not mmap_file[i][f"final_loss_true"] = gold_cand_K_idx_train.reshape(M) # get max for each alias, probs is M x K max_probs = probs.max(axis=1) pred_cands = probs.argmax(axis=1) mmap_file[i]["final_loss_pred"] = pred_cands mmap_file[i]["final_loss_prob"] = max_probs mmap_file[i]["final_loss_cand_probs"] = probs.reshape(1, -1) all_sent_idx.add(str(sent_idx)) # final_entity_embs is M x K x hidden_size, pred_cands is M if dump_embs: chosen_entity_embs = select_embs(embs=output_embeddings, pred_cands=pred_cands, M=M) # write chosen entity embs to file for contextualized entity embeddings mmap_file[i]["entity_emb"] = chosen_entity_embs.reshape(1, -1) # for i in range(len(mmap_file)): # si = mmap_file[i]["sent_idx"] # if -1 == si: # import pdb # pdb.set_trace() # assert si != -1, f"{i} {mmap_file[i]}" # Store all predicted sentences to filter the sentence mapping by subset_sent_idx2num_mens = { k: v for k, v in sent_idx2num_mens.items() if k in all_sent_idx } # print("ALL SEEN", all_sent_idx) subsent_sent_idx2row = { k: v for k, v in sent_idx2row.items() if k in all_sent_idx } result_file = get_result_file(result_idx, save_folder) log_rank_0_debug(logger, f"Writing predictions to {result_file}...") merge_subsentences( num_processes=num_processes, subset_sent_idx2num_mens=subset_sent_idx2num_mens, cache_folder=cache_dir, to_save_file=merged_entity_emb_file, to_save_storage=merged_storage_type, to_read_file=unmerged_entity_emb_file, to_read_storage=unmerged_storage_type, dump_embs=dump_embs, ) write_data_labels( num_processes=num_processes, result_alias_offset=result_alias_offset, merged_entity_emb_file=merged_entity_emb_file, merged_storage_type=merged_storage_type, sent_idx2row=subsent_sent_idx2row, cache_folder=cache_dir, out_file=result_file, entity_dump=entity_symbols, train_in_candidates=config.data_config.train_in_candidates, max_candidates=entity_symbols.max_candidates, dump_embs=dump_embs, trie_candidate_map_folder=trie_candidate_map_folder, trie_qid2eid_file=trie_qid2eid_file, ) out_emb_file = None filt_emb_data = np.memmap(merged_entity_emb_file, dtype=merged_storage_type, mode="r+") total_mentions_seen = len(filt_emb_data) # save easier-to-use embedding file if dump_embs: hidden_size = filt_emb_data[0]["hidden_size"] out_emb_file = get_emb_file(result_idx, save_folder) np.save(out_emb_file, filt_emb_data["entity_emb"].reshape(-1, hidden_size)) log_rank_0_debug( logger, f"Saving contextual entity embeddings for {result_idx} to {out_emb_file}", ) filt_emb_data = None # Cleanup cache - sometimes the file in cache_dir is still open so we need to retry to delete it try_rmtree(cache_dir) log_rank_0_debug(logger, f"Wrote predictions for {result_idx} to {result_file}") return result_file, out_emb_file, all_sent_idx, total_mentions_seen
def test_write_out_subsentences(self): merged_entity_emb_file = tempfile.NamedTemporaryFile() out_file = tempfile.NamedTemporaryFile() data_file = tempfile.NamedTemporaryFile() cache_folder = tempfile.TemporaryDirectory() entity_dir = "test/entity_db" entity_map_dir = "entity_mappings" entity_symbols = EntitySymbolsSubclass() entity_symbols.save(save_dir=os.path.join(entity_dir, entity_map_dir)) total_num_mentions = 7 K = 2 hidden_size = 2 # create data file -- just needs aliases and sentence indices data = [ { "aliases": ["a", "b"], "sent_idx_unq": 0 }, { "aliases": ["c", "d", "e", "f", "g"], "sent_idx_unq": 1 }, ] # Dict is a string key for trie sent_idx2rows = {"0": data[0], "1": data[1]} with jsonlines.open(data_file.name, "w") as f: for row in data: f.write(row) merged_storage_type = np.dtype([ ("hidden_size", int), ("sent_idx", int), ("alias_list_pos", int), ("entity_emb", float, hidden_size), ("final_loss_pred", int), ("final_loss_prob", float), ("final_loss_cand_probs", float, K), ]) merged_entity_emb = np.memmap( merged_entity_emb_file.name, dtype=merged_storage_type, mode="w+", shape=(total_num_mentions, ), ) # 2 sentences, 1st sent has 1 subsentence, 2nd sentence has 2 subsentences - 7 mentions total merged_entity_emb["hidden_size"] = hidden_size # first men merged_entity_emb[0]["sent_idx"] = 0 merged_entity_emb[0]["alias_list_pos"] = 0 merged_entity_emb[0]["entity_emb"] = np.array([0, 1]) merged_entity_emb[0]["final_loss_pred"] = 1 merged_entity_emb[0]["final_loss_prob"] = 0.9 merged_entity_emb[0]["final_loss_cand_probs"] = np.array([0.1, 0.9]) # second men merged_entity_emb[1]["sent_idx"] = 0 merged_entity_emb[1]["alias_list_pos"] = 1 merged_entity_emb[1]["entity_emb"] = np.array([2, 3]) merged_entity_emb[1]["final_loss_pred"] = 1 merged_entity_emb[1]["final_loss_prob"] = 0.9 merged_entity_emb[1]["final_loss_cand_probs"] = np.array([0.1, 0.9]) # third men merged_entity_emb[2]["sent_idx"] = 1 merged_entity_emb[2]["alias_list_pos"] = 0 merged_entity_emb[2]["entity_emb"] = np.array([4, 5]) merged_entity_emb[2]["final_loss_pred"] = 0 merged_entity_emb[2]["final_loss_prob"] = 0.9 merged_entity_emb[2]["final_loss_cand_probs"] = np.array([0.9, 0.1]) # fourth men merged_entity_emb[3]["sent_idx"] = 1 merged_entity_emb[3]["alias_list_pos"] = 1 merged_entity_emb[3]["entity_emb"] = np.array([6, 7]) merged_entity_emb[3]["final_loss_pred"] = 0 merged_entity_emb[3]["final_loss_prob"] = 0.9 merged_entity_emb[3]["final_loss_cand_probs"] = np.array([0.9, 0.1]) # fifth men merged_entity_emb[4]["sent_idx"] = 1 merged_entity_emb[4]["alias_list_pos"] = 2 merged_entity_emb[4]["entity_emb"] = np.array([10, 11]) merged_entity_emb[4]["final_loss_pred"] = 1 merged_entity_emb[4]["final_loss_prob"] = 0.9 merged_entity_emb[4]["final_loss_cand_probs"] = np.array([0.1, 0.9]) # sixth men merged_entity_emb[5]["sent_idx"] = 1 merged_entity_emb[5]["alias_list_pos"] = 3 merged_entity_emb[5]["entity_emb"] = np.array([12, 13]) merged_entity_emb[5]["final_loss_pred"] = 1 merged_entity_emb[5]["final_loss_prob"] = 0.9 merged_entity_emb[5]["final_loss_cand_probs"] = np.array([0.1, 0.9]) # seventh men merged_entity_emb[6]["sent_idx"] = 1 merged_entity_emb[6]["alias_list_pos"] = 4 merged_entity_emb[6]["entity_emb"] = np.array([14, 15]) merged_entity_emb[6]["final_loss_pred"] = 1 merged_entity_emb[6]["final_loss_prob"] = 0.9 merged_entity_emb[6]["final_loss_cand_probs"] = np.array([0.1, 0.9]) num_processes = 1 train_in_candidates = True dump_embs = True max_candidates = 2 """ "a":[["Q1",10.0],["Q4",6]], "b":[["Q2",5.0],["Q1",3]], "c":[["Q1",30.0],["Q2",3]], "d":[["Q4",20],["Q3",15.0]], "e":[["Q1",10.0],["Q4",6]], "f":[["Q2",5.0],["Q1",3]], "g":[["Q1",30.0],["Q2",3]] """ gold_lines = [ { "sent_idx_unq": 0, "aliases": ["a", "b"], "qids": ["Q4", "Q1"], "probs": [0.9, 0.9], "cands": [["Q1", "Q4"], ["Q2", "Q1"]], "cand_probs": [[0.1, 0.9], [0.1, 0.9]], "entity_ids": [4, 1], "ctx_emb_ids": [0, 1], }, { "sent_idx_unq": 1, "aliases": ["c", "d", "e", "f", "g"], "qids": ["Q1", "Q4", "Q4", "Q1", "Q2"], "probs": [0.9, 0.9, 0.9, 0.9, 0.9], "cands": [ ["Q1", "Q2"], ["Q4", "Q3"], ["Q1", "Q4"], ["Q2", "Q1"], ["Q1", "Q2"], ], "cand_probs": [ [0.9, 0.1], [0.9, 0.1], [0.1, 0.9], [0.1, 0.9], [0.1, 0.9], ], "entity_ids": [1, 4, 4, 1, 2], "ctx_emb_ids": [2, 3, 4, 5, 6], }, ] write_data_labels( num_processes=num_processes, result_alias_offset=0, merged_entity_emb_file=merged_entity_emb_file.name, merged_storage_type=merged_storage_type, sent_idx2row=sent_idx2rows, cache_folder=cache_folder.name, out_file=out_file.name, entity_dump=entity_symbols, train_in_candidates=train_in_candidates, max_candidates=max_candidates, dump_embs=dump_embs, trie_candidate_map_folder=None, trie_qid2eid_file=None, ) all_lines = [] with open(out_file.name) as check_f: for line in check_f: all_lines.append(ujson.loads(line)) assert len(all_lines) == len(gold_lines) all_lines_sent_idx_map = { line["sent_idx_unq"]: line for line in all_lines } gold_lines_sent_idx_map = { line["sent_idx_unq"]: line for line in gold_lines } assert len(all_lines_sent_idx_map) == len(gold_lines_sent_idx_map) for sent_idx in all_lines_sent_idx_map: self.assertDictEqual( gold_lines_sent_idx_map[sent_idx], all_lines_sent_idx_map[sent_idx], f"{ujson.dumps(gold_lines_sent_idx_map[sent_idx], indent=4)} VS " f"{ujson.dumps(all_lines_sent_idx_map[sent_idx], indent=4)}", ) # TRY MULTIPROCESSING num_processes = 2 # create memmory files for multiprocessing trie_candidate_map_folder = tempfile.TemporaryDirectory() trie_qid2eid_file = tempfile.NamedTemporaryFile() create_single_item_trie(entity_symbols.get_qid2eid(), out_file=trie_qid2eid_file.name) check_and_create_alias_cand_trie(trie_candidate_map_folder.name, entity_symbols) write_data_labels( num_processes=num_processes, result_alias_offset=0, merged_entity_emb_file=merged_entity_emb_file.name, merged_storage_type=merged_storage_type, sent_idx2row=sent_idx2rows, cache_folder=cache_folder.name, out_file=out_file.name, entity_dump=entity_symbols, train_in_candidates=train_in_candidates, max_candidates=max_candidates, dump_embs=dump_embs, trie_candidate_map_folder=trie_candidate_map_folder.name, trie_qid2eid_file=trie_qid2eid_file.name, ) all_lines = [] with open(out_file.name) as check_f: for line in check_f: all_lines.append(ujson.loads(line)) assert len(all_lines) == len(gold_lines) all_lines_sent_idx_map = { line["sent_idx_unq"]: line for line in all_lines } gold_lines_sent_idx_map = { line["sent_idx_unq"]: line for line in gold_lines } assert len(all_lines_sent_idx_map) == len(gold_lines_sent_idx_map) for sent_idx in all_lines_sent_idx_map: self.assertDictEqual( gold_lines_sent_idx_map[sent_idx], all_lines_sent_idx_map[sent_idx], f"{ujson.dumps(gold_lines_sent_idx_map[sent_idx], indent=4)} VS " f"{ujson.dumps(all_lines_sent_idx_map[sent_idx], indent=4)}", ) # clean up if os.path.exists(entity_dir): shutil.rmtree(entity_dir, ignore_errors=True) merged_entity_emb_file.close() out_file.close() data_file.close() trie_candidate_map_folder.cleanup() cache_folder.cleanup() trie_qid2eid_file.close()