def save(self, save_dir): """Dumps the entity symbols. Args: save_dir: directory string to save Returns: """ self._sort_alias_cands() utils.ensure_dir(save_dir) utils.dump_json_file( filename=os.path.join(save_dir, "config.json"), contents={ "max_candidates": self.max_candidates, "datetime": str(datetime.now()), }, ) utils.dump_json_file( filename=os.path.join(save_dir, self.alias_cand_map_file), contents=self._alias2qids, ) utils.dump_json_file( filename=os.path.join(save_dir, "qid2title.json"), contents=self._qid2title ) utils.dump_json_file( filename=os.path.join(save_dir, "qid2eid.json"), contents=self._qid2eid ) utils.dump_json_file( filename=os.path.join(save_dir, self.alias_idx_file), contents=self._alias2id, )
def test_build_type_table_too_many_types(self): type_data = {"Q1": [1, 2, 3], "Q2": [4, 5, 6], "Q3": [], "Q4": [7, 8, 9]} type_vocab = { "T1": 1, "T2": 2, "T3": 3, "T4": 4, "T5": 5, "T6": 6, "T7": 7, "T8": 8, "T9": 9, } utils.dump_json_file(self.type_file, type_data) utils.dump_json_file(self.type_vocab_file, type_vocab) true_type_table = torch.tensor([[0], [1], [4], [0], [7], [0]]).long() true_type2row = {1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9} pred_type_table, type2row, max_labels = TypeEmb.build_type_table( self.type_file, self.type_vocab_file, max_types=1, entity_symbols=self.entity_symbols, ) assert torch.equal(pred_type_table, true_type_table) self.assertDictEqual(true_type2row, type2row) # there are 9 real types so we expect (including unk and pad) there to be type indices up to 10 assert max_labels == 10
def compute_histograms(save_dir, entity_symbols): al_counts = Counter() for al in entity_symbols.get_all_aliases(): num_entities = len(entity_symbols.get_qid_cands(al)) al_counts.update([num_entities]) utils.dump_json_file(filename=os.path.join(save_dir, "candidate_counts.json"), contents=al_counts) return
def main(): args = parse_args() print(json.dumps(args, indent=4)) assert 0 < args.perc_emb_drop < 1, f"perc_emb_drop must be between 0 and 1" state_dict, model_state_dict = load_statedict(args.init_checkpoint) assert ENTITY_EMB_KEY in model_state_dict print(f"Loading entity symbols from {os.path.join(args.entity_dir, args.entity_map_dir)}") entity_db = EntitySymbols(os.path.join(args.entity_dir, args.entity_map_dir), args.alias_cand_map_file) print(f"Loading qid2count from {args.qid2count}") qid2count = utils.load_json_file(args.qid2count) print(f"Filtering qids") qid2topk_eid, old2new_eid, toes_eid, new_num_topk_entities = filter_qids(args.perc_emb_drop, entity_db, qid2count) print(f"Filtering embeddings") model_state_dict = filter_embs(new_num_topk_entities, entity_db, old2new_eid, qid2topk_eid, toes_eid, model_state_dict) # Generate the new old2new_eid weight vector to save in model_state_dict oldeid2topkeid = torch.arange(0, entity_db.num_entities_with_pad_and_nocand) # The +2 is to account for pads and unks. The -1 is as there are issues with -1 in the indexing for entity embeddings. So we must manually make it the last entry oldeid2topkeid[-1] = new_num_topk_entities+2-1 for qid, new_eid in qid2topk_eid.items(): old_eid = entity_db.get_eid(qid) oldeid2topkeid[old_eid] = new_eid assert oldeid2topkeid[0] == 0, f"The 0 eid shouldn't be changed" assert oldeid2topkeid[-1] == new_num_topk_entities+2-1, "The -1 eid should still map to the last row" model_state_dict[ENTITY_EID_KEY] = oldeid2topkeid state_dict["model"] = model_state_dict print(f"Saving model at {args.save_checkpoint}") torch.save(state_dict, args.save_checkpoint) print(f"Saving entity_db at {os.path.join(args.entity_dir, args.entity_map_dir, args.save_qid2topk_file)}") utils.dump_json_file(os.path.join(args.entity_dir, args.entity_map_dir, args.save_qid2topk_file), qid2topk_eid)
def main(): args = parse_args() print(ujson.dumps(vars(args), indent=4)) num_processes = int(0.8 * multiprocessing.cpu_count()) print(f"Getting slice counts from {args.train_file}") qid_cnts = get_counts(num_processes, args.train_file) utils.dump_json_file(args.out_file, qid_cnts)
def compute_occurrences(save_dir, data_file, entity_dump, lower, strip, num_workers=8): global all_aliases all_aliases = get_all_aliases(entity_dump._alias2qids) # divide up data into chunks num_lines = get_num_lines(data_file) num_processes = min(num_workers, int(multiprocessing.cpu_count())) logging.info(f'Using {num_processes} workers...') chunk_size = int(np.ceil(num_lines / (num_processes))) chunk_file_path = os.path.join(save_dir, 'tmp') utils.ensure_dir(chunk_file_path) chunk_infiles = [ os.path.join(f'{chunk_file_path}', f'data_chunk_{chunk_id}_in.jsonl') for chunk_id in range(num_processes) ] chunk_text_data(data_file, chunk_infiles, chunk_size, num_lines) pool = multiprocessing.Pool(processes=num_processes) subprocess_args = [[chunk_infiles[i], lower, strip] for i in range(num_processes)] results = pool.map(compute_occurrences_single, subprocess_args) pool.close() pool.join() logging.info('Finished collecting counts') logging.info('Merging counts....') # merge counters together ent_occurrences = Counter() # alias histogram alias_occurrences = Counter() # alias text occurrances alias_text_occurrences = Counter() # number of aliases per sentence alias_pair_occurrences = Counter() # alias|entity histogram alias_entity_pair = Counter() for result_set in results: ent_occurrences += result_set['ent_occurrences'] alias_occurrences += result_set['alias_occurrences'] alias_text_occurrences += result_set['alias_text_occurrences'] alias_pair_occurrences += result_set['alias_pair_occurrences'] alias_entity_pair += result_set['alias_entity_pair'] # save counters utils.dump_json_file(filename=os.path.join(save_dir, "entity_count.json"), contents=ent_occurrences) utils.dump_json_file(filename=os.path.join(save_dir, "alias_counts.json"), contents=alias_occurrences) utils.dump_json_file(filename=os.path.join(save_dir, "alias_text_counts.json"), contents=alias_text_occurrences) utils.dump_json_file(filename=os.path.join(save_dir, "alias_pair_occurrences.json"), contents=alias_pair_occurrences) utils.dump_json_file(filename=os.path.join(save_dir, "alias_entity_counts.json"), contents=alias_entity_pair)
def test_load_kg_adj_indices_json(self): kg_data = {"Q1": {"Q2": 100}, "Q3": {"Q2": 11}} utils.dump_json_file(self.kg_adj_json, kg_data) adj_out = KGIndices.build_kg_adj( kg_adj_file=self.kg_adj_json, entity_symbols=self.entity_symbols, threshold=10, log_weight=True, ) adj_out_gold = nx.adjacency_matrix( nx.Graph( np.array( [ [0, 0, 0, 0, 0, 0], [0, 0, np.log(100), 0, 0, 0], [0, np.log(100), 0, np.log(11), 0, 0], [0, 0, np.log(11), 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], ] ) ) ) np.testing.assert_allclose(adj_out.toarray(), adj_out_gold.toarray()) # We filter out weights 10 or lower kg_data = {"Q1": {"Q2": 100}, "Q3": {"Q2": 10}} utils.dump_json_file(self.kg_adj_json, kg_data) adj_out = KGIndices.build_kg_adj( kg_adj_file=self.kg_adj_json, entity_symbols=self.entity_symbols, threshold=10, log_weight=True, ) adj_out_gold = nx.adjacency_matrix( nx.Graph( np.array( [ [0, 0, 0, 0, 0, 0], [0, 0, np.log(100), 0, 0, 0], [0, np.log(100), 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], ] ) ) ) np.testing.assert_allclose(adj_out.toarray(), adj_out_gold.toarray())
def save(self, save_dir, prefix=""): """Dumps the type symbols. Args: save_dir: directory string to save prefix: prefix to add to beginning to file Returns: """ utils.ensure_dir(str(save_dir)) utils.dump_json_file( filename=os.path.join(save_dir, "config.json"), contents={ "max_types": self.max_types, }, ) utils.dump_json_file( filename=os.path.join(save_dir, f"{prefix}qid2typenames.json"), contents=self._qid2typenames, ) utils.dump_json_file( filename=os.path.join(save_dir, f"{prefix}qid2typeids.json"), contents=self._qid2typeid, ) utils.dump_json_file( filename=os.path.join(save_dir, f"{prefix}type_vocab.json"), contents=self._type_vocab, )
def main(args, mode): multiprocessing.set_start_method("forkserver", force=True) # ================================= # ARGUMENTS CHECK # ================================= # distributed training assert (args.run_config.ngpus_per_node <= torch.cuda.device_count()) or ( not torch.cuda.is_available()), 'Not enough GPUs per node.' world_size = args.run_config.ngpus_per_node * args.run_config.nodes if world_size > 1: args.run_config.distributed = True assert (args.run_config.distributed and world_size > 1) or (world_size == 1) train_utils.setup_run_folders(args, mode) # check slice method assert args.train_config.slice_method in SLICE_METHODS, f"You're slice_method {args.train_config.slice_method} is not in {SLICE_METHODS}." train_utils.setup_train_heads_and_eval_slices(args) # check save step assert args.run_config.save_every_k_eval > 0, f"You must have save_every_k_eval set to be > 0" # since eval, make sure resume model file is set and exists if mode == "eval" or mode == "dump_preds" or mode == "dump_embs": assert args.run_config.init_checkpoint != "", \ f"You must specify a model checkpoint in run_config to run {mode}" assert os.path.exists(args.run_config.init_checkpoint),\ f"The resume model file of {args.run_config.init_checkpoint} doesn't exist" if mode == "dump_preds" or mode == "dump_embs": assert args.run_config.perc_eval == 1.0, f"If you are running dump_preds or dump_embs, run_config.perc_eval must be 1.0. You have {args.run_config.perc_eval}" assert args.data_config.test_dataset.use_weak_label is True, f"We do not support dumping when the test dataset gold is set to false. You can filter the dataset and run with filtered data." utils.dump_json_file(filename=os.path.join( train_utils.get_save_folder(args.run_config), f"config_{mode}.json"), contents=args) if args.run_config.distributed: mp.spawn(main_worker, nprocs=args.run_config.ngpus_per_node, args=(args, mode, world_size)) else: main_worker(gpu=args.run_config.gpu, args=args, mode=mode, world_size=world_size)
def compute_type_occurrences(save_dir, prefix, entity_symbols, qid2typenames, data_file): # type histogram type_occurances = defaultdict(int) # type intersection histogram (frequency of type co-occurring together) type_pair_occurances = defaultdict(int) # computes number of aliases in a sentence with the number of shared types amongst those aliases; just outputs number of types for single alias num_al_num_match_type = defaultdict(int) type_pairs = defaultdict(int) with open(data_file, "r") as in_file: for line in in_file: line = json.loads(line.strip()) aliases = line['aliases'] qids = line['qids'] all_ex_types = set() i = 0 # for all pairs of qid, types, get the intersection of types # if the intersection > 1, write type pairs and increment for qid, alias in zip(qids, aliases): types = qid2typenames.get(qid, []) for ty in types: type_occurances[ty] += 1 if i == 0: all_ex_types = set(types) else: all_ex_types = all_ex_types.intersection(set(types)) i += 1 if len(aliases) > 1: num_al_num_match_type[ f"{len(aliases)}|{len(set(all_ex_types))}"] += 1 type_pair_occurances[tuple(sorted(list(all_ex_types)))] += 1 alias_subsets = list(combinations(qids, 2)) for qid1, qid2 in alias_subsets: types1 = qid2typenames.get(qid1, []) types2 = qid2typenames.get(qid2, []) overlap_types = set(types1).intersection(set(types2)) if len(overlap_types) > 0: type_pairs[tuple(overlap_types)] += 1 logging.info(f"Saving type data...") utils.dump_json_file(filename=os.path.join( save_dir, f"{prefix}_type_occurances.json"), contents=type_occurances) utils.dump_json_file(filename=os.path.join( save_dir, f"{prefix}_type_pair_occurances.json"), contents=type_pair_occurances) utils.dump_json_file(filename=os.path.join( save_dir, f"{prefix}_num_al_num_match_type.json"), contents=num_al_num_match_type) utils.dump_json_file(filename=os.path.join( save_dir, f"{prefix}_num_type_pairs.json"), contents=type_pairs)
def get_slice_stats_hlp(args): i, lines, offset, temp_out_dir = args res = defaultdict(int) # slice -> cnt slice_to_sent = defaultdict(set) # slice -> sent_idx (for sampling) sent_to_slices = defaultdict( lambda: defaultdict(int)) # sent_idx -> slice -> cnt (for sampling) for line in tqdm(lines, total=len(lines), desc=f"Processing lines for {i}"): line = ujson.loads(line) slices = line.get("slices", {}) anchors = line["gold"] for slice_name in slices: for al_str in slices[slice_name]: if anchors[int( al_str)] is True and slices[slice_name][al_str] > 0.5: res[slice_name] += 1 slice_to_sent[slice_name].add(int(line["sent_idx_unq"])) sent_to_slices[int(line["sent_idx_unq"])].update(res) utils.dump_json_file( os.path.join(temp_out_dir, f"{FINAL_COUNTS_PREFIX}_{i}.json"), res) utils.dump_json_file( os.path.join(temp_out_dir, f"{FINAL_SENT_TO_SLICE_PREFIX}_{i}.json"), sent_to_slices, ) utils.dump_json_file( os.path.join(temp_out_dir, f"{FINAL_SLICE_TO_SENT_PREFIX}_{i}.json"), slice_to_sent, ) return i
def test_topk_embedding(self): topkqid2eid = {"Q1": 1, "Q2": 3, "Q3": 3, "Q4": 2} utils.dump_json_file(self.qid2topkeid, topkqid2eid) self.args.data_config.ent_embeddings[0]["args"]["perc_emb_drop"] = 0.5 self.args.data_config.ent_embeddings[0]["args"][ "qid2topk_eid" ] = self.qid2topkeid learned_emb = TopKEntityEmb( main_args=self.args, emb_args=self.args.data_config.ent_embeddings[0]["args"], entity_symbols=self.entity_symbols, key="learned", cpu=True, normalize=False, dropout1d_perc=0.0, dropout2d_perc=0.0, ) num_new_eids_padunk = 5 eid2topkeid_gold = torch.tensor([0, 1, 3, 3, 2, num_new_eids_padunk - 1]) assert torch.equal(eid2topkeid_gold, learned_emb.eid2topkeid) assert list(learned_emb.learned_entity_embedding.weight.shape) == [ num_new_eids_padunk, self.learned_embedding_size, ]
def dump(self, save_dir): #memmapped files bahve badly if you try to overwrite them in memory, which is what we'd be doing if load_dir == save_dir if self._loaded_from_dir is None or self._loaded_from_dir != save_dir: utils.ensure_dir(save_dir) utils.dump_json_file(filename=os.path.join(save_dir, "fmt_types.json"), contents=self._fmt_types) utils.dump_json_file(filename=os.path.join(save_dir, "max_values.json"), contents=self._max_values) utils.dump_json_file(filename=os.path.join(save_dir, "vocabulary.json"), contents=self._stoi) np.save(file=os.path.join(save_dir, "itos.npy"), arr=self._itos, allow_pickle=True) for tri_name in self._fmt_types: self._record_tris[tri_name].save( os.path.join(save_dir, f'record_trie_{tri_name}.marisa'))
def __init__( self, main_args, dataset, use_weak_label, entity_symbols, dataset_threads, split="train", ): global_start = time.time() log_rank_0_info(logger, f"Building slice dataset for {split} from {dataset}.") spawn_method = main_args.run_config.spawn_method data_config = main_args.data_config orig_spawn = multiprocessing.get_start_method() multiprocessing.set_start_method(spawn_method, force=True) self.slice_names = data_utils.get_eval_slices(data_config.eval_slices) self.get_slice_dt = lambda max_a2p: np.dtype([ ("sent_idx", int), ("subslice_idx", int), ("alias_slice_incidence", int, (max_a2p, )), ("prob_labels", float, (max_a2p, )), ]) self.get_storage = lambda max_a2p: np.dtype( [(slice_name, self.get_slice_dt(max_a2p)) for slice_name in self.slice_names]) # Folder for all mmap saved files save_dataset_folder = data_utils.get_save_data_folder( data_config, use_weak_label, dataset) utils.ensure_dir(save_dataset_folder) # Folder for temporary output files temp_output_folder = os.path.join(data_config.data_dir, data_config.data_prep_dir, f"prep_{split}_slice_files") utils.ensure_dir(temp_output_folder) # Input step 1 create_ex_indir = os.path.join(temp_output_folder, "create_examples_input") utils.ensure_dir(create_ex_indir) # Input step 2 create_ex_outdir = os.path.join(temp_output_folder, "create_examples_output") utils.ensure_dir(create_ex_outdir) # Meta data saved files meta_file = os.path.join(temp_output_folder, "meta_data.json") # File for standard training data hash = hashlib.sha1(str( self.slice_names).encode("UTF-8")).hexdigest()[:10] self.save_dataset_name = os.path.join(save_dataset_folder, f"ned_slices_{hash}.bin") self.save_data_config_name = os.path.join(save_dataset_folder, "ned_slices_config.json") # ======================================================================================= # SLICE DATA # ======================================================================================= log_rank_0_debug(logger, "Loading dataset...") log_rank_0_debug(logger, f"Seeing if {self.save_dataset_name} exists") if data_config.overwrite_preprocessed_data or (not os.path.exists( self.save_dataset_name)): st_time = time.time() try: log_rank_0_info( logger, f"Building dataset from scratch. Saving to {save_dataset_folder}", ) create_examples( dataset, create_ex_indir, create_ex_outdir, meta_file, data_config, dataset_threads, self.slice_names, use_weak_label, split, ) max_alias2pred = utils.load_json_file( meta_file)["max_alias2pred"] convert_examples_to_features_and_save( meta_file, dataset_threads, self.slice_names, self.save_dataset_name, self.get_storage(max_alias2pred), ) utils.dump_json_file(self.save_data_config_name, {"max_alias2pred": max_alias2pred}) log_rank_0_debug( logger, f"Finished prepping data in {time.time() - st_time}") except Exception as e: tb = traceback.TracebackException.from_exception(e) logger.error(e) logger.error("\n".join(tb.stack.format())) shutil.rmtree(save_dataset_folder, ignore_errors=True) raise log_rank_0_info( logger, f"Loading data from {self.save_dataset_name} and {self.save_data_config_name}", ) max_alias2pred = utils.load_json_file( self.save_data_config_name)["max_alias2pred"] self.data, self.sent_to_row_id_dict = self.build_data_dict( self.save_dataset_name, self.get_storage(max_alias2pred)) assert len(self.data) > 0 assert len(self.sent_to_row_id_dict) > 0 log_rank_0_debug(logger, f"Removing temporary output files") shutil.rmtree(temp_output_folder, ignore_errors=True) # Set spawn back to original/default, which is "fork" or "spawn". This is needed for the Meta.config to # be correctly passed in the collate_fn. multiprocessing.set_start_method(orig_spawn, force=True) log_rank_0_info( logger, f"Final slice data initialization time from {split} is {time.time() - global_start}s", )
def dump(self, save_dir, stats={}, args=None): self._sort_alias_cands() utils.ensure_dir(save_dir) utils.dump_json_file(filename=os.path.join(save_dir, "config.json"), contents={"max_candidates":self.max_candidates, "max_alias_len":self.max_alias_len, "datetime": str(datetime.now())}) utils.dump_json_file(filename=os.path.join(save_dir, self.alias_cand_map_file), contents=self._alias2qids) utils.dump_json_file(filename=os.path.join(save_dir, "qid2title.json"), contents=self._qid2title) utils.dump_json_file(filename=os.path.join(save_dir, "qid2eid.json"), contents=self._qid2eid) utils.dump_json_file(filename=os.path.join(save_dir, "filter_stats.json"), contents=stats) if args is not None: utils.dump_json_file(filename=os.path.join(save_dir, "args.json"), contents=vars(args))
def create_examples( dataset, create_ex_indir, create_ex_outdir, meta_file, data_config, dataset_threads, slice_names, use_weak_label, split, ): """Creates examples from the raw input data. Args: dataset: dataset file create_ex_indir: temporary directory where input files are stored create_ex_outdir: temporary directory to store output files from method meta_file: metadata file to save the file names/paths for the next step in prep pipeline data_config: data config dataset_threads: number of threads slice_names: list of slices to evaluate on use_weak_label: whether to use weak labeling or not split: data split Returns: """ log_rank_0_debug(logger, "Starting to extract subsentences") start = time.time() num_processes = min(dataset_threads, int(0.8 * multiprocessing.cpu_count())) log_rank_0_debug(logger, f"Counting lines") total_input = sum(1 for _ in open(dataset)) if num_processes == 1: out_file_name = os.path.join(create_ex_outdir, os.path.basename(dataset)) constants_dict = { "slice_names": slice_names, "use_weak_label": use_weak_label, "max_aliases": data_config.max_aliases, "split": split, "train_in_candidates": data_config.train_in_candidates, } files_and_counts = {} res = create_examples_single(dataset, total_input, out_file_name, constants_dict) total_output = res["total_lines"] max_alias2pred = res["max_alias2pred"] files_and_counts[res["output_filename"]] = res["total_lines"] else: log_rank_0_info( logger, f"Strating to extract examples with {num_processes} threads") log_rank_0_debug( logger, "Parallelizing with " + str(num_processes) + " threads.") chunk_input = int(np.ceil(total_input / num_processes)) log_rank_0_debug( logger, f"Chunking up {total_input} lines into subfiles of size {chunk_input} lines", ) total_input_from_chunks, input_files_dict = utils.chunk_file( dataset, create_ex_indir, chunk_input) input_files = list(input_files_dict.keys()) input_file_lines = [input_files_dict[k] for k in input_files] output_files = [ in_file_name.replace(create_ex_indir, create_ex_outdir) for in_file_name in input_files ] assert ( total_input == total_input_from_chunks ), f"Lengths of files {total_input} doesn't mathc {total_input_from_chunks}" log_rank_0_debug(logger, f"Done chunking files") pool = multiprocessing.Pool( processes=num_processes, initializer=create_examples_initializer, initargs=[ data_config, slice_names, use_weak_label, data_config.max_aliases, split, data_config.train_in_candidates, ], ) total_output = 0 max_alias2pred = 0 input_args = list(zip(input_files, input_file_lines, output_files)) # Store output files and counts for saving in next step files_and_counts = {} for res in pool.imap_unordered(create_examples_hlp, input_args, chunksize=1): total_output += res["total_lines"] max_alias2pred = max(max_alias2pred, res["max_alias2pred"]) files_and_counts[res["output_filename"]] = res["total_lines"] pool.close() utils.dump_json_file( meta_file, { "num_mentions": total_output, "files_and_counts": files_and_counts, "max_alias2pred": max_alias2pred, }, ) log_rank_0_debug( logger, f"Done with extracting examples in {time.time()-start}. Total lines seen {total_input}. " f"Total lines kept {total_output}.", ) return
def compress_topk_embeddings(args): assert 0 < args.perc_emb_drop < 1, f"perc_emb_drop must be between 0 and 1" print( f"Loading entity symbols from {os.path.join(args.entity_dir, 'entity_mappings')}" ) entity_db = EntitySymbols.load_from_cache( os.path.join(args.entity_dir, "entity_mappings")) print(f"Loading qid2count from {args.qid2count}") qid2count = utils.load_json_file(args.qid2count) print(f"Filtering qids") ( qid2topk_eid, old2new_eid, toes_eid, new_num_topk_entities, ) = filter_qids(args.perc_emb_drop, entity_db, qid2count) if len(args.model_path) > 0: assert (args.save_model_path is not None and len(args.save_model_path) > 0 ), f"If you give a model path, you must give a save checkpoint" print(f"Filtering embeddings") state_dict, model_state_dict = load_statedict(args.model_path) try: get_nested_item(model_state_dict, ENTITY_EMB_KEYS) except: print( f"ERROR: All of {ENTITY_EMB_KEYS} are not in model_state_dict") raise model_state_dict = filter_embs( new_num_topk_entities, entity_db, old2new_eid, qid2topk_eid, toes_eid, model_state_dict, ) # Generate the new old2new_eid weight vector to save in model_state_dict oldeid2topkeid = torch.arange( 0, entity_db.num_entities_with_pad_and_nocand) # The +2 is to account for pads and unks. The -1 is as there are issues with -1 in the indexing # for entity embeddings. So we must manually make it the last entry oldeid2topkeid[-1] = new_num_topk_entities + 2 - 1 for qid, new_eid in tqdm(qid2topk_eid.items(), desc="Setting new ids"): old_eid = entity_db.get_eid(qid) oldeid2topkeid[old_eid] = new_eid assert oldeid2topkeid[0] == 0, f"The 0 eid shouldn't be changed" assert (oldeid2topkeid[-1] == new_num_topk_entities + 2 - 1), "The -1 eid should still map to the last row" model_state_dict = set_nested_item(model_state_dict, ENTITY_EID_KEYS, oldeid2topkeid) # Remove the eid2reg value as that was with the old entity id mapping try: model_state_dict = set_nested_item(model_state_dict, ENTITY_REG_KEYS, None) except: print( f"Could not remove regularization. If your model was trained with regularization mapping on " f"the learned entity embedding, this should not happen.") print(model_state_dict["module_pool"]["learned"].keys()) state_dict["model"] = model_state_dict print(f"Saving model at {args.save_model_path}") torch.save(state_dict, args.save_model_path) print( f"Saving topk to eid at {os.path.join(args.entity_dir, 'entity_mappings', args.save_qid2topk_file)}" ) utils.dump_json_file( os.path.join(args.entity_dir, "entity_mappings", args.save_qid2topk_file), qid2topk_eid, ) if args.model_config is not None: modify_config( args.model_config, args.save_model_config, args.save_model_path, os.path.join(args.entity_dir, "entity_mappings", args.save_qid2topk_file), args.perc_emb_drop, )