def load_from_cache(cls, load_dir, prefix="", edit_mode=False, verbose=False): """Loads type symbols from load_dir. Args: load_dir: directory to load from prefix: prefix to add to beginning to file edit_mode: edit mode flag verbose: verbose flag Returns: TypeSymbols """ config = utils.load_json_file( filename=os.path.join(load_dir, "config.json")) max_types = config["max_types"] qid2typenames: Dict[str, List[str]] = utils.load_json_file( filename=os.path.join(load_dir, f"{prefix}qid2typenames.json")) qid2typeid: Dict[str, List[int]] = utils.load_json_file( filename=os.path.join(load_dir, f"{prefix}qid2typeids.json")) type_vocab: Dict[str, int] = utils.load_json_file( filename=os.path.join(load_dir, f"{prefix}type_vocab.json")) return cls(qid2typenames, qid2typeid, type_vocab, max_types, edit_mode, verbose)
def load(self, load_dir): config = utils.load_json_file(filename=os.path.join(load_dir, "config.json")) self.max_candidates = config["max_candidates"] self.max_alias_len = config["max_alias_len"] self._alias2qids: Dict[str, list] = utils.load_json_file(filename=os.path.join(load_dir, self.alias_cand_map_file)) self._qid2title: Dict[str, str] = utils.load_json_file(filename=os.path.join(load_dir, "qid2title.json")) self._qid2eid: Dict[str, int] = utils.load_json_file(filename=os.path.join(load_dir, "qid2eid.json")) self._sort_alias_cands()
def load_types(self, entity_symbols, emb_dir, max_types): """Loads all type information""" # load type vocab if self.type_vocab_file == "": print( "You did not give a type vocab file (from type name to typeid). We will use identity mapping" ) typeid2typename = {} else: extension = os.path.splitext(self.type_vocab_file)[-1] if extension == ".json": type_vocab = utils.load_json_file( os.path.join(emb_dir, self.type_vocab_file)) else: print( f"We only support loading json files for TypeSymbol. You have a file ending in {extension}" ) return {}, {}, {} typeid2typename = {i: v for v, i in type_vocab.items()} # load mapping of entities to type ids qid2typenames = {qid: [] for qid in entity_symbols.get_all_qids()} qid2typeid = {qid: [] for qid in entity_symbols.get_all_qids()} print(f"Loading types from {self.type_file}") qid2typeid, qid2typenames = self.load_type_file( type_file=self.type_file, max_types=max_types, entity_symbols=entity_symbols, qid2typeid=qid2typeid, qid2typenames=qid2typenames, typeid2typename=typeid2typename) return qid2typenames, qid2typeid, typeid2typename
def main(): args = parse_args() print(json.dumps(args, indent=4)) assert 0 < args.perc_emb_drop < 1, f"perc_emb_drop must be between 0 and 1" state_dict, model_state_dict = load_statedict(args.init_checkpoint) assert ENTITY_EMB_KEY in model_state_dict print(f"Loading entity symbols from {os.path.join(args.entity_dir, args.entity_map_dir)}") entity_db = EntitySymbols(os.path.join(args.entity_dir, args.entity_map_dir), args.alias_cand_map_file) print(f"Loading qid2count from {args.qid2count}") qid2count = utils.load_json_file(args.qid2count) print(f"Filtering qids") qid2topk_eid, old2new_eid, toes_eid, new_num_topk_entities = filter_qids(args.perc_emb_drop, entity_db, qid2count) print(f"Filtering embeddings") model_state_dict = filter_embs(new_num_topk_entities, entity_db, old2new_eid, qid2topk_eid, toes_eid, model_state_dict) # Generate the new old2new_eid weight vector to save in model_state_dict oldeid2topkeid = torch.arange(0, entity_db.num_entities_with_pad_and_nocand) # The +2 is to account for pads and unks. The -1 is as there are issues with -1 in the indexing for entity embeddings. So we must manually make it the last entry oldeid2topkeid[-1] = new_num_topk_entities+2-1 for qid, new_eid in qid2topk_eid.items(): old_eid = entity_db.get_eid(qid) oldeid2topkeid[old_eid] = new_eid assert oldeid2topkeid[0] == 0, f"The 0 eid shouldn't be changed" assert oldeid2topkeid[-1] == new_num_topk_entities+2-1, "The -1 eid should still map to the last row" model_state_dict[ENTITY_EID_KEY] = oldeid2topkeid state_dict["model"] = model_state_dict print(f"Saving model at {args.save_checkpoint}") torch.save(state_dict, args.save_checkpoint) print(f"Saving entity_db at {os.path.join(args.entity_dir, args.entity_map_dir, args.save_qid2topk_file)}") utils.dump_json_file(os.path.join(args.entity_dir, args.entity_map_dir, args.save_qid2topk_file), qid2topk_eid)
def load_from_cache( cls, load_dir, alias_cand_map_file="alias2qids.json", alias_idx_file="alias2id.json", edit_mode=False, verbose=False, ): """Loads entity symbols from load_dir. Args: load_dir: directory to load from alias_cand_map_file: alias2qid file alias_idx_file: alias2id file edit_mode: edit mode flag verbose: verbose flag Returns: """ config = utils.load_json_file(filename=os.path.join(load_dir, "config.json")) max_candidates = config["max_candidates"] alias2qids: Dict[str, list] = utils.load_json_file( filename=os.path.join(load_dir, alias_cand_map_file) ) qid2title: Dict[str, str] = utils.load_json_file( filename=os.path.join(load_dir, "qid2title.json") ) qid2eid: Dict[str, int] = utils.load_json_file( filename=os.path.join(load_dir, "qid2eid.json") ) alias2id: Dict[str, int] = utils.load_json_file( filename=os.path.join(load_dir, alias_idx_file) ) return cls( alias2qids, qid2title, qid2eid, alias2id, max_candidates, alias_cand_map_file, alias_idx_file, edit_mode, verbose, )
def get_slice_stats(num_processes, file): """Gets true anchor slice counts.""" pool = multiprocessing.Pool(processes=num_processes) final_counts = defaultdict(int) final_slice_to_sent = defaultdict(set) final_sent_to_slices = defaultdict(lambda: defaultdict(int)) temp_out_dir = os.path.join(os.path.dirname(file), "_temp") os.mkdir(temp_out_dir) all_lines = [li for li in open(file)] num_lines = len(all_lines) chunk_size = int(np.ceil(num_lines / num_processes)) line_chunks = [ all_lines[i:i + chunk_size] for i in range(0, num_lines, chunk_size) ] input_args = [[i, line_chunks[i], i * chunk_size, temp_out_dir] for i in range(len(line_chunks))] for i in tqdm( pool.imap_unordered(get_slice_stats_hlp, input_args, chunksize=1), total=len(line_chunks), desc="Gathering slice counts", ): cnt_res = utils.load_json_file( os.path.join(temp_out_dir, f"{FINAL_COUNTS_PREFIX}_{i}.json")) sent_to_slices = utils.load_json_file( os.path.join(temp_out_dir, f"{FINAL_SENT_TO_SLICE_PREFIX}_{i}.json")) slice_to_sent = utils.load_json_file( os.path.join(temp_out_dir, f"{FINAL_SLICE_TO_SENT_PREFIX}_{i}.json")) for k in cnt_res: final_counts[k] += cnt_res[k] for k in slice_to_sent: final_slice_to_sent[k].update(slice_to_sent[k]) for k in sent_to_slices: final_sent_to_slices[k].update(sent_to_slices[k]) shutil.rmtree(temp_out_dir) return dict(final_counts), dict(final_slice_to_sent), dict( final_sent_to_slices)
def load(self, load_dir): self._fmt_types = utils.load_json_file( filename=os.path.join(load_dir, "fmt_types.json")) self._max_values = utils.load_json_file( filename=os.path.join(load_dir, "max_values.json")) self._stoi = utils.load_json_file( filename=os.path.join(load_dir, "vocabulary.json")) self._itos = np.load(file=os.path.join(load_dir, "itos.npy"), allow_pickle=True) assert self._fmt_types.keys() == self._max_values.keys() for tri_name in self._fmt_types: assert f'record_trie_{tri_name}.marisa' in os.listdir( load_dir ), f"Missing record_trie_{tri_name}.marisa in {load_dir}" self._record_tris = {} for tri_name in self._fmt_types: self._record_tris[tri_name] = marisa_trie.RecordTrie( self._get_fmt_strings[self._fmt_types[tri_name]]( self._max_values[tri_name])).mmap( os.path.join(load_dir, f'record_trie_{tri_name}.marisa'))
def build_static_embeddings(cls, emb_file, entity_symbols): """Builds the table of the embedding associated with each entity.""" ending = os.path.splitext(emb_file)[1] if ending == ".glove": embeddings, embedding_size = data_utils.load_glove(emb_file, log_func=print) elif ending == ".json": dct = utils.load_json_file(emb_file) val = next(iter(dct.values())) if type(val) is int or type(val) is float: embedding_size = 1 conver_func = lambda x: np.array([x]) elif type(val) is list: embedding_size = len(val) conver_func = lambda x: np.array([y for y in x]) else: raise ValueError( f"Unrecognized type for the array value of {type(val)}") embeddings = {} for k in dct: embeddings[k] = conver_func(dct[k]) assert len(embeddings[k]) == embedding_size else: raise ValueError( f"We do not support static embeddings from {ending}. We only support .glove or .json" ) entity2staticemb_table = torch.zeros( entity_symbols.num_entities_with_pad_and_nocand, embedding_size) found = 0 for qid in tqdm(entity_symbols.get_all_qids()): if qid in embeddings: found += 1 emb = embeddings[qid] eid = entity_symbols.get_eid(qid) entity2staticemb_table[ eid, :embedding_size] = torch.from_numpy(emb) print( f"Found {found/len(entity_symbols.get_all_qids())} percent of all entities have a static embedding" ) return entity2staticemb_table
def __init__( self, main_args, dataset, use_weak_label, entity_symbols, dataset_threads, split="train", ): global_start = time.time() log_rank_0_info(logger, f"Building slice dataset for {split} from {dataset}.") spawn_method = main_args.run_config.spawn_method data_config = main_args.data_config orig_spawn = multiprocessing.get_start_method() multiprocessing.set_start_method(spawn_method, force=True) self.slice_names = data_utils.get_eval_slices(data_config.eval_slices) self.get_slice_dt = lambda max_a2p: np.dtype([ ("sent_idx", int), ("subslice_idx", int), ("alias_slice_incidence", int, (max_a2p, )), ("prob_labels", float, (max_a2p, )), ]) self.get_storage = lambda max_a2p: np.dtype( [(slice_name, self.get_slice_dt(max_a2p)) for slice_name in self.slice_names]) # Folder for all mmap saved files save_dataset_folder = data_utils.get_save_data_folder( data_config, use_weak_label, dataset) utils.ensure_dir(save_dataset_folder) # Folder for temporary output files temp_output_folder = os.path.join(data_config.data_dir, data_config.data_prep_dir, f"prep_{split}_slice_files") utils.ensure_dir(temp_output_folder) # Input step 1 create_ex_indir = os.path.join(temp_output_folder, "create_examples_input") utils.ensure_dir(create_ex_indir) # Input step 2 create_ex_outdir = os.path.join(temp_output_folder, "create_examples_output") utils.ensure_dir(create_ex_outdir) # Meta data saved files meta_file = os.path.join(temp_output_folder, "meta_data.json") # File for standard training data hash = hashlib.sha1(str( self.slice_names).encode("UTF-8")).hexdigest()[:10] self.save_dataset_name = os.path.join(save_dataset_folder, f"ned_slices_{hash}.bin") self.save_data_config_name = os.path.join(save_dataset_folder, "ned_slices_config.json") # ======================================================================================= # SLICE DATA # ======================================================================================= log_rank_0_debug(logger, "Loading dataset...") log_rank_0_debug(logger, f"Seeing if {self.save_dataset_name} exists") if data_config.overwrite_preprocessed_data or (not os.path.exists( self.save_dataset_name)): st_time = time.time() try: log_rank_0_info( logger, f"Building dataset from scratch. Saving to {save_dataset_folder}", ) create_examples( dataset, create_ex_indir, create_ex_outdir, meta_file, data_config, dataset_threads, self.slice_names, use_weak_label, split, ) max_alias2pred = utils.load_json_file( meta_file)["max_alias2pred"] convert_examples_to_features_and_save( meta_file, dataset_threads, self.slice_names, self.save_dataset_name, self.get_storage(max_alias2pred), ) utils.dump_json_file(self.save_data_config_name, {"max_alias2pred": max_alias2pred}) log_rank_0_debug( logger, f"Finished prepping data in {time.time() - st_time}") except Exception as e: tb = traceback.TracebackException.from_exception(e) logger.error(e) logger.error("\n".join(tb.stack.format())) shutil.rmtree(save_dataset_folder, ignore_errors=True) raise log_rank_0_info( logger, f"Loading data from {self.save_dataset_name} and {self.save_data_config_name}", ) max_alias2pred = utils.load_json_file( self.save_data_config_name)["max_alias2pred"] self.data, self.sent_to_row_id_dict = self.build_data_dict( self.save_dataset_name, self.get_storage(max_alias2pred)) assert len(self.data) > 0 assert len(self.sent_to_row_id_dict) > 0 log_rank_0_debug(logger, f"Removing temporary output files") shutil.rmtree(temp_output_folder, ignore_errors=True) # Set spawn back to original/default, which is "fork" or "spawn". This is needed for the Meta.config to # be correctly passed in the collate_fn. multiprocessing.set_start_method(orig_spawn, force=True) log_rank_0_info( logger, f"Final slice data initialization time from {split} is {time.time() - global_start}s", )
def convert_examples_to_features_and_save(meta_file, dataset_threads, slice_names, save_dataset_name, storage): """Converts the prepped examples into input features and saves in memmap files. These are used in the __get_item__ method. Args: meta_file: metadata file where input file paths are saved dataset_threads: number of threads slice_names: list of slice names to evaluation on save_dataset_name: data file name to save storage: data storage type (for memmap) Returns: """ log_rank_0_debug(logger, "Starting to extract subsentences") start = time.time() num_processes = min(dataset_threads, int(0.8 * multiprocessing.cpu_count())) log_rank_0_info( logger, f"Starting to build and save features with {num_processes} threads") log_rank_0_debug(logger, f"Counting lines") total_input = utils.load_json_file(meta_file)["num_mentions"] max_alias2pred = utils.load_json_file(meta_file)["max_alias2pred"] files_and_counts = utils.load_json_file(meta_file)["files_and_counts"] # IMPORTANT: for distributed writing to memmap files, you must create them in w+ mode before # being opened in r+ mode by workers memmap_file = np.memmap(save_dataset_name, dtype=storage, mode="w+", shape=(total_input, ), order="C") # Save -1 in sent_idx to check that things are loaded correctly later memmap_file[slice_names[0]]["sent_idx"][:] = -1 input_args = [] # Saves where in memap file to start writing offset = 0 for i, in_file_name in enumerate(files_and_counts.keys()): input_args.append({ "file_name": in_file_name, "in_file_lines": files_and_counts[in_file_name], "save_file_offset": offset, "ex_print_mod": int(np.ceil(total_input / 20)), "slice_names": slice_names, "max_alias2pred": max_alias2pred, }) offset += files_and_counts[in_file_name] if num_processes == 1: assert len(input_args) == 1 total_output = convert_examples_to_features_and_save_single( input_args[0], memmap_file) else: log_rank_0_debug( logger, "Initializing pool. This make take a few minutes.", ) pool = multiprocessing.Pool( processes=num_processes, initializer=convert_examples_to_features_and_save_initializer, initargs=[save_dataset_name, storage], ) total_output = 0 for res in pool.imap_unordered( convert_examples_to_features_and_save_hlp, input_args, chunksize=1): total_output += res pool.close() # Verify that sentences are unique and saved correctly mmap_file = np.memmap(save_dataset_name, dtype=storage, mode="r") all_uniq_ids = set() for i in tqdm(range(total_input), desc="Checking sentence uniqueness"): assert (mmap_file[slice_names[0]]["sent_idx"][i] != -1), f"Index {i} has -1 sent idx" uniq_id = str( f"{mmap_file[slice_names[0]]['sent_idx'][i]}.{mmap_file[slice_names[0]]['subslice_idx'][i]}" ) assert (uniq_id not in all_uniq_ids ), f"Idx {uniq_id} is not unique and already in data" all_uniq_ids.add(uniq_id) log_rank_0_debug( logger, f"Done with extracting examples in {time.time() - start}. Total lines seen {total_input}. " f"Total lines kept {total_output}", ) return
def __init__(self, args, use_weak_label, input_src, dataset_name, is_writer, distributed, word_symbols, entity_symbols, slice_dataset=None, dataset_is_eval=False): # Need to save args to reinstantiate logger self.args = args self.logger = logging_utils.get_logger(args) # Number of candidates, including NIL if a NIL model (train_in_candidates is False) self.K = entity_symbols.max_candidates + ( not args.data_config.train_in_candidates) self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand self.dataset_name = dataset_name self.slice_dataset = slice_dataset self.dataset_is_eval = dataset_is_eval # Slice names used for eval slices and a slicing model self.slice_names = train_utils.get_data_slices(args, dataset_is_eval) self.storage_type_file = data_utils.get_storage_file(self.dataset_name) # Mappings from sent_idx to row_id in dataset self.sent_idx_file = os.path.splitext( dataset_name)[0] + "_sent_idx.json" self.type_pred = False if args.data_config.type_prediction.use_type_pred: self.type_pred = True self.eid2typeid, self.num_types_with_pad = self.load_coarse_type_table( args, entity_symbols) # Load memory mapped file self.logger.info("Loading dataset...") self.logger.debug("Seeing if " + dataset_name + " exists") if (args.data_config.overwrite_preprocessed_data or (not os.path.exists(self.dataset_name)) or (not os.path.exists(self.sent_idx_file)) or (not os.path.exists(self.storage_type_file)) or (not os.path.exists( data_utils.get_batch_prep_config(self.dataset_name)))): start = time.time() self.logger.debug(f"Building dataset with {input_src}") # Only prep data once per node if is_writer: prep_data(args, use_weak_label=use_weak_label, dataset_is_eval=self.dataset_is_eval, input_src=input_src, dataset_name=dataset_name, prep_dir=data_utils.get_data_prep_dir(args)) if distributed: # Make sure all processes wait for data to be created dist.barrier() self.logger.debug( f"Finished building and saving dataset in {round(time.time() - start, 2)}s." ) start = time.time() # Storage type for loading memory mapped file of dataset self.storage_type = pickle.load(open(self.storage_type_file, 'rb')) self.data = np.memmap(self.dataset_name, dtype=self.storage_type, mode='r') self.data_len = len(self.data) # Mapping from sentence idx to rows in the dataset (indices). # Needed when sampling sentence indices from slices for evaluation. sent_idx_to_idx_str = utils.load_json_file(self.sent_idx_file) self.sent_idx_to_idx = { int(i): val for i, val in sent_idx_to_idx_str.items() } self.logger.info(f"Finished loading dataset.") # Stores info about the batch prepped embedding memory mapped files and their shapes and datatypes # so we can load them self.batch_prep_config = utils.load_json_file( data_utils.get_batch_prep_config(self.dataset_name)) self.batch_prepped_emb_files = {} self.batch_prepped_emb_file_names = {} for emb in args.data_config.ent_embeddings: if 'batch_prep' in emb and emb['batch_prep']: assert emb.key in self.batch_prep_config, f'Need to prep {emb.key}. Please call prep instead of run with batch_prep_embeddings set to true.' self.batch_prepped_emb_file_names[emb.key] = os.path.join( os.path.dirname(self.dataset_name), os.path.basename( self.batch_prep_config[emb.key]['file_name'])) self.batch_prepped_emb_files[emb.key] = np.memmap( self.batch_prepped_emb_file_names[emb.key], dtype=self.batch_prep_config[emb.key]['dtype'], shape=tuple(self.batch_prep_config[emb.key]['shape']), mode='r') assert len(self.batch_prepped_emb_files[emb.key]) == self.data_len,\ f'Preprocessed emb data file {self.batch_prep_config[emb.key]["file_name"]} does not match length of main data file.' # Stores embeddings that we compute on the fly; these are embeddings where batch_on_the_fly is set to true. self.batch_on_the_fly_embs = {} for emb in args.data_config.ent_embeddings: if 'batch_on_the_fly' in emb and emb['batch_on_the_fly'] is True: mod, load_class = import_class("bootleg.embeddings", emb.load_class) try: self.batch_on_the_fly_embs[emb.key] = getattr( mod, load_class)(main_args=args, emb_args=emb['args'], entity_symbols=entity_symbols, model_device=None, word_symbols=None, key=emb.key) except AttributeError as e: self.logger.warning( f'No prep method found for {emb.load_class} with error {e}' ) except Exception as e: print("ERROR", e) # The data in this table shouldn't be pickled since we delete it in the class __getstate__ self.alias2entity_table = AliasEntityTable( args=args, entity_symbols=entity_symbols) # Random NIL percent self.mask_perc = args.train_config.random_nil_perc self.random_nil = False # Don't want to random mask for eval if not dataset_is_eval: # Whether to use a random NIL training regime self.random_nil = args.train_config.random_nil if self.random_nil: self.logger.info( f'Using random nils during training with {self.mask_perc} percent' )
def build_static_embeddings(cls, emb_file, entity_symbols): """Builds the table of the embedding associated with each entity. Args: emb_file: embedding file to load entity_symbols: entity symbols Returns: numpy embedding matrix where each row is an emedding """ ending = os.path.splitext(emb_file)[1] found = 0 raw_num_ents = 0 if ending == ".json": dct = utils.load_json_file(emb_file) val = next(iter(dct.values())) if type(val) is int or type(val) is float: embedding_size = 1 conver_func = lambda x: np.array([x]) elif type(val) is list: embedding_size = len(val) conver_func = lambda x: np.array([y for y in x]) else: raise ValueError( f"Unrecognized type for the array value of {type(val)}" ) embeddings = {} for k in dct: embeddings[k] = conver_func(dct[k]) assert len(embeddings[k]) == embedding_size entity2staticemb_table = np.zeros( (entity_symbols.num_entities_with_pad_and_nocand, embedding_size) ) raw_num_ents = len(embeddings) for qid in tqdm(entity_symbols.get_all_qids()): if qid in embeddings: found += 1 emb = embeddings[qid] eid = entity_symbols.get_eid(qid) entity2staticemb_table[eid, :embedding_size] = emb elif ending == ".pt": log_rank_0_debug( logger, f"We are readining in the embedding file from a .pt. We assume this is already mapped to eids", ) (qid2eid_map, entity2staticemb_table_raw) = torch.load(emb_file) entity2staticemb_table_raw = ( entity2staticemb_table_raw.detach().cpu().numpy() ) raw_num_ents = entity2staticemb_table_raw.shape[0] # +2 handles the PAD and UNK entities assert entity2staticemb_table_raw.shape[0] == len(qid2eid_map) + 2, ( f"The saved static embeddings file had mismatched shapes between qid2eid {len(qid2eid_map)} and " f"weights {entity2staticemb_table_raw.shape[0]}" ) entity2staticemb_table = np.zeros( ( entity_symbols.num_entities_with_pad_and_nocand, entity2staticemb_table_raw.shape[1], ) ) found = 0 for qid in tqdm(entity_symbols.get_all_qids()): if qid in qid2eid_map: found += 1 raw_eid = qid2eid_map[qid] emb = entity2staticemb_table_raw[raw_eid] new_eid = entity_symbols.get_eid(qid) entity2staticemb_table[new_eid, :] = emb else: raise ValueError( f"We do not support static embeddings from {ending}. We only support .json and .pt" ) log_rank_0_debug( logger, f"Found {found} ({found/len(entity_symbols.get_all_qids())} percent) of all entities after " f"reading {raw_num_ents} original entities have a static embedding", ) return entity2staticemb_table
def __init__( self, main_args, emb_args, entity_symbols, key, cpu, normalize, dropout1d_perc, dropout2d_perc, ): super(TopKEntityEmb, self).__init__( main_args=main_args, emb_args=emb_args, entity_symbols=entity_symbols, key=key, cpu=cpu, normalize=normalize, dropout1d_perc=dropout1d_perc, dropout2d_perc=dropout2d_perc, ) allowable_keys = { "learned_embedding_size", "perc_emb_drop", "qid2topk_eid", "regularize_mapping", "tail_init", "tail_init_zeros", } correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args) if not correct: raise ValueError(f"The key {bad_key} is not in {allowable_keys}") assert ( "learned_embedding_size" in emb_args ), f"TopKEntityEmb must have learned_embedding_size in args" assert "perc_emb_drop" in emb_args, ( f"To use TopKEntityEmb we need perc_emb_drop to be in the args. This gives the percentage of embeddings" f" removed." ) self.learned_embedding_size = emb_args.learned_embedding_size # We remove perc_emb_drop percent of the embeddings and add one to represent the new toes embedding num_topk_entities_with_pad_and_nocand = ( entity_symbols.num_entities_with_pad_and_nocand - int(emb_args.perc_emb_drop * entity_symbols.num_entities) + 1 ) # Mapping of entity to the new eid mapping eid2topkeid = torch.arange(0, entity_symbols.num_entities_with_pad_and_nocand) # There are issues with using -1 index into the embeddings; so we manually set it to be the last value eid2topkeid[-1] = num_topk_entities_with_pad_and_nocand - 1 if "qid2topk_eid" not in emb_args: assert self.from_pretrained, ( f"If you don't provide the qid2topk_eid mapping as an argument to TopKEntityEmb, " f"you must be loading a model from a checkpoint to build this index mapping" ) self.learned_entity_embedding = nn.Embedding( num_topk_entities_with_pad_and_nocand, self.learned_embedding_size, padding_idx=-1, sparse=False, ) self._dim = main_args.model_config.hidden_size if "regularize_mapping" in emb_args: eid2reg = torch.zeros(num_topk_entities_with_pad_and_nocand) else: eid2reg = None # If tail_init is false, all embeddings are randomly intialized. # If tail_init is true, we initialize all embeddings to be the same. self.tail_init = True self.tail_init_zeros = False # None init vec will be random init_vec = None if not self.from_pretrained: qid2topk_eid = utils.load_json_file(emb_args.qid2topk_eid) assert ( len(qid2topk_eid) == entity_symbols.num_entities ), f"You must have an item in qid2topk_eid for each qid in entity_symbols" for qid in entity_symbols.get_all_qids(): old_eid = entity_symbols.get_eid(qid) new_eid = qid2topk_eid[qid] eid2topkeid[old_eid] = new_eid assert eid2topkeid[0] == 0, f"The 0 eid shouldn't be changed" assert ( eid2topkeid[-1] == num_topk_entities_with_pad_and_nocand - 1 ), "The -1 eid should still map to -1" if "tail_init" in emb_args: self.tail_init = emb_args.tail_init if "tail_init_zeros" in emb_args: self.tail_init_zeros = emb_args.tail_init_zeros self.tail_init = False init_vec = torch.zeros(1, self.learned_embedding_size) assert not ( self.tail_init and self.tail_init_zeros ), f"Can only have one of tail_init or tail_init_zeros set" if self.tail_init or self.tail_init_zeros: if self.tail_init_zeros: log_rank_0_debug( logger, f"All learned entity embeddings are initialized to zero.", ) else: log_rank_0_debug( logger, f"All learned entity embeddings are initialized to the same value.", ) init_vec = model_utils.init_embeddings_to_vec( self.learned_entity_embedding, pad_idx=-1, vec=init_vec ) vec_save_file = os.path.join( emmental.Meta.log_path, "init_vec_entity_embs.npy" ) log_rank_0_debug(logger, f"Saving init vector to {vec_save_file}") if ( torch.distributed.is_initialized() and torch.distributed.get_rank() == 0 ): np.save(vec_save_file, init_vec) else: log_rank_0_debug( logger, f"All learned embeddings are randomly initialized." ) # Regularization mapping goes from eid to 2d dropout percent if "regularize_mapping" in emb_args: log_rank_0_debug( logger, f"You are using regularization mapping with a topK entity embedding. " f"This means all QIDs that are mapped to the same" f" EID will get the same regularization value.", ) if self.dropout1d_perc > 0 or self.dropout2d_perc > 0: log_rank_0_debug( logger, f"You have 1D or 2D regularization set with a regularize_mapping. Do you mean to do this?", ) log_rank_0_debug( logger, f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}", ) eid2reg = self.load_regularization_mapping( main_args.data_config, entity_symbols, qid2topk_eid, num_topk_entities_with_pad_and_nocand, emb_args.regularize_mapping, ) # Keep this mapping so a topK model can simply be loaded without needing the new eid mapping self.register_buffer("eid2topkeid", eid2topkeid) self.register_buffer("eid2reg", eid2reg)
def __init__(self, main_args, emb_args, model_device, entity_symbols, word_symbols, word_emb, key): super(TopKEntityEmb, self).__init__(main_args=main_args, emb_args=emb_args, model_device=model_device, entity_symbols=entity_symbols, word_symbols=word_symbols, word_emb=word_emb, key=key) self.logger = logging_utils.get_logger(main_args) self.learned_embedding_size = emb_args.learned_embedding_size self.normalize = True assert "perc_emb_drop" in emb_args, f"To use TopKEntityEmb we need perc_emb_drop to be in the args. This gives the percentage of embeddings" \ f" removed." # We remove perc_emb_drop percent of the embeddings and add one to represent the new toes embedding num_topk_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand - int( emb_args.perc_emb_drop * entity_symbols.num_entities) + 1 # Mapping of entity to the new eid mapping qid2topk_eid = {} eid2topkeid = torch.arange( 0, entity_symbols.num_entities_with_pad_and_nocand) # There are issues with using -1 index into the embeddings; so we manually set it to be the last value eid2topkeid[-1] = num_topk_entities_with_pad_and_nocand - 1 if "qid2topk_eid" not in emb_args: assert len(main_args.run_config.init_checkpoint) > 0, f"If you don't provide the qid2topk_eid mapping as an argument to TopKEntityEmb, " \ f"you must be loading a model from a checkpoint to build this index mapping" else: qid2topk_eid = utils.load_json_file(emb_args.qid2topk_eid) assert len( qid2topk_eid ) == entity_symbols.num_entities, f"You must have an item in qid2topk_eid for each qid in entity_symbols" for qid in entity_symbols.get_all_qids(): old_eid = entity_symbols.get_eid(qid) new_eid = qid2topk_eid[qid] eid2topkeid[old_eid] = new_eid assert eid2topkeid[0] == 0, f"The 0 eid shouldn't be changed" assert eid2topkeid[ -1] == num_topk_entities_with_pad_and_nocand - 1, "The -1 eid should still map to -1" self.learned_entity_embedding = nn.Embedding( num_topk_entities_with_pad_and_nocand, self.learned_embedding_size, padding_idx=-1, sparse=True) # Keep this mapping so a topK model can simply be loaded without needing the new eid mapping self.register_buffer("eid2topkeid", eid2topkeid) # If tail_init is false, all embeddings are randomly intialized. # If tail_init is true, we initialize all embeddings to be the same. self.tail_init = True self.tail_init_zeros = False # None init vec will be random init_vec = None if "tail_init" in emb_args: self.tail_init = emb_args.tail_init if "tail_init_zeros" in emb_args: self.tail_init_zeros = emb_args.tail_init_zeros self.tail_init = False init_vec = torch.zeros(1, self.learned_embedding_size) assert not (self.tail_init and self.tail_init_zeros ), f"Can only have one of tail_init or tail_init_zeros set" if self.tail_init or self.tail_init_zeros: if self.tail_init_zeros: self.logger.debug( f"All learned entity embeddings are initialized to zero.") else: self.logger.debug( f"All learned entity embeddings are initialized to the same value." ) init_vec = model_utils.init_embeddings_to_vec( self.learned_entity_embedding, pad_idx=-1, vec=init_vec) vec_save_file = os.path.join( train_utils.get_save_folder(main_args.run_config), "init_vec_entity_embs.npy") self.logger.debug(f"Saving init vector to {vec_save_file}") np.save(vec_save_file, init_vec) else: self.logger.debug( f"All learned embeddings are randomly initialized.") self._dim = main_args.model_config.hidden_size self.eid2reg = None # Regularization mapping goes from eid to 2d dropout percent if "regularize_mapping" in emb_args: self.logger.warning( f"You are using regularization mapping with a topK entity embedding. This means all QIDs that are mapped to the same" f" EID will get the same regularization value.") self.logger.debug( f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}" ) self.eid2reg = self.load_regularization_mapping( main_args, qid2topk_eid, num_topk_entities_with_pad_and_nocand, emb_args.regularize_mapping, self.logger.debug) self.eid2reg = self.eid2reg.to(model_device)
def compress_topk_embeddings(args): assert 0 < args.perc_emb_drop < 1, f"perc_emb_drop must be between 0 and 1" print( f"Loading entity symbols from {os.path.join(args.entity_dir, 'entity_mappings')}" ) entity_db = EntitySymbols.load_from_cache( os.path.join(args.entity_dir, "entity_mappings")) print(f"Loading qid2count from {args.qid2count}") qid2count = utils.load_json_file(args.qid2count) print(f"Filtering qids") ( qid2topk_eid, old2new_eid, toes_eid, new_num_topk_entities, ) = filter_qids(args.perc_emb_drop, entity_db, qid2count) if len(args.model_path) > 0: assert (args.save_model_path is not None and len(args.save_model_path) > 0 ), f"If you give a model path, you must give a save checkpoint" print(f"Filtering embeddings") state_dict, model_state_dict = load_statedict(args.model_path) try: get_nested_item(model_state_dict, ENTITY_EMB_KEYS) except: print( f"ERROR: All of {ENTITY_EMB_KEYS} are not in model_state_dict") raise model_state_dict = filter_embs( new_num_topk_entities, entity_db, old2new_eid, qid2topk_eid, toes_eid, model_state_dict, ) # Generate the new old2new_eid weight vector to save in model_state_dict oldeid2topkeid = torch.arange( 0, entity_db.num_entities_with_pad_and_nocand) # The +2 is to account for pads and unks. The -1 is as there are issues with -1 in the indexing # for entity embeddings. So we must manually make it the last entry oldeid2topkeid[-1] = new_num_topk_entities + 2 - 1 for qid, new_eid in tqdm(qid2topk_eid.items(), desc="Setting new ids"): old_eid = entity_db.get_eid(qid) oldeid2topkeid[old_eid] = new_eid assert oldeid2topkeid[0] == 0, f"The 0 eid shouldn't be changed" assert (oldeid2topkeid[-1] == new_num_topk_entities + 2 - 1), "The -1 eid should still map to the last row" model_state_dict = set_nested_item(model_state_dict, ENTITY_EID_KEYS, oldeid2topkeid) # Remove the eid2reg value as that was with the old entity id mapping try: model_state_dict = set_nested_item(model_state_dict, ENTITY_REG_KEYS, None) except: print( f"Could not remove regularization. If your model was trained with regularization mapping on " f"the learned entity embedding, this should not happen.") print(model_state_dict["module_pool"]["learned"].keys()) state_dict["model"] = model_state_dict print(f"Saving model at {args.save_model_path}") torch.save(state_dict, args.save_model_path) print( f"Saving topk to eid at {os.path.join(args.entity_dir, 'entity_mappings', args.save_qid2topk_file)}" ) utils.dump_json_file( os.path.join(args.entity_dir, "entity_mappings", args.save_qid2topk_file), qid2topk_eid, ) if args.model_config is not None: modify_config( args.model_config, args.save_model_config, args.save_model_path, os.path.join(args.entity_dir, "entity_mappings", args.save_qid2topk_file), args.perc_emb_drop, )