def check_and_create_alias_cand_trie(save_folder, entity_symbols): """Creates a mmap memory trie object if it doesn't exist for storing the alias-candidate mappings. Args: save_folder: save folder for alias trie entity_symbols: entity symbols Returns: """ try: AliasCandRecordTrie(load_dir=save_folder) except FileNotFoundError: log_rank_0_debug( logger, "Creating the alias candidate trie for faster parallel processing. " "This is a one time cost", ) alias_trie = AliasCandRecordTrie( input_dict=entity_symbols.get_alias2qids(), vocabulary=entity_symbols.get_qid2title(), max_value=entity_symbols.max_candidates, ) alias_trie.dump(save_folder) return
def get_sent_idx2num_mens(data_file): """Gets the map from sentence index to number of mentions and to data. Used for calculating offsets and chunking file. Args: data_file: eval file Returns: Dict of sentence index -> number of mention per sentence, Dict of sentence index -> input line """ sent_idx2num_mens = {} sent_idx2row = {} total_num_mentions = 0 with open(data_file) as f: for line in f: line = ujson.loads(line) # keep track of the start idx in the condensed memory mapped file for each sentence (varying number of # aliases) assert ( line["sent_idx_unq"] not in sent_idx2num_mens ), f'Sentence indices must be unique. {line["sent_idx_unq"]} already seen.' sent_idx2row[str(line["sent_idx_unq"])] = line # Save as string for Marisa Tri later sent_idx2num_mens[str(line["sent_idx_unq"])] = len(line["aliases"]) # We include false aliases for debugging (and alias_pos includes them) total_num_mentions += len(line["aliases"]) # print("INSIDE SENT MAP", str(line["sent_idx_unq"]), total_num_mentions) log_rank_0_debug( logger, f"Total number of mentions across all sentences: {total_num_mentions}") return sent_idx2num_mens, sent_idx2row
def __init__( self, main_args, emb_args, entity_symbols, key, cpu, normalize, dropout1d_perc, dropout2d_perc, ): super(KGIndices, self).__init__( main_args=main_args, emb_args=emb_args, entity_symbols=entity_symbols, key=key, cpu=cpu, normalize=normalize, dropout1d_perc=dropout1d_perc, dropout2d_perc=dropout2d_perc, ) self._dim = 0 # Weight for the diagonal addition to the KG indices - allows for summing an entity with other connections self.kg_bias_weight = torch.nn.Parameter(torch.tensor(2.0)) self.kg_softmax = torch.nn.Softmax(dim=2) # This determines that, when prepping the embeddings, we will query the kg_adj matrix - generating # M*K values per entity candidate self.kg_adj_process_func = embedding_utils.prep_kg_feature_matrix log_rank_0_debug( logger, f"You are using the KGIndices class with key {key}." f" This key need to be used in the attention network to access the kg bias matrix." f" This embedding is not appended to the payload.", )
def prep( cls, data_config, entity_symbols, num_aliases_with_pad_and_unk, num_cands_K, ): """Preps the alias to entity EID table. Args: data_config: data config entity_symbols: entity symbols num_aliases_with_pad_and_unk: number of aliases including pad and unk num_cands_K: number of candidates per alias (aka K) Returns: torch Tensor of the alias to EID table, save pt file """ # we pass num_aliases_with_pad_and_unk and num_cands_K to remove the dependence on entity_symbols # when the alias table is already prepped data_shape = (num_aliases_with_pad_and_unk, num_cands_K) # dependent on train_in_candidates flag prep_dir = data_utils.get_emb_prep_dir(data_config) alias_str = os.path.splitext( data_config.alias_cand_map.replace("/", "_"))[0] prep_file = os.path.join( prep_dir, f"alias2entity_table_{alias_str}_InC{int(data_config.train_in_candidates)}.pt", ) log_rank_0_debug(logger, f"Looking for alias table in {prep_file}") if not data_config.overwrite_preprocessed_data and os.path.exists( prep_file): log_rank_0_debug(logger, f"Loading alias table from {prep_file}") start = time.time() alias2entity_table = np.memmap(prep_file, dtype="int64", mode="r+", shape=data_shape) log_rank_0_debug( logger, f"Loaded alias table in {round(time.time() - start, 2)}s") else: start = time.time() log_rank_0_debug(logger, f"Building alias table") utils.ensure_dir(prep_dir) mmap_file = np.memmap(prep_file, dtype="int64", mode="w+", shape=data_shape) alias2entity_table = cls.build_alias_table(data_config, entity_symbols) mmap_file[:] = alias2entity_table[:] mmap_file.flush() log_rank_0_debug( logger, f"Finished building and saving alias table in {round(time.time() - start, 2)}s.", ) alias2entity_table = torch.from_numpy(alias2entity_table) return alias2entity_table, prep_file
def __init__( self, main_args, emb_args, entity_symbols, key, cpu, normalize, dropout1d_perc, dropout2d_perc, ): super(TitleEmb, self).__init__( main_args=main_args, emb_args=emb_args, entity_symbols=entity_symbols, key=key, cpu=cpu, normalize=normalize, dropout1d_perc=dropout1d_perc, dropout2d_perc=dropout2d_perc, ) allowable_keys = {"proj", "requires_grad"} correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args) if not correct: raise ValueError(f"The key {bad_key} is not in {allowable_keys}") self.orig_dim = BERT_WORD_DIM self.merge_func = self.average_titles self.M = main_args.data_config.max_aliases self.K = get_max_candidates(entity_symbols, main_args.data_config) self._dim = main_args.model_config.hidden_size if "proj" in emb_args: self._dim = emb_args.proj self.requires_grad = True if "requires_grad" in emb_args: self.requires_grad = emb_args.requires_grad self.title_proj = torch.nn.Linear(self.orig_dim, self._dim) log_rank_0_debug( logger, f'Setting the "proj" parameter to {self._dim} and the "requires_grad" parameter to {self.requires_grad}', ) ( entity2titleid_table, entity2titlemask_table, entity2tokentypeid_table, ) = self.prep( data_config=main_args.data_config, entity_symbols=entity_symbols, ) self.register_buffer("entity2titleid_table", entity2titleid_table, persistent=False) self.register_buffer("entity2titlemask_table", entity2titlemask_table, persistent=False) self.register_buffer("entity2tokentypeid_table", entity2tokentypeid_table, persistent=False)
def __init__( self, main_args, emb_args, entity_symbols, key, cpu, normalize, dropout1d_perc, dropout2d_perc, ): super(StaticEmb, self).__init__( main_args=main_args, emb_args=emb_args, entity_symbols=entity_symbols, key=key, cpu=cpu, normalize=normalize, dropout1d_perc=dropout1d_perc, dropout2d_perc=dropout2d_perc, ) allowable_keys = {"emb_file", "proj"} correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args) if not correct: raise ValueError(f"The key {bad_key} is not in {allowable_keys}") assert "emb_file" in emb_args, f"Must have emb_file in args for StaticEmb" self.entity2static = self.prep( data_config=main_args.data_config, emb_args=emb_args, entity_symbols=entity_symbols, ) self.orig_dim = self.entity2static.shape[1] # entity2static_embedding = torch.nn.Embedding( # entity_symbols.num_entities_with_pad_and_nocand, # self.orig_dim, # padding_idx=-1, # sparse=True, # ) self.normalize = False if self.orig_dim > 1: self.normalize = True self.proj = None self._dim = self.orig_dim if "proj" in emb_args: log_rank_0_debug( logger, f"Adding a projection layer to the static emb to go to dim {emb_args.proj}", ) self._dim = emb_args.proj self.proj = MLP( input_size=self.orig_dim, num_hidden_units=None, output_size=self._dim, num_layers=1, )
def __init__( self, main_args, emb_args, entity_symbols, key, cpu, normalize, dropout1d_perc, dropout2d_perc, ): super(KGAdjEmb, self).__init__( main_args=main_args, emb_args=emb_args, entity_symbols=entity_symbols, key=key, cpu=cpu, normalize=normalize, dropout1d_perc=dropout1d_perc, dropout2d_perc=dropout2d_perc, ) allowable_keys = {"kg_adj", "threshold", "log_weight"} correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args) if not correct: raise ValueError(f"The key {bad_key} is not in {allowable_keys}") assert "kg_adj" in emb_args, f"KG embedding requires kg_adj to be set in args" assert (self.normalize is False), f"We can't normalize a KGAdjEmb as it has hidden dim 1" assert (self.dropout1d_perc == 0.0 ), f"We can't dropout 1d a KGAdjEmb as it has hidden dim 1" assert (self.dropout2d_perc == 0.0 ), f"We can't dropout 2d a KGAdjEmb as it has hidden dim 1" # This determines that, when prepping the embeddings, we will query the kg_adj matrix and sum the results - # generating one value per entity candidate self.kg_adj_process_func = embedding_utils.prep_kg_feature_sum self._dim = 1 self.threshold_weight = 0 if "threshold" in emb_args: self.threshold_weight = float(emb_args.threshold) self.log_weight = False if "log_weight" in emb_args: self.log_weight = emb_args.log_weight assert type(self.log_weight) is bool log_rank_0_debug( logger, f"Setting log_weight to be {self.log_weight} and threshold to be {self.threshold_weight} in {key}", ) self.kg_adj, self.prep_file = self.prep( data_config=main_args.data_config, emb_args=emb_args, entity_symbols=entity_symbols, threshold=self.threshold_weight, log_weight=self.log_weight, )
def prep(cls, data_config, entity_symbols): """Prep the title data. Args: data_config: data config entity_symbols: entity symbols Returns: torch tensor EID to title token IDs, EID to title token mask, EID to title token type ID (for BERT) """ prep_dir = data_utils.get_emb_prep_dir(data_config) prep_file_token_ids = os.path.join( prep_dir, f"title_token_ids_{data_config.word_embedding.bert_model}.pt") prep_file_attn_mask = os.path.join( prep_dir, f"title_attn_mask_{data_config.word_embedding.bert_model}.pt") prep_file_token_type_ids = os.path.join( prep_dir, f"title_token_type_ids_{data_config.word_embedding.bert_model}.pt") utils.ensure_dir(os.path.dirname(prep_file_token_ids)) log_rank_0_debug( logger, f"Looking for title table mapping in {prep_file_token_ids}") if (not data_config.overwrite_preprocessed_data and os.path.exists(prep_file_token_ids) and os.path.exists(prep_file_attn_mask) and os.path.exists(prep_file_token_type_ids)): log_rank_0_debug( logger, f"Loading existing title table from {prep_file_token_ids}") start = time.time() entity2titleid = torch.load(prep_file_token_ids) entity2titlemask = torch.load(prep_file_attn_mask) entity2tokentypeid = torch.load(prep_file_token_type_ids) log_rank_0_debug( logger, f"Loaded existing title table in {round(time.time() - start, 2)}s", ) else: start = time.time() log_rank_0_debug(logger, f"Loading tokenizer") tokenizer = load_tokenizer(data_config) ( entity2titleid, entity2titlemask, entity2tokentypeid, ) = cls.build_title_table(tokenizer=tokenizer, entity_symbols=entity_symbols) torch.save(entity2titleid, prep_file_token_ids) torch.save(entity2titlemask, prep_file_attn_mask) torch.save(entity2tokentypeid, prep_file_token_type_ids) log_rank_0_debug( logger, f"Finished building and saving title table in {round(time.time() - start, 2)}s.", ) return entity2titleid, entity2titlemask, entity2tokentypeid
def configure_optimizer(config): """Configures the optimizer for Bootleg. By default, we use SparseDenseAdam. We always change the parameter group for layer norms following standard BERT finetuning methods. Args: config: config Returns: """ # Set default Bootleg optimizer if config doesn't override it if config.learner_config.optimizer_config.optimizer is None: log_rank_0_debug(logger, f"Setting default optimizer to be SparseDenseAdam") custom_optimizer = partial( SparseDenseAdamW, lr=config.learner_config.optimizer_config.lr, weight_decay=config.learner_config.optimizer_config.l2, betas=config.learner_config.optimizer_config.adamw_config.betas, eps=config.learner_config.optimizer_config.adamw_config.eps, ) custom_optim_config = { "learner_config": { "optimizer_config": { "optimizer": custom_optimizer } } } emmental.Meta.update_config(custom_optim_config) # Specify parameter group for Adam BERT def grouped_parameters(model): no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] return [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": emmental.Meta.config["learner_config"]["optimizer_config"] ["l2"], }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0, }, ] emmental.Meta.config["learner_config"]["optimizer_config"][ "parameters"] = grouped_parameters
def load_regularization_mapping(cls, data_config, entity_symbols, reg_file): """Reads in a csv file with columns [qid, regularization]. In the forward pass, the entity id with associated qid will be regularized with probability regularization. Args: data_config: data config qid2topk_eid: Dict from QID to eid in the entity embedding num_entities_with_pad_and_nocand: number of entities including pad and null candidate option reg_file: regularization csv file Returns: Tensor where each value is the regularization value for EID """ reg_str = os.path.splitext(os.path.basename(reg_file.replace("/", "_")))[0] prep_dir = data_utils.get_data_prep_dir(data_config) prep_file = os.path.join( prep_dir, f"entity_regularization_mapping_{reg_str}.pt") utils.ensure_dir(os.path.dirname(prep_file)) log_rank_0_debug(logger, f"Looking for regularization mapping in {prep_file}") if not data_config.overwrite_preprocessed_data and os.path.exists( prep_file): log_rank_0_debug( logger, f"Loading existing entity regularization mapping from {prep_file}", ) start = time.time() eid2reg = torch.load(prep_file) log_rank_0_debug( logger, f"Loaded existing entity regularization mapping in {round(time.time() - start, 2)}s", ) else: start = time.time() log_rank_0_info( logger, f"Building entity regularization mapping from {reg_file}") qid2reg = pd.read_csv(reg_file) assert ( "qid" in qid2reg.columns and "regularization" in qid2reg.columns ), f"Expected qid and regularization as the column names for {reg_file}" # default of no mask eid2reg_arr = [0.0 ] * entity_symbols.num_entities_with_pad_and_nocand for row_idx, row in qid2reg.iterrows(): if entity_symbols.qid_exists(row["qid"]): eid = entity_symbols.get_eid(row["qid"]) eid2reg_arr[eid] = row["regularization"] eid2reg = torch.tensor(eid2reg_arr) torch.save(eid2reg, prep_file) log_rank_0_debug( logger, f"Finished building and saving entity regularization mapping in {round(time.time() - start, 2)}s.", ) return eid2reg
def prep(cls, data_config, emb_args, entity_symbols): """Static embedding prep. Args: data_config: data config emb_args: embedding args entity_symbols: entity symbols Returns: numpy embedding array where each row is the embedding for an EID """ static_str = os.path.splitext(os.path.basename(emb_args.emb_file))[0] prep_dir = data_utils.get_emb_prep_dir(data_config) prep_file = os.path.join(prep_dir, f"static_table_{static_str}.npy") log_rank_0_debug(logger, f"Looking for static embedding saved at {prep_file}") if not data_config.overwrite_preprocessed_data and os.path.exists( prep_file): log_rank_0_debug( logger, f"Loading existing static embedding table from {prep_file}") start = time.time() entity2staticemb_table = np.load(prep_file, mmap_mode="r") log_rank_0_debug( logger, f"Loaded existing static embedding table in {round(time.time() - start, 2)}s", ) else: start = time.time() emb_file = emb_args.emb_file log_rank_0_debug(logger, f"Building static table from file {emb_file}") entity2staticemb_table = cls.build_static_embeddings( emb_file=emb_file, entity_symbols=entity_symbols) np.save(prep_file, entity2staticemb_table) entity2staticemb_table = np.load(prep_file, mmap_mode="r") log_rank_0_debug( logger, f"Finished building and saving static embedding table in {round(time.time() - start, 2)}s.", ) return entity2staticemb_table
def count_parameters(model, requires_grad, logger): """Counts the number of parameters, printing along the way, with param.required_grad == requires_grad. Args: model: model to count requires_grad: whether to look at grad or no grad params logger: logger Returns: """ for p in [ p for p in model.named_parameters() if p[1].requires_grad is requires_grad ]: log_rank_0_debug( logger, "{:s} {:d} {:.2f} MB".format(p[0], p[1].numel(), p[1].numel() * 4 / 1024**2), ) return sum(p.numel() for p in model.parameters() if p.requires_grad is requires_grad)
def prep(cls, data_config, emb_args, entity_symbols): """Prep the type id table. Args: data_config: data config emb_args: embedding args entity_symbols: entity synbols Returns: torch tensor from EID to type IDS, type ID to row in type embedding matrix, and number of types with unk type """ type_str = os.path.splitext(emb_args.type_labels.replace("/", "_"))[0] prep_dir = data_utils.get_emb_prep_dir(data_config) prep_file = os.path.join( prep_dir, f"type_table_{type_str}_{emb_args.max_types}.pt") utils.ensure_dir(os.path.dirname(prep_file)) if not data_config.overwrite_preprocessed_data and os.path.exists( prep_file): log_rank_0_debug(logger, f"Loading existing type table from {prep_file}") start = time.time() eid2typeids_table, type2row_dict, num_types_with_unk = torch.load( prep_file) log_rank_0_debug( logger, f"Loaded existing type table in {round(time.time() - start, 2)}s", ) else: start = time.time() type_labels = os.path.join(data_config.emb_dir, emb_args.type_labels) type_vocab = os.path.join(data_config.emb_dir, emb_args.type_vocab) log_rank_0_debug(logger, f"Building type table from {type_labels}") eid2typeids_table, type2row_dict, num_types_with_unk = cls.build_type_table( type_labels=type_labels, type_vocab=type_vocab, max_types=emb_args.max_types, entity_symbols=entity_symbols, ) torch.save((eid2typeids_table, type2row_dict, num_types_with_unk), prep_file) log_rank_0_debug( logger, f"Finished building and saving type table in {round(time.time() - start, 2)}s.", ) return eid2typeids_table, type2row_dict, num_types_with_unk, prep_file
def prep( cls, data_config, emb_args, entity_symbols, threshold, log_weight, ): """Preps the KG information. Args: data_config: data config emb_args: embedding args entity_symbols: entity symbols threshold: weight threshold for counting an edge log_weight: whether to take the log of the weight value after the threshold Returns: numpy sparce KG adjacency matrix, prep file """ file_tag = os.path.splitext(emb_args.kg_adj.replace("/", "_"))[0] prep_dir = data_utils.get_emb_prep_dir(data_config) prep_file = os.path.join( prep_dir, f"kg_adj_file_{file_tag}_{threshold}_{log_weight}.npz") utils.ensure_dir(os.path.dirname(prep_file)) if not data_config.overwrite_preprocessed_data and os.path.exists( prep_file): log_rank_0_debug(logger, f"Loading existing KG adj from {prep_file}") start = time.time() kg_adj = scipy.sparse.load_npz(prep_file) log_rank_0_debug( logger, f"Loaded existing KG adj in {round(time.time() - start, 2)}s") else: start = time.time() kg_adj_file = os.path.join(data_config.emb_dir, emb_args.kg_adj) log_rank_0_debug(logger, f"Building KG adj from {kg_adj_file}") kg_adj = cls.build_kg_adj(kg_adj_file, entity_symbols, threshold, log_weight) scipy.sparse.save_npz(prep_file, kg_adj) log_rank_0_debug( logger, f"Finished building and saving KG adj in {round(time.time() - start, 2)}s.", ) return kg_adj, prep_file
def convert_examples_to_features_and_save(meta_file, dataset_threads, slice_names, save_dataset_name, storage): """Converts the prepped examples into input features and saves in memmap files. These are used in the __get_item__ method. Args: meta_file: metadata file where input file paths are saved dataset_threads: number of threads slice_names: list of slice names to evaluation on save_dataset_name: data file name to save storage: data storage type (for memmap) Returns: """ log_rank_0_debug(logger, "Starting to extract subsentences") start = time.time() num_processes = min(dataset_threads, int(0.8 * multiprocessing.cpu_count())) log_rank_0_info( logger, f"Starting to build and save features with {num_processes} threads") log_rank_0_debug(logger, f"Counting lines") total_input = utils.load_json_file(meta_file)["num_mentions"] max_alias2pred = utils.load_json_file(meta_file)["max_alias2pred"] files_and_counts = utils.load_json_file(meta_file)["files_and_counts"] # IMPORTANT: for distributed writing to memmap files, you must create them in w+ mode before # being opened in r+ mode by workers memmap_file = np.memmap(save_dataset_name, dtype=storage, mode="w+", shape=(total_input, ), order="C") # Save -1 in sent_idx to check that things are loaded correctly later memmap_file[slice_names[0]]["sent_idx"][:] = -1 input_args = [] # Saves where in memap file to start writing offset = 0 for i, in_file_name in enumerate(files_and_counts.keys()): input_args.append({ "file_name": in_file_name, "in_file_lines": files_and_counts[in_file_name], "save_file_offset": offset, "ex_print_mod": int(np.ceil(total_input / 20)), "slice_names": slice_names, "max_alias2pred": max_alias2pred, }) offset += files_and_counts[in_file_name] if num_processes == 1: assert len(input_args) == 1 total_output = convert_examples_to_features_and_save_single( input_args[0], memmap_file) else: log_rank_0_debug( logger, "Initializing pool. This make take a few minutes.", ) pool = multiprocessing.Pool( processes=num_processes, initializer=convert_examples_to_features_and_save_initializer, initargs=[save_dataset_name, storage], ) total_output = 0 for res in pool.imap_unordered( convert_examples_to_features_and_save_hlp, input_args, chunksize=1): total_output += res pool.close() # Verify that sentences are unique and saved correctly mmap_file = np.memmap(save_dataset_name, dtype=storage, mode="r") all_uniq_ids = set() for i in tqdm(range(total_input), desc="Checking sentence uniqueness"): assert (mmap_file[slice_names[0]]["sent_idx"][i] != -1), f"Index {i} has -1 sent idx" uniq_id = str( f"{mmap_file[slice_names[0]]['sent_idx'][i]}.{mmap_file[slice_names[0]]['subslice_idx'][i]}" ) assert (uniq_id not in all_uniq_ids ), f"Idx {uniq_id} is not unique and already in data" all_uniq_ids.add(uniq_id) log_rank_0_debug( logger, f"Done with extracting examples in {time.time() - start}. Total lines seen {total_input}. " f"Total lines kept {total_output}", ) return
def __init__( self, main_args, dataset, use_weak_label, entity_symbols, dataset_threads, split="train", ): global_start = time.time() log_rank_0_info(logger, f"Building slice dataset for {split} from {dataset}.") spawn_method = main_args.run_config.spawn_method data_config = main_args.data_config orig_spawn = multiprocessing.get_start_method() multiprocessing.set_start_method(spawn_method, force=True) self.slice_names = data_utils.get_eval_slices(data_config.eval_slices) self.get_slice_dt = lambda max_a2p: np.dtype([ ("sent_idx", int), ("subslice_idx", int), ("alias_slice_incidence", int, (max_a2p, )), ("prob_labels", float, (max_a2p, )), ]) self.get_storage = lambda max_a2p: np.dtype( [(slice_name, self.get_slice_dt(max_a2p)) for slice_name in self.slice_names]) # Folder for all mmap saved files save_dataset_folder = data_utils.get_save_data_folder( data_config, use_weak_label, dataset) utils.ensure_dir(save_dataset_folder) # Folder for temporary output files temp_output_folder = os.path.join(data_config.data_dir, data_config.data_prep_dir, f"prep_{split}_slice_files") utils.ensure_dir(temp_output_folder) # Input step 1 create_ex_indir = os.path.join(temp_output_folder, "create_examples_input") utils.ensure_dir(create_ex_indir) # Input step 2 create_ex_outdir = os.path.join(temp_output_folder, "create_examples_output") utils.ensure_dir(create_ex_outdir) # Meta data saved files meta_file = os.path.join(temp_output_folder, "meta_data.json") # File for standard training data hash = hashlib.sha1(str( self.slice_names).encode("UTF-8")).hexdigest()[:10] self.save_dataset_name = os.path.join(save_dataset_folder, f"ned_slices_{hash}.bin") self.save_data_config_name = os.path.join(save_dataset_folder, "ned_slices_config.json") # ======================================================================================= # SLICE DATA # ======================================================================================= log_rank_0_debug(logger, "Loading dataset...") log_rank_0_debug(logger, f"Seeing if {self.save_dataset_name} exists") if data_config.overwrite_preprocessed_data or (not os.path.exists( self.save_dataset_name)): st_time = time.time() try: log_rank_0_info( logger, f"Building dataset from scratch. Saving to {save_dataset_folder}", ) create_examples( dataset, create_ex_indir, create_ex_outdir, meta_file, data_config, dataset_threads, self.slice_names, use_weak_label, split, ) max_alias2pred = utils.load_json_file( meta_file)["max_alias2pred"] convert_examples_to_features_and_save( meta_file, dataset_threads, self.slice_names, self.save_dataset_name, self.get_storage(max_alias2pred), ) utils.dump_json_file(self.save_data_config_name, {"max_alias2pred": max_alias2pred}) log_rank_0_debug( logger, f"Finished prepping data in {time.time() - st_time}") except Exception as e: tb = traceback.TracebackException.from_exception(e) logger.error(e) logger.error("\n".join(tb.stack.format())) shutil.rmtree(save_dataset_folder, ignore_errors=True) raise log_rank_0_info( logger, f"Loading data from {self.save_dataset_name} and {self.save_data_config_name}", ) max_alias2pred = utils.load_json_file( self.save_data_config_name)["max_alias2pred"] self.data, self.sent_to_row_id_dict = self.build_data_dict( self.save_dataset_name, self.get_storage(max_alias2pred)) assert len(self.data) > 0 assert len(self.sent_to_row_id_dict) > 0 log_rank_0_debug(logger, f"Removing temporary output files") shutil.rmtree(temp_output_folder, ignore_errors=True) # Set spawn back to original/default, which is "fork" or "spawn". This is needed for the Meta.config to # be correctly passed in the collate_fn. multiprocessing.set_start_method(orig_spawn, force=True) log_rank_0_info( logger, f"Final slice data initialization time from {split} is {time.time() - global_start}s", )
def run_model(mode, config, run_config_path=None): """ Main run method for Emmental Bootleg models. Args: mode: run mode (train, eval, dump_preds, dump_embs) config: parsed model config run_config_path: original config path (for saving) Returns: """ # Set up distributed backend and save configuration files setup(config, run_config_path) # Load entity symbols log_rank_0_info(logger, f"Loading entity symbols...") entity_symbols = EntitySymbols.load_from_cache( load_dir=os.path.join(config.data_config.entity_dir, config.data_config.entity_map_dir), alias_cand_map_file=config.data_config.alias_cand_map, alias_idx_file=config.data_config.alias_idx_map, ) # Create tasks tasks = [NED_TASK] if config.data_config.type_prediction.use_type_pred is True: tasks.append(TYPE_PRED_TASK) # Create splits for data loaders data_splits = [TRAIN_SPLIT, DEV_SPLIT, TEST_SPLIT] # Slices are for eval so we only split on test/dev slice_splits = [DEV_SPLIT, TEST_SPLIT] # If doing eval, only run on test data if mode in ["eval", "dump_preds", "dump_embs"]: data_splits = [TEST_SPLIT] slice_splits = [TEST_SPLIT] # We only do dumping if weak labels is True if mode in ["dump_preds", "dump_embs"]: if config.data_config[ f"{TEST_SPLIT}_dataset"].use_weak_label is False: raise ValueError( f"When calling dump_preds or dump_embs, we require use_weak_label to be True." ) # Gets embeddings that need to be prepped during data prep or in the __get_item__ method batch_on_the_fly_kg_adj = get_dataloader_embeddings(config, entity_symbols) # Gets dataloaders dataloaders = get_dataloaders( config, tasks, data_splits, entity_symbols, batch_on_the_fly_kg_adj, ) slice_datasets = get_slicedatasets(config, slice_splits, entity_symbols) configure_optimizer(config) # Create models and add tasks if config.model_config.attn_class == "BERTNED": log_rank_0_info(logger, f"Starting NED-Base Model") assert (config.data_config.type_prediction.use_type_pred is False), f"NED-Base does not support type prediction" assert ( config.data_config.word_embedding.use_sent_proj is False ), f"NED-Base requires word_embeddings.use_sent_proj to be False" model = EmmentalModel(name="NED-Base") model.add_tasks( ned_task.create_task(config, entity_symbols, slice_datasets)) else: log_rank_0_info(logger, f"Starting Bootleg Model") model = EmmentalModel(name="Bootleg") # TODO: make this more general for other tasks -- iterate through list of tasks # and add task for each model.add_task( ned_task.create_task(config, entity_symbols, slice_datasets)) if TYPE_PRED_TASK in tasks: model.add_task( type_pred_task.create_task(config, entity_symbols, slice_datasets)) # Add the mention type embedding to the embedding payload type_pred_task.update_ned_task(model) # Print param counts if mode == "train": log_rank_0_debug(logger, "PARAMS WITH GRAD\n" + "=" * 30) total_params = count_parameters(model, requires_grad=True, logger=logger) log_rank_0_info(logger, f"===> Total Params With Grad: {total_params}") log_rank_0_debug(logger, "PARAMS WITHOUT GRAD\n" + "=" * 30) total_params = count_parameters(model, requires_grad=False, logger=logger) log_rank_0_info(logger, f"===> Total Params Without Grad: {total_params}") # Load the best model from the pretrained model if config["model_config"]["model_path"] is not None: model.load(config["model_config"]["model_path"]) # Barrier if config["learner_config"]["local_rank"] == 0: torch.distributed.barrier() # Train model if mode == "train": emmental_learner = EmmentalLearner() emmental_learner._set_optimizer(model) emmental_learner.learn(model, dataloaders) if config.learner_config.local_rank in [0, -1]: model.save(f"{emmental.Meta.log_path}/last_model.pth") # Multi-gpu DataParallel eval (NOT distributed) if mode in ["eval", "dump_embs", "dump_preds"]: # This happens inside EmmentalLearner for training if (config["learner_config"]["local_rank"] == -1 and config["model_config"]["dataparallel"]): model._to_dataparallel() # If just finished training a model or in eval mode, run eval if mode in ["train", "eval"]: scores = model.score(dataloaders) # Save metrics and models log_rank_0_info(logger, f"Saving metrics to {emmental.Meta.log_path}") log_rank_0_info(logger, f"Metrics: {scores}") scores["log_path"] = emmental.Meta.log_path if config.learner_config.local_rank in [0, -1]: write_to_file(f"{emmental.Meta.log_path}/{mode}_metrics.txt", scores) eval_utils.write_disambig_metrics_to_csv( f"{emmental.Meta.log_path}/{mode}_disambig_metrics.csv", scores) return scores # If you want detailed dumps, save model outputs assert mode in [ "dump_preds", "dump_embs", ], 'Mode must be "dump_preds" or "dump_embs"' dump_embs = False if mode != "dump_embs" else True assert ( len(dataloaders) == 1 ), f"We should only have length 1 dataloaders for dump_embs and dump_preds!" final_result_file, final_out_emb_file = None, None if config.learner_config.local_rank in [0, -1]: # Setup files/folders filename = os.path.basename(dataloaders[0].dataset.raw_filename) log_rank_0_debug( logger, f"Collecting sentence to mention map {os.path.join(config.data_config.data_dir, filename)}", ) sentidx2num_mentions, sent_idx2row = eval_utils.get_sent_idx2num_mens( os.path.join(config.data_config.data_dir, filename)) log_rank_0_debug(logger, f"Done collecting sentence to mention map") eval_folder = eval_utils.get_eval_folder(filename) subeval_folder = os.path.join(eval_folder, "batch_results") utils.ensure_dir(subeval_folder) # Will keep track of sentences dumped already. These will only be ones with mentions all_dumped_sentences = set() number_dumped_batches = 0 total_mentions_seen = 0 all_result_files = [] all_out_emb_files = [] # Iterating over batches of predictions for res_i, res_dict in enumerate( eval_utils.batched_pred_iter( model, dataloaders[0], config.run_config.eval_accumulation_steps, sentidx2num_mentions, )): ( result_file, out_emb_file, final_sent_idxs, mentions_seen, ) = eval_utils.disambig_dump_preds( res_i, total_mentions_seen, config, res_dict, sentidx2num_mentions, sent_idx2row, subeval_folder, entity_symbols, dump_embs, NED_TASK, ) all_dumped_sentences.update(final_sent_idxs) all_result_files.append(result_file) all_out_emb_files.append(out_emb_file) total_mentions_seen += mentions_seen number_dumped_batches += 1 # Dump the sentences that had no mentions and were not already dumped # Assert all remaining sentences have no mentions assert all( v == 0 for k, v in sentidx2num_mentions.items() if k not in all_dumped_sentences ), (f"Sentences with mentions were not dumped: " f"{[k for k, v in sentidx2num_mentions.items() if k not in all_dumped_sentences]}" ) empty_sentidx2row = { k: v for k, v in sent_idx2row.items() if k not in all_dumped_sentences } empty_resultfile = eval_utils.get_result_file(number_dumped_batches, subeval_folder) all_result_files.append(empty_resultfile) # Dump the outputs eval_utils.write_data_labels_single( sentidx2row=empty_sentidx2row, output_file=empty_resultfile, filt_emb_data=None, sental2embid={}, alias_cand_map=entity_symbols.get_alias2qids(), qid2eid=entity_symbols.get_qid2eid(), result_alias_offset=total_mentions_seen, train_in_cands=config.data_config.train_in_candidates, max_cands=entity_symbols.max_candidates, dump_embs=dump_embs, ) log_rank_0_info( logger, f"Finished dumping. Merging results across accumulation steps.") # Final result files for labels and embeddings final_result_file = os.path.join(eval_folder, config.run_config.result_label_file) # Copy labels output = open(final_result_file, "wb") for file in all_result_files: shutil.copyfileobj(open(file, "rb"), output) output.close() log_rank_0_info(logger, f"Bootleg labels saved at {final_result_file}") # Try to copy embeddings if dump_embs: final_out_emb_file = os.path.join( eval_folder, config.run_config.result_emb_file) log_rank_0_info( logger, f"Trying to merge numpy embedding arrays. " f"If your machine is limited in memory, this may cause OOM errors. " f"Is that happens, result files should be saved in {subeval_folder}.", ) all_arrays = [] for i, npfile in enumerate(all_out_emb_files): all_arrays.append(np.load(npfile)) np.save(final_out_emb_file, np.concatenate(all_arrays)) log_rank_0_info( logger, f"Bootleg embeddings saved at {final_out_emb_file}") # Cleanup try_rmtree(subeval_folder) return final_result_file, final_out_emb_file
def create_examples( dataset, create_ex_indir, create_ex_outdir, meta_file, data_config, dataset_threads, slice_names, use_weak_label, split, ): """Creates examples from the raw input data. Args: dataset: dataset file create_ex_indir: temporary directory where input files are stored create_ex_outdir: temporary directory to store output files from method meta_file: metadata file to save the file names/paths for the next step in prep pipeline data_config: data config dataset_threads: number of threads slice_names: list of slices to evaluate on use_weak_label: whether to use weak labeling or not split: data split Returns: """ log_rank_0_debug(logger, "Starting to extract subsentences") start = time.time() num_processes = min(dataset_threads, int(0.8 * multiprocessing.cpu_count())) log_rank_0_debug(logger, f"Counting lines") total_input = sum(1 for _ in open(dataset)) if num_processes == 1: out_file_name = os.path.join(create_ex_outdir, os.path.basename(dataset)) constants_dict = { "slice_names": slice_names, "use_weak_label": use_weak_label, "max_aliases": data_config.max_aliases, "split": split, "train_in_candidates": data_config.train_in_candidates, } files_and_counts = {} res = create_examples_single(dataset, total_input, out_file_name, constants_dict) total_output = res["total_lines"] max_alias2pred = res["max_alias2pred"] files_and_counts[res["output_filename"]] = res["total_lines"] else: log_rank_0_info( logger, f"Strating to extract examples with {num_processes} threads") log_rank_0_debug( logger, "Parallelizing with " + str(num_processes) + " threads.") chunk_input = int(np.ceil(total_input / num_processes)) log_rank_0_debug( logger, f"Chunking up {total_input} lines into subfiles of size {chunk_input} lines", ) total_input_from_chunks, input_files_dict = utils.chunk_file( dataset, create_ex_indir, chunk_input) input_files = list(input_files_dict.keys()) input_file_lines = [input_files_dict[k] for k in input_files] output_files = [ in_file_name.replace(create_ex_indir, create_ex_outdir) for in_file_name in input_files ] assert ( total_input == total_input_from_chunks ), f"Lengths of files {total_input} doesn't mathc {total_input_from_chunks}" log_rank_0_debug(logger, f"Done chunking files") pool = multiprocessing.Pool( processes=num_processes, initializer=create_examples_initializer, initargs=[ data_config, slice_names, use_weak_label, data_config.max_aliases, split, data_config.train_in_candidates, ], ) total_output = 0 max_alias2pred = 0 input_args = list(zip(input_files, input_file_lines, output_files)) # Store output files and counts for saving in next step files_and_counts = {} for res in pool.imap_unordered(create_examples_hlp, input_args, chunksize=1): total_output += res["total_lines"] max_alias2pred = max(max_alias2pred, res["max_alias2pred"]) files_and_counts[res["output_filename"]] = res["total_lines"] pool.close() utils.dump_json_file( meta_file, { "num_mentions": total_output, "files_and_counts": files_and_counts, "max_alias2pred": max_alias2pred, }, ) log_rank_0_debug( logger, f"Done with extracting examples in {time.time()-start}. Total lines seen {total_input}. " f"Total lines kept {total_output}.", ) return
def build_type_table(cls, type_labels, type_vocab, max_types, entity_symbols): """Builds the EID to type ids table. Args: type_labels: QID to type ids or type names json mapping type_vocab: type name to type ids max_types: maximum number of types for an entity entity_symbols: entity symbols Returns: torch tensor from EID to type IDS, type ID to row in type embedding matrix, and number of types with unk type """ with open(type_vocab) as f: vocab = json.load(f) all_type_ids = set(list(vocab.values())) assert ( 0 not in all_type_ids ), f"The type id of 0 is reserved for UNK type. Please offset the typeids by 1" # all eids are initially assigned to unk types # if they occur in the type file, then they are assigned the types in the file plus padded types eid2typeids = torch.zeros( entity_symbols.num_entities_with_pad_and_nocand, max_types) eid2typeids[0] = torch.zeros(1, max_types) max_type_id_all = max(all_type_ids) type_hit = 0 type2row_dict = {} with open(type_labels) as f: qid2typeid = json.load(f) for qid, row_types in qid2typeid.items(): if not entity_symbols.qid_exists(qid): continue # assign padded types to the last row typeids = torch.ones(max_types) * -1 if len(row_types) > 0: type_hit += 1 # increment by 1 to account for unk row typeids_list = [] for type_id_or_name in row_types: # If typename, map to typeid if type(type_id_or_name) is str: type_id = vocab[type_id_or_name] else: type_id = type_id_or_name assert ( type_id > 0 ), f"Typeid for {qid} is 0. That is reserved. Please offset by 1" assert (type_id in all_type_ids ), f"Typeid for {qid} isn't in vocab" typeids_list.append(type_id) type2row_dict[type_id] = type_id num_types = min(len(typeids_list), max_types) typeids[:num_types] = torch.tensor( typeids_list)[:num_types] eid2typeids[entity_symbols.get_eid(qid)] = typeids # + 1 bc we need to account for pad row labeled_num_types = max_type_id_all + 1 # assign padded types to the last row of the type embedding # make sure adding type labels doesn't add new types assert (max_type_id_all + 1) <= labeled_num_types eid2typeids[eid2typeids == -1] = labeled_num_types log_rank_0_debug( logger, f"{round(type_hit/entity_symbols.num_entities, 2)*100}% of entities are assigned types", ) return eid2typeids.long(), type2row_dict, labeled_num_types
def build_static_embeddings(cls, emb_file, entity_symbols): """Builds the table of the embedding associated with each entity. Args: emb_file: embedding file to load entity_symbols: entity symbols Returns: numpy embedding matrix where each row is an emedding """ ending = os.path.splitext(emb_file)[1] found = 0 raw_num_ents = 0 if ending == ".json": dct = utils.load_json_file(emb_file) val = next(iter(dct.values())) if type(val) is int or type(val) is float: embedding_size = 1 conver_func = lambda x: np.array([x]) elif type(val) is list: embedding_size = len(val) conver_func = lambda x: np.array([y for y in x]) else: raise ValueError( f"Unrecognized type for the array value of {type(val)}" ) embeddings = {} for k in dct: embeddings[k] = conver_func(dct[k]) assert len(embeddings[k]) == embedding_size entity2staticemb_table = np.zeros( (entity_symbols.num_entities_with_pad_and_nocand, embedding_size) ) raw_num_ents = len(embeddings) for qid in tqdm(entity_symbols.get_all_qids()): if qid in embeddings: found += 1 emb = embeddings[qid] eid = entity_symbols.get_eid(qid) entity2staticemb_table[eid, :embedding_size] = emb elif ending == ".pt": log_rank_0_debug( logger, f"We are readining in the embedding file from a .pt. We assume this is already mapped to eids", ) (qid2eid_map, entity2staticemb_table_raw) = torch.load(emb_file) entity2staticemb_table_raw = ( entity2staticemb_table_raw.detach().cpu().numpy() ) raw_num_ents = entity2staticemb_table_raw.shape[0] # +2 handles the PAD and UNK entities assert entity2staticemb_table_raw.shape[0] == len(qid2eid_map) + 2, ( f"The saved static embeddings file had mismatched shapes between qid2eid {len(qid2eid_map)} and " f"weights {entity2staticemb_table_raw.shape[0]}" ) entity2staticemb_table = np.zeros( ( entity_symbols.num_entities_with_pad_and_nocand, entity2staticemb_table_raw.shape[1], ) ) found = 0 for qid in tqdm(entity_symbols.get_all_qids()): if qid in qid2eid_map: found += 1 raw_eid = qid2eid_map[qid] emb = entity2staticemb_table_raw[raw_eid] new_eid = entity_symbols.get_eid(qid) entity2staticemb_table[new_eid, :] = emb else: raise ValueError( f"We do not support static embeddings from {ending}. We only support .json and .pt" ) log_rank_0_debug( logger, f"Found {found} ({found/len(entity_symbols.get_all_qids())} percent) of all entities after " f"reading {raw_num_ents} original entities have a static embedding", ) return entity2staticemb_table
def __init__( self, main_args, emb_args, entity_symbols, key, cpu, normalize, dropout1d_perc, dropout2d_perc, ): super(TopKEntityEmb, self).__init__( main_args=main_args, emb_args=emb_args, entity_symbols=entity_symbols, key=key, cpu=cpu, normalize=normalize, dropout1d_perc=dropout1d_perc, dropout2d_perc=dropout2d_perc, ) allowable_keys = { "learned_embedding_size", "perc_emb_drop", "qid2topk_eid", "regularize_mapping", "tail_init", "tail_init_zeros", } correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args) if not correct: raise ValueError(f"The key {bad_key} is not in {allowable_keys}") assert ( "learned_embedding_size" in emb_args ), f"TopKEntityEmb must have learned_embedding_size in args" assert "perc_emb_drop" in emb_args, ( f"To use TopKEntityEmb we need perc_emb_drop to be in the args. This gives the percentage of embeddings" f" removed." ) self.learned_embedding_size = emb_args.learned_embedding_size # We remove perc_emb_drop percent of the embeddings and add one to represent the new toes embedding num_topk_entities_with_pad_and_nocand = ( entity_symbols.num_entities_with_pad_and_nocand - int(emb_args.perc_emb_drop * entity_symbols.num_entities) + 1 ) # Mapping of entity to the new eid mapping eid2topkeid = torch.arange(0, entity_symbols.num_entities_with_pad_and_nocand) # There are issues with using -1 index into the embeddings; so we manually set it to be the last value eid2topkeid[-1] = num_topk_entities_with_pad_and_nocand - 1 if "qid2topk_eid" not in emb_args: assert self.from_pretrained, ( f"If you don't provide the qid2topk_eid mapping as an argument to TopKEntityEmb, " f"you must be loading a model from a checkpoint to build this index mapping" ) self.learned_entity_embedding = nn.Embedding( num_topk_entities_with_pad_and_nocand, self.learned_embedding_size, padding_idx=-1, sparse=False, ) self._dim = main_args.model_config.hidden_size if "regularize_mapping" in emb_args: eid2reg = torch.zeros(num_topk_entities_with_pad_and_nocand) else: eid2reg = None # If tail_init is false, all embeddings are randomly intialized. # If tail_init is true, we initialize all embeddings to be the same. self.tail_init = True self.tail_init_zeros = False # None init vec will be random init_vec = None if not self.from_pretrained: qid2topk_eid = utils.load_json_file(emb_args.qid2topk_eid) assert ( len(qid2topk_eid) == entity_symbols.num_entities ), f"You must have an item in qid2topk_eid for each qid in entity_symbols" for qid in entity_symbols.get_all_qids(): old_eid = entity_symbols.get_eid(qid) new_eid = qid2topk_eid[qid] eid2topkeid[old_eid] = new_eid assert eid2topkeid[0] == 0, f"The 0 eid shouldn't be changed" assert ( eid2topkeid[-1] == num_topk_entities_with_pad_and_nocand - 1 ), "The -1 eid should still map to -1" if "tail_init" in emb_args: self.tail_init = emb_args.tail_init if "tail_init_zeros" in emb_args: self.tail_init_zeros = emb_args.tail_init_zeros self.tail_init = False init_vec = torch.zeros(1, self.learned_embedding_size) assert not ( self.tail_init and self.tail_init_zeros ), f"Can only have one of tail_init or tail_init_zeros set" if self.tail_init or self.tail_init_zeros: if self.tail_init_zeros: log_rank_0_debug( logger, f"All learned entity embeddings are initialized to zero.", ) else: log_rank_0_debug( logger, f"All learned entity embeddings are initialized to the same value.", ) init_vec = model_utils.init_embeddings_to_vec( self.learned_entity_embedding, pad_idx=-1, vec=init_vec ) vec_save_file = os.path.join( emmental.Meta.log_path, "init_vec_entity_embs.npy" ) log_rank_0_debug(logger, f"Saving init vector to {vec_save_file}") if ( torch.distributed.is_initialized() and torch.distributed.get_rank() == 0 ): np.save(vec_save_file, init_vec) else: log_rank_0_debug( logger, f"All learned embeddings are randomly initialized." ) # Regularization mapping goes from eid to 2d dropout percent if "regularize_mapping" in emb_args: log_rank_0_debug( logger, f"You are using regularization mapping with a topK entity embedding. " f"This means all QIDs that are mapped to the same" f" EID will get the same regularization value.", ) if self.dropout1d_perc > 0 or self.dropout2d_perc > 0: log_rank_0_debug( logger, f"You have 1D or 2D regularization set with a regularize_mapping. Do you mean to do this?", ) log_rank_0_debug( logger, f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}", ) eid2reg = self.load_regularization_mapping( main_args.data_config, entity_symbols, qid2topk_eid, num_topk_entities_with_pad_and_nocand, emb_args.regularize_mapping, ) # Keep this mapping so a topK model can simply be loaded without needing the new eid mapping self.register_buffer("eid2topkeid", eid2topkeid) self.register_buffer("eid2reg", eid2reg)
def __init__( self, main_args, emb_args, entity_symbols, key, cpu, normalize, dropout1d_perc, dropout2d_perc, ): super(LearnedEntityEmb, self).__init__( main_args=main_args, emb_args=emb_args, entity_symbols=entity_symbols, key=key, cpu=cpu, normalize=normalize, dropout1d_perc=dropout1d_perc, dropout2d_perc=dropout2d_perc, ) allowable_keys = { "learned_embedding_size", "regularize_mapping", "tail_init", "tail_init_zeros", } correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args) if not correct: raise ValueError(f"The key {bad_key} is not in {allowable_keys}") assert ( "learned_embedding_size" in emb_args ), f"LearnedEntityEmb must have learned_embedding_size in args" self.learned_embedding_size = emb_args.learned_embedding_size # Set sparsity based on optimizer and fp16. The None optimizer is Bootleg's SparseDenseAdam. # If fp16 is True, must use dense. optimiz = main_args.learner_config.optimizer_config.optimizer if optimiz in [None, "sparse_adam"] and main_args.learner_config.fp16 is False: sparse = True else: sparse = False if ( torch.distributed.is_initialized() and main_args.model_config.distributed_backend == "nccl" ): sparse = False log_rank_0_debug( logger, f"Setting sparsity for entity embeddings to be {sparse}" ) self.learned_entity_embedding = nn.Embedding( entity_symbols.num_entities_with_pad_and_nocand, self.learned_embedding_size, padding_idx=-1, sparse=sparse, ) self._dim = main_args.model_config.hidden_size if "regularize_mapping" in emb_args: eid2reg = torch.zeros(entity_symbols.num_entities_with_pad_and_nocand) else: eid2reg = None # If tail_init is false, all embeddings are randomly intialized. # If tail_init is true, we initialize all embeddings to be the same. self.tail_init = True self.tail_init_zeros = False # None init vec will be random init_vec = None if not self.from_pretrained: if "tail_init" in emb_args: self.tail_init = emb_args.tail_init if "tail_init_zeros" in emb_args: self.tail_init_zeros = emb_args.tail_init_zeros self.tail_init = False init_vec = torch.zeros(1, self.learned_embedding_size) assert not ( self.tail_init and self.tail_init_zeros ), f"Can only have one of tail_init or tail_init_zeros set" if self.tail_init or self.tail_init_zeros: if self.tail_init_zeros: log_rank_0_debug( logger, f"All learned entity embeddings are initialized to zero.", ) else: log_rank_0_debug( logger, f"All learned entity embeddings are initialized to the same value.", ) init_vec = model_utils.init_embeddings_to_vec( self.learned_entity_embedding, pad_idx=-1, vec=init_vec ) vec_save_file = os.path.join( emmental.Meta.log_path, "init_vec_entity_embs.npy" ) log_rank_0_debug(logger, f"Saving init vector to {vec_save_file}") if ( torch.distributed.is_initialized() and torch.distributed.get_rank() == 0 ): np.save(vec_save_file, init_vec) else: log_rank_0_debug( logger, f"All learned embeddings are randomly initialized." ) # Regularization mapping goes from eid to 2d dropout percent if "regularize_mapping" in emb_args: if self.dropout1d_perc > 0 or self.dropout2d_perc > 0: log_rank_0_debug( logger, f"You have 1D or 2D regularization set with a regularize_mapping. Do you mean to do this?", ) log_rank_0_debug( logger, f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}", ) eid2reg = self.load_regularization_mapping( main_args.data_config, entity_symbols, emb_args.regularize_mapping, ) self.register_buffer("eid2reg", eid2reg)
def merge_subsentences( num_processes, subset_sent_idx2num_mens, cache_folder, to_save_file, to_save_storage, to_read_file, to_read_storage, dump_embs=False, ): """Flatten all sentences back together over sub-sentences; removing the PAD aliases from the data I.e., converts from sent_idx -> array of values to (sent_idx, alias_idx) -> value with varying numbers of aliases per sentence. Args: num_processes: number of processes subset_sent_idx2num_mens: Dict of sentence index to number of mentions for this batch cache_folder: cache directory to_save_file: memmap file to save results to to_save_storage: save file storage type to_read_file: memmap file to read predictions from to_read_storage: read file storage type dump_embs: whether to save embeddings or not Returns: """ # Compute sent idx to offset so we know where to fill in mentions cur_offset = 0 sentidx2offset = {} for k, v in subset_sent_idx2num_mens.items(): sentidx2offset[k] = cur_offset cur_offset += v # print("Sent Idx, Num Mens, Offset", k, v, cur_offset) total_num_mentions = cur_offset # print("TOTAL", total_num_mentions) full_pred_data = np.memmap(to_read_file, dtype=to_read_storage, mode="r") M = int(full_pred_data[0]["M"]) K = int(full_pred_data[0]["K"]) hidden_size = int(full_pred_data[0]["hidden_size"]) # print("TOTAL MENS", total_num_mentions) filt_emb_data = np.memmap(to_save_file, dtype=to_save_storage, mode="w+", shape=(total_num_mentions, )) filt_emb_data["hidden_size"] = hidden_size filt_emb_data["sent_idx"][:] = -1 filt_emb_data["alias_list_pos"][:] = -1 all_ids = list(range(0, len(full_pred_data))) start = time.time() if num_processes == 1: seen_ids = merge_subsentences_single( M, K, hidden_size, dump_embs, all_ids, filt_emb_data, full_pred_data, sentidx2offset, ) else: # Get trie for sentence start map trie_folder = os.path.join(cache_folder, "bootleg_sent_idx2num_mens") utils.ensure_dir(trie_folder) trie_file = os.path.join(trie_folder, "sentidx.marisa") utils.create_single_item_trie(sentidx2offset, out_file=trie_file) # Chunk up date chunk_size = int(np.ceil(len(full_pred_data) / num_processes)) row_idx_set_chunks = [ all_ids[ids:ids + chunk_size] for ids in range(0, len(full_pred_data), chunk_size) ] # Start pool input_args = [[M, K, hidden_size, dump_embs, chunk] for chunk in row_idx_set_chunks] log_rank_0_debug( logger, f"Merging sentences together with {num_processes} processes") pool = multiprocessing.Pool( processes=num_processes, initializer=merge_subsentences_initializer, initargs=[ to_save_file, to_save_storage, to_read_file, to_read_storage, trie_file, ], ) seen_ids = set() for sent_ids_seen in pool.imap_unordered(merge_subsentences_hlp, input_args, chunksize=1): for emb_id in sent_ids_seen: assert ( emb_id not in seen_ids ), f"{emb_id} already seen, something went wrong with sub-sentences" seen_ids.add(emb_id) # filt_emb_data = np.memmap(to_save_file, dtype=to_save_storage, mode="r") # for i in range(len(filt_emb_data)): # si = filt_emb_data[i]["sent_idx"] # al_test = filt_emb_data[i]["alias_list_pos"] # if si == -1 or al_test == -1: # print("BAD", i, filt_emb_data[i]) # import ipdb; ipdb.set_trace() logging.debug(f"Saw {len(seen_ids)} sentences") logging.debug(f"Time to merge sub-sentences {time.time() - start}s") return
def write_data_labels( num_processes, result_alias_offset, merged_entity_emb_file, merged_storage_type, sent_idx2row, cache_folder, out_file, entity_dump, train_in_candidates, max_candidates, dump_embs, trie_candidate_map_folder=None, trie_qid2eid_file=None, ): """Takes the flattened data from merge_sentences and writes out predictions to a file, one line per sentence. The embedding ids are added to the file if dump_embs is True. Args: num_processes: number of processes result_alias_offset: alias offset of this batch of examples for writing out merged_entity_emb_file: input memmap file after merge sentences merged_storage_type: input file storage type sent_idx2row: Dict of sentence idx to row relevant to this subbatch cache_folder: folder to save temporary outputs out_file: final output file for predictions entity_dump: entity dump train_in_candidates: whether NC entities are not in candidate lists max_candidates: maximum number of candidates dump_embs: whether to dump embeddings or not trie_candidate_map_folder: folder where trie of alias->candidate map is stored for parallel proccessing trie_qid2eid_file: file where trie of qid->eid map is stored for parallel proccessing Returns: """ st = time.time() sental2embid = get_sental2embid(merged_entity_emb_file, merged_storage_type) log_rank_0_debug(logger, f"Finished getting sentence map {time.time() - st}s") total_input = len(sent_idx2row) if num_processes == 1: filt_emb_data = np.memmap(merged_entity_emb_file, dtype=merged_storage_type, mode="r+") write_data_labels_single( sentidx2row=sent_idx2row, output_file=out_file, filt_emb_data=filt_emb_data, sental2embid=sental2embid, alias_cand_map=entity_dump.get_alias2qids(), qid2eid=entity_dump.get_qid2eid(), result_alias_offset=result_alias_offset, train_in_cands=train_in_candidates, max_cands=max_candidates, dump_embs=dump_embs, ) else: assert ( trie_candidate_map_folder is not None ), "trie_candidate_map_folder is None and you have parallel turned on" assert (trie_qid2eid_file is not None ), "trie_qid2eid_file is None and you have parallel turned on" # Get trie of sentence map trie_folder = os.path.join(cache_folder, "bootleg_sental2embid") utils.ensure_dir(trie_folder) trie_file = os.path.join(trie_folder, "sentidx.marisa") utils.create_single_item_trie(sental2embid, out_file=trie_file) # Chunk file for parallel writing # We do not use TemporaryFolders as the temp dir may not have enough space for large files create_ex_indir = os.path.join(cache_folder, "_bootleg_eval_temp_indir") utils.ensure_dir(create_ex_indir) create_ex_outdir = os.path.join(cache_folder, "_bootleg_eval_temp_outdir") utils.ensure_dir(create_ex_outdir) chunk_input = int(np.ceil(total_input / num_processes)) logger.debug( f"Chunking up {total_input} lines into subfiles of size {chunk_input} lines" ) # Chunk up dictionary of data for parallel processing input_files = [] i = 0 cur_lines = 0 file_split = os.path.join(create_ex_indir, f"out{i}.jsonl") open_file = open(file_split, "w") for s_idx in sent_idx2row: if cur_lines >= chunk_input: open_file.close() input_files.append(file_split) cur_lines = 0 i += 1 file_split = os.path.join(create_ex_indir, f"out{i}.jsonl") open_file = open(file_split, "w") line = sent_idx2row[s_idx] open_file.write(ujson.dumps(line) + "\n") cur_lines += 1 open_file.close() input_files.append(file_split) # Generation input/output pairs output_files = [ in_file_name.replace(create_ex_indir, create_ex_outdir) for in_file_name in input_files ] log_rank_0_debug(logger, f"Done chunking files. Starting pool") pool = multiprocessing.Pool( processes=num_processes, initializer=write_data_labels_initializer, initargs=[ merged_entity_emb_file, merged_storage_type, trie_file, result_alias_offset, train_in_candidates, max_candidates, dump_embs, trie_candidate_map_folder, trie_qid2eid_file, ], ) input_args = list(zip(input_files, output_files)) total = 0 for res in pool.imap(write_data_labels_hlp, input_args, chunksize=1): total += 1 # Merge output files to final file log_rank_0_debug(logger, f"Merging output files") with open(out_file, "wb") as outfile: for filename in glob.glob(os.path.join(create_ex_outdir, "*")): if filename == out_file: # don't want to copy the output into the output continue with open(filename, "rb") as readfile: shutil.copyfileobj(readfile, outfile)
def disambig_dump_preds( result_idx, result_alias_offset, config, res_dict, sent_idx2num_mens, sent_idx2row, save_folder, entity_symbols, dump_embs, task_name, ): """Dumps the predictions of a disambiguation task. Args: result_idx: batch index of the result arrays result_alias_offset: overall offset of the starting example (i.e., the number of previous mens already written) config: model config res_dict: result dictionary from Emmental predict sent_idx2num_mens: Dict sentence idx to number of mentions sent_idx2row: Dict sentence idx to row of eval data save_folder: folder to save results entity_symbols: entity symbols dump_embs: whether to save the contextualized embeddings or not task_name: task name Returns: saved prediction file, saved embedding file (will be None if dump_embs is False) """ num_processes = min(config.run_config.dataset_threads, int(multiprocessing.cpu_count() * 0.9)) cache_dir = os.path.join(save_folder, f"cache_{result_idx}") utils.ensure_dir(cache_dir) trie_candidate_map_folder = None trie_qid2eid_file = None # Save the alias->QID candidate map and the QID->EID mapping in memory efficient structures for faster # prediction dumping if num_processes > 1: entity_prep_dir = data_utils.get_emb_prep_dir(config.data_config) trie_candidate_map_folder = os.path.join(entity_prep_dir, "for_dumping_preds", "alias_cand_trie") utils.ensure_dir(trie_candidate_map_folder) check_and_create_alias_cand_trie(trie_candidate_map_folder, entity_symbols) trie_qid2eid_file = os.path.join(entity_prep_dir, "for_dumping_preds", "qid2eid_trie.marisa") if not os.path.exists(trie_qid2eid_file): utils.create_single_item_trie(entity_symbols.get_qid2eid(), out_file=trie_qid2eid_file) # This is dumping disambig_res_dict = {} for k in res_dict: assert task_name in res_dict[ k], f"{task_name} not in res_dict for key {k}" disambig_res_dict[k] = res_dict[k][task_name] # write to file (M x hidden x size for each data point -- next step will deal with recovering original sentence # indices for overflowing sentences) unmerged_entity_emb_file = os.path.join(save_folder, f"entity_embs.pt") merged_entity_emb_file = os.path.join(save_folder, f"entity_embs_unmerged.pt") emb_file_config = os.path.splitext( unmerged_entity_emb_file)[0] + "_config.npy" M = config.data_config.max_aliases K = entity_symbols.max_candidates + ( not config.data_config.train_in_candidates) if dump_embs: unmerged_storage_type = np.dtype([ ("M", int), ("K", int), ("hidden_size", int), ("sent_idx", int), ("subsent_idx", int), ("alias_list_pos", int, (M, )), ("entity_emb", float, M * config.model_config.hidden_size), ("final_loss_true", int, (M, )), ("final_loss_pred", int, (M, )), ("final_loss_prob", float, (M, )), ("final_loss_cand_probs", float, M * K), ]) merged_storage_type = np.dtype([ ("hidden_size", int), ("sent_idx", int), ("alias_list_pos", int), ("entity_emb", float, config.model_config.hidden_size), ("final_loss_pred", int), ("final_loss_prob", float), ("final_loss_cand_probs", float, K), ]) else: # don't need to extract contextualized entity embedding unmerged_storage_type = np.dtype([ ("M", int), ("K", int), ("hidden_size", int), ("sent_idx", int), ("subsent_idx", int), ("alias_list_pos", int, (M, )), ("final_loss_true", int, (M, )), ("final_loss_pred", int, (M, )), ("final_loss_prob", float, (M, )), ("final_loss_cand_probs", float, M * K), ]) merged_storage_type = np.dtype([ ("hidden_size", int), ("sent_idx", int), ("alias_list_pos", int), ("final_loss_pred", int), ("final_loss_prob", float), ("final_loss_cand_probs", float, K), ]) mmap_file = np.memmap( unmerged_entity_emb_file, dtype=unmerged_storage_type, mode="w+", shape=(len(disambig_res_dict["uids"]), ), ) # print("MEMMAP FILE SHAPE", len(disambig_res_dict["uids"])) # Init sent_idx to -1 for debugging mmap_file[:]["sent_idx"] = -1 np.save(emb_file_config, unmerged_storage_type, allow_pickle=True) log_rank_0_debug( logger, f"Created file {unmerged_entity_emb_file} to save predictions.") log_rank_0_debug(logger, f'{len(disambig_res_dict["uids"])} samples') for_iteration = [ disambig_res_dict["uids"], disambig_res_dict["golds"], disambig_res_dict["probs"], disambig_res_dict["preds"], ] all_sent_idx = set() for i, (uid, gold, probs, model_pred) in enumerate(zip(*for_iteration)): # disambig_res_dict["output"] is dict with keys ['_input__alias_orig_list_pos', # 'bootleg_pred_1', '_input__sent_idx', '_input__for_dump_gold_cand_K_idx_train', '_input__subsent_idx', 0, 1] sent_idx = disambig_res_dict["outputs"]["_input__sent_idx"][i] # print("INSIDE LOOP", sent_idx, "AT", i) subsent_idx = disambig_res_dict["outputs"]["_input__subsent_idx"][i] alias_orig_list_pos = disambig_res_dict["outputs"][ "_input__alias_orig_list_pos"][i] gold_cand_K_idx_train = disambig_res_dict["outputs"][ "_input__for_dump_gold_cand_K_idx_train"][i] output_embeddings = disambig_res_dict["outputs"][ f"{PRED_LAYER}_ent_embs"][i] mmap_file[i]["M"] = M mmap_file[i]["K"] = K mmap_file[i]["hidden_size"] = config.model_config.hidden_size mmap_file[i]["sent_idx"] = sent_idx mmap_file[i]["subsent_idx"] = subsent_idx mmap_file[i]["alias_list_pos"] = alias_orig_list_pos # This will give all aliases seen by the model during training, independent of if it's gold or not mmap_file[i][f"final_loss_true"] = gold_cand_K_idx_train.reshape(M) # get max for each alias, probs is M x K max_probs = probs.max(axis=1) pred_cands = probs.argmax(axis=1) mmap_file[i]["final_loss_pred"] = pred_cands mmap_file[i]["final_loss_prob"] = max_probs mmap_file[i]["final_loss_cand_probs"] = probs.reshape(1, -1) all_sent_idx.add(str(sent_idx)) # final_entity_embs is M x K x hidden_size, pred_cands is M if dump_embs: chosen_entity_embs = select_embs(embs=output_embeddings, pred_cands=pred_cands, M=M) # write chosen entity embs to file for contextualized entity embeddings mmap_file[i]["entity_emb"] = chosen_entity_embs.reshape(1, -1) # for i in range(len(mmap_file)): # si = mmap_file[i]["sent_idx"] # if -1 == si: # import pdb # pdb.set_trace() # assert si != -1, f"{i} {mmap_file[i]}" # Store all predicted sentences to filter the sentence mapping by subset_sent_idx2num_mens = { k: v for k, v in sent_idx2num_mens.items() if k in all_sent_idx } # print("ALL SEEN", all_sent_idx) subsent_sent_idx2row = { k: v for k, v in sent_idx2row.items() if k in all_sent_idx } result_file = get_result_file(result_idx, save_folder) log_rank_0_debug(logger, f"Writing predictions to {result_file}...") merge_subsentences( num_processes=num_processes, subset_sent_idx2num_mens=subset_sent_idx2num_mens, cache_folder=cache_dir, to_save_file=merged_entity_emb_file, to_save_storage=merged_storage_type, to_read_file=unmerged_entity_emb_file, to_read_storage=unmerged_storage_type, dump_embs=dump_embs, ) write_data_labels( num_processes=num_processes, result_alias_offset=result_alias_offset, merged_entity_emb_file=merged_entity_emb_file, merged_storage_type=merged_storage_type, sent_idx2row=subsent_sent_idx2row, cache_folder=cache_dir, out_file=result_file, entity_dump=entity_symbols, train_in_candidates=config.data_config.train_in_candidates, max_candidates=entity_symbols.max_candidates, dump_embs=dump_embs, trie_candidate_map_folder=trie_candidate_map_folder, trie_qid2eid_file=trie_qid2eid_file, ) out_emb_file = None filt_emb_data = np.memmap(merged_entity_emb_file, dtype=merged_storage_type, mode="r+") total_mentions_seen = len(filt_emb_data) # save easier-to-use embedding file if dump_embs: hidden_size = filt_emb_data[0]["hidden_size"] out_emb_file = get_emb_file(result_idx, save_folder) np.save(out_emb_file, filt_emb_data["entity_emb"].reshape(-1, hidden_size)) log_rank_0_debug( logger, f"Saving contextual entity embeddings for {result_idx} to {out_emb_file}", ) filt_emb_data = None # Cleanup cache - sometimes the file in cache_dir is still open so we need to retry to delete it try_rmtree(cache_dir) log_rank_0_debug(logger, f"Wrote predictions for {result_idx} to {result_file}") return result_file, out_emb_file, all_sent_idx, total_mentions_seen
def load_regularization_mapping(cls, data_config, num_types_with_pad_and_unk, type2row_dict, reg_file): """Reads in a csv file with columns [qid, regularization]. In the forward pass, the entity id with associated qid will be regularized with probability regularization. Args: data_config: data config num_entities_with_pad_and_nocand: number of types including pad and null option type2row_dict: Dict from typeID to row id in the type embedding matrix reg_file: regularization csv file Returns: Tensor where each value is the regularization value for EID """ reg_str = os.path.splitext(os.path.basename(reg_file.replace("/", "_")))[0] prep_dir = data_utils.get_data_prep_dir(data_config) prep_file = os.path.join(prep_dir, f"type_regularization_mapping_{reg_str}.pt") utils.ensure_dir(os.path.dirname(prep_file)) log_rank_0_debug(logger, f"Looking for regularization mapping in {prep_file}") if not data_config.overwrite_preprocessed_data and os.path.exists( prep_file): log_rank_0_debug( logger, f"Loading existing entity regularization mapping from {prep_file}", ) start = time.time() typeid2reg = torch.load(prep_file) log_rank_0_debug( logger, f"Loaded existing entity regularization mapping in {round(time.time() - start, 2)}s", ) else: start = time.time() log_rank_0_debug( logger, f"Building entity regularization mapping from {reg_file}") typeid2reg_raw = pd.read_csv(reg_file) assert ( "typeid" in typeid2reg_raw.columns and "regularization" in typeid2reg_raw.columns ), f"Expected typeid and regularization as the column names for {reg_file}" # default of no mask typeid2reg_arr = [0.0] * num_types_with_pad_and_unk for row_idx, row in typeid2reg_raw.iterrows(): # Happens when we filter QIDs not in our entity db and the max typeid is smaller than the total number if int(row["typeid"]) not in type2row_dict: continue typeid = type2row_dict[int(row["typeid"])] typeid2reg_arr[typeid] = row["regularization"] typeid2reg = torch.Tensor(typeid2reg_arr) torch.save(typeid2reg, prep_file) log_rank_0_debug( logger, f"Finished building and saving entity regularization mapping in {round(time.time() - start, 2)}s.", ) return typeid2reg
def __init__( self, main_args, emb_args, entity_symbols, key, cpu, normalize, dropout1d_perc, dropout2d_perc, ): super(TypeEmb, self).__init__( main_args=main_args, emb_args=emb_args, entity_symbols=entity_symbols, key=key, cpu=cpu, normalize=normalize, dropout1d_perc=dropout1d_perc, dropout2d_perc=dropout2d_perc, ) allowable_keys = { "max_types", "type_dim", "type_labels", "type_vocab", "merge_func", "attn_hidden_size", "regularize_mapping", } correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args) if not correct: raise ValueError(f"The key {bad_key} is not in {allowable_keys}") assert ( "max_types" in emb_args), "Type embedding requires max_types to be set in args" assert ( "type_dim" in emb_args), "Type embedding requires type_dim to be set in args" assert ( "type_labels" in emb_args ), "Type embedding requires type_labels to be set in args. A Dict from QID -> TypeId or TypeName" assert ( "type_vocab" in emb_args ), "Type embedding requires type_vocab to be set in args. A Dict from TypeName -> TypeId" assert (self.cpu is False ), f"We don't support putting type embeddings on CPU right now" self.merge_func = self.average_types self.orig_dim = emb_args.type_dim self.add_attn = None # Function for merging multiple types if "merge_func" in emb_args: assert emb_args.merge_func in [ "average", "addattn", ], (f"{key}: You have set the type merge_func to be {emb_args.merge_func} but" f" that is not in the allowable list of [average, addattn]") if emb_args.merge_func == "addattn": if "attn_hidden_size" in emb_args: attn_hidden_size = emb_args.attn_hidden_size else: attn_hidden_size = 100 # Softmax of types using the sentence context self.add_attn = PositionAwareAttention( input_size=self.orig_dim, attn_size=attn_hidden_size, feature_size=0) self.merge_func = self.add_attn_merge self.max_types = emb_args.max_types ( eid2typeids_table, self.type2row_dict, num_types_with_unk, self.prep_file, ) = self.prep( data_config=main_args.data_config, emb_args=emb_args, entity_symbols=entity_symbols, ) self.register_buffer("eid2typeids_table", eid2typeids_table, persistent=False) # self.eid2typeids_table.requires_grad = False self.num_types_with_pad_and_unk = num_types_with_unk + 1 # Regularization mapping goes from typeid to 2d dropout percent if "regularize_mapping" in emb_args: typeid2reg = torch.zeros(self.num_types_with_pad_and_unk) else: typeid2reg = None if not self.from_pretrained: if "regularize_mapping" in emb_args: if self.dropout1d_perc > 0 or self.dropout2d_perc > 0: logger.warning( f"You have 1D or 2D regularization set with a regularize_mapping. Do you mean to do this?" ) log_rank_0_info( logger, f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}", ) typeid2reg = self.load_regularization_mapping( main_args.data_config, self.num_types_with_pad_and_unk, self.type2row_dict, emb_args.regularize_mapping, ) self.register_buffer("typeid2reg", typeid2reg) assert self.eid2typeids_table.shape[1] == emb_args.max_types, ( f"Something went wrong with loading type file." f" The given max types {emb_args.max_types} does not match that " f"of type table {self.eid2typeids_table.shape[1]}") log_rank_0_debug( logger, f"{key}: Type embedding with {self.max_types} types with dim {self.orig_dim}. " f"Setting merge_func to be {self.merge_func.__name__} in type emb.", )
def get_dataloader_embeddings(main_args, entity_symbols): """Gets KG embeddings that need to be processed in the __get_item__ method of a dataset (e.g., querying a sparce numpy matrix). We save, for each KG embedding class that needs this preprocessing, the adjacency matrix (for KG connections), the processing function to run in __get_item__, and the file to load the adj matrix for dumping/loading. Args: main_args: main arguments entity_symbols: entity symbols Returns: Dict of KG metadata for using in the __get_item__ method. """ batch_on_the_fly_kg_adj = {} for emb in main_args.data_config.ent_embeddings: batch_on_fly = "batch_on_the_fly" in emb and emb[ "batch_on_the_fly"] is True # Find embeddings that have a "batch of the fly" key if batch_on_fly: log_rank_0_debug( logger, f"Loading class {emb.load_class} for preprocessing as on the fly or in data prep embeddings", ) ( cpu, dropout1d_perc, dropout2d_perc, emb_args, freeze, normalize, through_bert, ) = embedding_utils.get_embedding_args(emb) try: # Load the object mod, load_class = import_class("bootleg.embeddings", emb.load_class) kg_class = getattr(mod, load_class)( main_args=main_args, emb_args=emb_args, entity_symbols=entity_symbols, key=emb.key, cpu=cpu, normalize=normalize, dropout1d_perc=dropout1d_perc, dropout2d_perc=dropout2d_perc, ) # Extract its kg adj, we'll use this later # Extract the kg_adj_process_func (how to process the embeddings in __get_item__ or dataset prep) # Extract the prep_file. We use this to load the kg_adj back after # saving/loading state using scipy.sparse.load_npz(prep_file) assert hasattr( kg_class, "kg_adj" ), f"The embedding class {emb.key} does not have a kg_adj attribute and it needs to." assert hasattr( kg_class, "kg_adj_process_func" ), f"The embedding class {emb.key} does not have a kg_adj_process_func attribute and it needs to." assert hasattr(kg_class, "prep_file"), ( f"The embedding class {emb.key} does not have a prep_file attribute and it needs to. We will call" f" `scipy.sparse.load_npz(prep_file)` to load the kg_adj matrix." ) batch_on_the_fly_kg_adj[emb.key] = { "kg_adj": kg_class.kg_adj, "kg_adj_process_func": kg_class.kg_adj_process_func, "prep_file": kg_class.prep_file, } except AttributeError as e: logger.warning( f"No prep method found for {emb.load_class} with error {e}" ) raise except Exception as e: print("ERROR", e) raise return batch_on_the_fly_kg_adj
def build_kg_adj(cls, kg_adj_file, entity_symbols, threshold, log_weight): """Builds the KG adjacency matrix from inputs. Args: kg_adj_file: KG adjacency file entity_symbols: entity symbols threshold: weight threshold to count as an edge log_weight: whether to log the weight after the threshold Returns: KG adjacency """ G = nx.Graph() qids = set(entity_symbols.get_all_qids()) edges_to_add = [] num_added = 0 num_total = 0 file_ending = os.path.splitext(kg_adj_file)[1][1:] # Get ending to determine if txt or json file assert file_ending in [ "json", "txt", ], f"We only support loading txt or json files for edge weights. You provided {file_ending}" with open(kg_adj_file) as f: if file_ending == "json": all_edges = json.load(f) for head in all_edges: for tail in all_edges[head]: weight = all_edges[head][tail] num_total += 1 if head in qids and tail in qids and weight > threshold: num_added += 1 if log_weight: edges_to_add.append( (head, tail, np.log(weight))) else: edges_to_add.append((head, tail, weight)) else: for line in f: splt = line.strip().split() if len(splt) == 2: head, tail = splt weight = 1.0 elif len(splt) == 3: head, tail, weight = splt else: raise ValueError( f"A line {line} in {kg_adj_file} has not 2 or 3 values after called split()." ) num_total += 1 # head and tail must be in list of qids if head in qids and tail in qids and weight > threshold: num_added += 1 if log_weight: edges_to_add.append((head, tail, np.log(weight))) else: edges_to_add.append((head, tail, weight)) log_rank_0_debug( logger, f"Adding {num_added} out of {num_total} items from {kg_adj_file}") G.add_weighted_edges_from(edges_to_add) # convert to entityids G = nx.relabel_nodes(G, entity_symbols.get_qid2eid()) # create adjacency matrix adj = nx.adjacency_matrix( G, nodelist=range(entity_symbols.num_entities_with_pad_and_nocand)) assert ( adj.sum() > 0 ), f"Your KG Adj matrix has all 0 values. Something was likely parsed wrong." # assert that the padded entity has no connections assert (adj[:, -1] != 0).sum() == 0 assert (adj[-1, :] != 0).sum() == 0 # assert that the unk entity has no connections assert (adj[:, 0] != 0).sum() == 0 assert (adj[0, :] != 0).sum() == 0 return adj