def prep(cls, main_args, emb_args, entity_symbols, log_func=print, word_symbols=None): """Loads KG information into matrix format or loads from saved state""" file_tag = os.path.splitext(emb_args.kg_adj)[0] prep_dir = data_utils.get_emb_prep_dir(main_args) prep_file = os.path.join(prep_dir, f'kg_adj_file_{file_tag}.npz') utils.ensure_dir(os.path.dirname(prep_file)) if (not main_args.data_config.overwrite_preprocessed_data and os.path.exists(prep_file)): log_func(f'Loading existing KG adj from {prep_file}') start = time.time() kg_adj = scipy.sparse.load_npz(prep_file) log_func( f'Loaded existing KG adj in {round(time.time() - start, 2)}s') else: start = time.time() kg_adj_file = os.path.join(main_args.data_config.emb_dir, emb_args.kg_adj) log_func(f'Building KG adj from {kg_adj_file}') kg_adj = cls.build_kg_adj(kg_adj_file, entity_symbols, emb_args) scipy.sparse.save_npz(prep_file, kg_adj) log_func( f"Finished building and saving KG adj in {round(time.time() - start, 2)}s." ) return kg_adj, prep_file
def prep(cls, main_args, emb_args, entity_symbols, log_func=print, word_symbols=None): type_str = os.path.splitext(emb_args.type_labels)[0] prep_dir = data_utils.get_emb_prep_dir(main_args) prep_file = os.path.join( prep_dir, f'type_table_{type_str}_{emb_args.max_types}.pt') utils.ensure_dir(os.path.dirname(prep_file)) if (not main_args.data_config.overwrite_preprocessed_data and os.path.exists(prep_file)): log_func(f'Loading existing type table from {prep_file}') start = time.time() eid2typeids_table, type2row, num_types_with_unk = torch.load( prep_file) log_func( f'Loaded existing type table in {round(time.time() - start, 2)}s' ) else: start = time.time() log_func(f'Building type table') type_labels = os.path.join(main_args.data_config.emb_dir, emb_args.type_labels) eid2typeids_table, type2row, num_types_with_unk = cls.build_type_table( type_labels=type_labels, max_types=emb_args.max_types, entity_symbols=entity_symbols) torch.save((eid2typeids_table, type2row, num_types_with_unk), prep_file) log_func( f"Finished building and saving type table in {round(time.time() - start, 2)}s." ) return eid2typeids_table, type2row, num_types_with_unk, prep_file
def save(self, save_dir, prefix=""): """Dumps the type symbols. Args: save_dir: directory string to save prefix: prefix to add to beginning to file Returns: """ utils.ensure_dir(str(save_dir)) utils.dump_json_file( filename=os.path.join(save_dir, "config.json"), contents={ "max_types": self.max_types, }, ) utils.dump_json_file( filename=os.path.join(save_dir, f"{prefix}qid2typenames.json"), contents=self._qid2typenames, ) utils.dump_json_file( filename=os.path.join(save_dir, f"{prefix}qid2typeids.json"), contents=self._qid2typeid, ) utils.dump_json_file( filename=os.path.join(save_dir, f"{prefix}type_vocab.json"), contents=self._type_vocab, )
def save(self, save_dir): """Dumps the entity symbols. Args: save_dir: directory string to save Returns: """ self._sort_alias_cands() utils.ensure_dir(save_dir) utils.dump_json_file( filename=os.path.join(save_dir, "config.json"), contents={ "max_candidates": self.max_candidates, "datetime": str(datetime.now()), }, ) utils.dump_json_file( filename=os.path.join(save_dir, self.alias_cand_map_file), contents=self._alias2qids, ) utils.dump_json_file( filename=os.path.join(save_dir, "qid2title.json"), contents=self._qid2title ) utils.dump_json_file( filename=os.path.join(save_dir, "qid2eid.json"), contents=self._qid2eid ) utils.dump_json_file( filename=os.path.join(save_dir, self.alias_idx_file), contents=self._alias2id, )
def main(): args = parse_args() logging.info(json.dumps(args, indent=4)) entity_symbols = EntitySymbols( load_dir=os.path.join(args.data_dir, args.entity_symbols_dir)) train_file = os.path.join(args.data_dir, args.train_file) save_dir = os.path.join(args.save_dir, "stats") logging.info(f"Will save data to {save_dir}") utils.ensure_dir(save_dir) # compute_histograms(save_dir, entity_symbols) compute_occurrences(save_dir, train_file, entity_symbols, args.lower, args.strip, num_workers=args.num_workers) if not args.no_types: type_symbols = TypeSymbols( entity_symbols=entity_symbols, emb_dir=args.emb_dir, max_types=args.max_types, emb_file="hyena_type_emb.pkl", type_vocab_file="hyena_type_graph.vocab.pkl", type_file="hyena_types.txt") compute_type_occurrences(save_dir, "orig", entity_symbols, type_symbols.qid2typenames, train_file)
def compute_occurrences(save_dir, data_file, entity_dump, lower, strip, num_workers=8): global all_aliases all_aliases = get_all_aliases(entity_dump._alias2qids) # divide up data into chunks num_lines = get_num_lines(data_file) num_processes = min(num_workers, int(multiprocessing.cpu_count())) logging.info(f'Using {num_processes} workers...') chunk_size = int(np.ceil(num_lines / (num_processes))) chunk_file_path = os.path.join(save_dir, 'tmp') utils.ensure_dir(chunk_file_path) chunk_infiles = [ os.path.join(f'{chunk_file_path}', f'data_chunk_{chunk_id}_in.jsonl') for chunk_id in range(num_processes) ] chunk_text_data(data_file, chunk_infiles, chunk_size, num_lines) pool = multiprocessing.Pool(processes=num_processes) subprocess_args = [[chunk_infiles[i], lower, strip] for i in range(num_processes)] results = pool.map(compute_occurrences_single, subprocess_args) pool.close() pool.join() logging.info('Finished collecting counts') logging.info('Merging counts....') # merge counters together ent_occurrences = Counter() # alias histogram alias_occurrences = Counter() # alias text occurrances alias_text_occurrences = Counter() # number of aliases per sentence alias_pair_occurrences = Counter() # alias|entity histogram alias_entity_pair = Counter() for result_set in results: ent_occurrences += result_set['ent_occurrences'] alias_occurrences += result_set['alias_occurrences'] alias_text_occurrences += result_set['alias_text_occurrences'] alias_pair_occurrences += result_set['alias_pair_occurrences'] alias_entity_pair += result_set['alias_entity_pair'] # save counters utils.dump_json_file(filename=os.path.join(save_dir, "entity_count.json"), contents=ent_occurrences) utils.dump_json_file(filename=os.path.join(save_dir, "alias_counts.json"), contents=alias_occurrences) utils.dump_json_file(filename=os.path.join(save_dir, "alias_text_counts.json"), contents=alias_text_occurrences) utils.dump_json_file(filename=os.path.join(save_dir, "alias_pair_occurrences.json"), contents=alias_pair_occurrences) utils.dump_json_file(filename=os.path.join(save_dir, "alias_entity_counts.json"), contents=alias_entity_pair)
def load_regularization_mapping(cls, data_config, entity_symbols, reg_file): """Reads in a csv file with columns [qid, regularization]. In the forward pass, the entity id with associated qid will be regularized with probability regularization. Args: data_config: data config qid2topk_eid: Dict from QID to eid in the entity embedding num_entities_with_pad_and_nocand: number of entities including pad and null candidate option reg_file: regularization csv file Returns: Tensor where each value is the regularization value for EID """ reg_str = os.path.splitext(os.path.basename(reg_file.replace("/", "_")))[0] prep_dir = data_utils.get_data_prep_dir(data_config) prep_file = os.path.join( prep_dir, f"entity_regularization_mapping_{reg_str}.pt") utils.ensure_dir(os.path.dirname(prep_file)) log_rank_0_debug(logger, f"Looking for regularization mapping in {prep_file}") if not data_config.overwrite_preprocessed_data and os.path.exists( prep_file): log_rank_0_debug( logger, f"Loading existing entity regularization mapping from {prep_file}", ) start = time.time() eid2reg = torch.load(prep_file) log_rank_0_debug( logger, f"Loaded existing entity regularization mapping in {round(time.time() - start, 2)}s", ) else: start = time.time() log_rank_0_info( logger, f"Building entity regularization mapping from {reg_file}") qid2reg = pd.read_csv(reg_file) assert ( "qid" in qid2reg.columns and "regularization" in qid2reg.columns ), f"Expected qid and regularization as the column names for {reg_file}" # default of no mask eid2reg_arr = [0.0 ] * entity_symbols.num_entities_with_pad_and_nocand for row_idx, row in qid2reg.iterrows(): if entity_symbols.qid_exists(row["qid"]): eid = entity_symbols.get_eid(row["qid"]) eid2reg_arr[eid] = row["regularization"] eid2reg = torch.tensor(eid2reg_arr) torch.save(eid2reg, prep_file) log_rank_0_debug( logger, f"Finished building and saving entity regularization mapping in {round(time.time() - start, 2)}s.", ) return eid2reg
def setup_run_folders(args, mode): if args.run_config.timestamp == "": # create a timestamp for directory for saving results start_date = strftime("%Y%m%d") start_time = strftime("%H%M%S") args.run_config.timestamp = "{:s}_{:s}".format(start_date, start_time) utils.ensure_dir(get_save_folder(args.run_config)) return
def prep( cls, data_config, entity_symbols, num_aliases_with_pad_and_unk, num_cands_K, ): """Preps the alias to entity EID table. Args: data_config: data config entity_symbols: entity symbols num_aliases_with_pad_and_unk: number of aliases including pad and unk num_cands_K: number of candidates per alias (aka K) Returns: torch Tensor of the alias to EID table, save pt file """ # we pass num_aliases_with_pad_and_unk and num_cands_K to remove the dependence on entity_symbols # when the alias table is already prepped data_shape = (num_aliases_with_pad_and_unk, num_cands_K) # dependent on train_in_candidates flag prep_dir = data_utils.get_emb_prep_dir(data_config) alias_str = os.path.splitext( data_config.alias_cand_map.replace("/", "_"))[0] prep_file = os.path.join( prep_dir, f"alias2entity_table_{alias_str}_InC{int(data_config.train_in_candidates)}.pt", ) log_rank_0_debug(logger, f"Looking for alias table in {prep_file}") if not data_config.overwrite_preprocessed_data and os.path.exists( prep_file): log_rank_0_debug(logger, f"Loading alias table from {prep_file}") start = time.time() alias2entity_table = np.memmap(prep_file, dtype="int64", mode="r+", shape=data_shape) log_rank_0_debug( logger, f"Loaded alias table in {round(time.time() - start, 2)}s") else: start = time.time() log_rank_0_debug(logger, f"Building alias table") utils.ensure_dir(prep_dir) mmap_file = np.memmap(prep_file, dtype="int64", mode="w+", shape=data_shape) alias2entity_table = cls.build_alias_table(data_config, entity_symbols) mmap_file[:] = alias2entity_table[:] mmap_file.flush() log_rank_0_debug( logger, f"Finished building and saving alias table in {round(time.time() - start, 2)}s.", ) alias2entity_table = torch.from_numpy(alias2entity_table) return alias2entity_table, prep_file
def prep(cls, data_config, entity_symbols): """Prep the title data. Args: data_config: data config entity_symbols: entity symbols Returns: torch tensor EID to title token IDs, EID to title token mask, EID to title token type ID (for BERT) """ prep_dir = data_utils.get_emb_prep_dir(data_config) prep_file_token_ids = os.path.join( prep_dir, f"title_token_ids_{data_config.word_embedding.bert_model}.pt") prep_file_attn_mask = os.path.join( prep_dir, f"title_attn_mask_{data_config.word_embedding.bert_model}.pt") prep_file_token_type_ids = os.path.join( prep_dir, f"title_token_type_ids_{data_config.word_embedding.bert_model}.pt") utils.ensure_dir(os.path.dirname(prep_file_token_ids)) log_rank_0_debug( logger, f"Looking for title table mapping in {prep_file_token_ids}") if (not data_config.overwrite_preprocessed_data and os.path.exists(prep_file_token_ids) and os.path.exists(prep_file_attn_mask) and os.path.exists(prep_file_token_type_ids)): log_rank_0_debug( logger, f"Loading existing title table from {prep_file_token_ids}") start = time.time() entity2titleid = torch.load(prep_file_token_ids) entity2titlemask = torch.load(prep_file_attn_mask) entity2tokentypeid = torch.load(prep_file_token_type_ids) log_rank_0_debug( logger, f"Loaded existing title table in {round(time.time() - start, 2)}s", ) else: start = time.time() log_rank_0_debug(logger, f"Loading tokenizer") tokenizer = load_tokenizer(data_config) ( entity2titleid, entity2titlemask, entity2tokentypeid, ) = cls.build_title_table(tokenizer=tokenizer, entity_symbols=entity_symbols) torch.save(entity2titleid, prep_file_token_ids) torch.save(entity2titlemask, prep_file_attn_mask) torch.save(entity2tokentypeid, prep_file_token_type_ids) log_rank_0_debug( logger, f"Finished building and saving title table in {round(time.time() - start, 2)}s.", ) return entity2titleid, entity2titlemask, entity2tokentypeid
def get_data_prep_dir(data_config): """Get data prep directory for saving prep files. Lives inside data_dir. Args: data_config: data config Returns: directory path """ prep_dir = os.path.join(data_config.data_dir, data_config.data_prep_dir) utils.ensure_dir(prep_dir) return prep_dir
def dump(self, save_dir, stats={}, args=None): self._sort_alias_cands() utils.ensure_dir(save_dir) utils.dump_json_file(filename=os.path.join(save_dir, "config.json"), contents={"max_candidates":self.max_candidates, "max_alias_len":self.max_alias_len, "datetime": str(datetime.now())}) utils.dump_json_file(filename=os.path.join(save_dir, self.alias_cand_map_file), contents=self._alias2qids) utils.dump_json_file(filename=os.path.join(save_dir, "qid2title.json"), contents=self._qid2title) utils.dump_json_file(filename=os.path.join(save_dir, "qid2eid.json"), contents=self._qid2eid) utils.dump_json_file(filename=os.path.join(save_dir, "filter_stats.json"), contents=stats) if args is not None: utils.dump_json_file(filename=os.path.join(save_dir, "args.json"), contents=vars(args))
def __init__(self, args, is_writer, distributed): super(BERTWordSymbols, self).__init__(args, is_writer, distributed) self.is_bert = True self.unk_id = None self.pad_id = 0 cache_dir = args.word_embedding.cache_dir utils.ensure_dir(cache_dir) # import torch # import os # from transformers import BertTokenizer # cache_dir = "pretrained_bert_models" # tokenizer = BertTokenizer.from_pretrained('bert-base-cased', cache_dir=cache_dir, do_lower_case=False) # torch.save(tokenizer, os.path.join(cache_dir, "bert_base_cased_tokenizer.pt")) self.tokenizer = torch.load(os.path.join(cache_dir, "bert_base_cased_tokenizer.pt")) self.vocab = self.tokenizer.vocab self.num_words = len(self.vocab) self.word_embedding_dim = 768
def prep(cls, data_config, emb_args, entity_symbols): """Prep the type id table. Args: data_config: data config emb_args: embedding args entity_symbols: entity synbols Returns: torch tensor from EID to type IDS, type ID to row in type embedding matrix, and number of types with unk type """ type_str = os.path.splitext(emb_args.type_labels.replace("/", "_"))[0] prep_dir = data_utils.get_emb_prep_dir(data_config) prep_file = os.path.join( prep_dir, f"type_table_{type_str}_{emb_args.max_types}.pt") utils.ensure_dir(os.path.dirname(prep_file)) if not data_config.overwrite_preprocessed_data and os.path.exists( prep_file): log_rank_0_debug(logger, f"Loading existing type table from {prep_file}") start = time.time() eid2typeids_table, type2row_dict, num_types_with_unk = torch.load( prep_file) log_rank_0_debug( logger, f"Loaded existing type table in {round(time.time() - start, 2)}s", ) else: start = time.time() type_labels = os.path.join(data_config.emb_dir, emb_args.type_labels) type_vocab = os.path.join(data_config.emb_dir, emb_args.type_vocab) log_rank_0_debug(logger, f"Building type table from {type_labels}") eid2typeids_table, type2row_dict, num_types_with_unk = cls.build_type_table( type_labels=type_labels, type_vocab=type_vocab, max_types=emb_args.max_types, entity_symbols=entity_symbols, ) torch.save((eid2typeids_table, type2row_dict, num_types_with_unk), prep_file) log_rank_0_debug( logger, f"Finished building and saving type table in {round(time.time() - start, 2)}s.", ) return eid2typeids_table, type2row_dict, num_types_with_unk, prep_file
def prep( cls, data_config, emb_args, entity_symbols, threshold, log_weight, ): """Preps the KG information. Args: data_config: data config emb_args: embedding args entity_symbols: entity symbols threshold: weight threshold for counting an edge log_weight: whether to take the log of the weight value after the threshold Returns: numpy sparce KG adjacency matrix, prep file """ file_tag = os.path.splitext(emb_args.kg_adj.replace("/", "_"))[0] prep_dir = data_utils.get_emb_prep_dir(data_config) prep_file = os.path.join( prep_dir, f"kg_adj_file_{file_tag}_{threshold}_{log_weight}.npz") utils.ensure_dir(os.path.dirname(prep_file)) if not data_config.overwrite_preprocessed_data and os.path.exists( prep_file): log_rank_0_debug(logger, f"Loading existing KG adj from {prep_file}") start = time.time() kg_adj = scipy.sparse.load_npz(prep_file) log_rank_0_debug( logger, f"Loaded existing KG adj in {round(time.time() - start, 2)}s") else: start = time.time() kg_adj_file = os.path.join(data_config.emb_dir, emb_args.kg_adj) log_rank_0_debug(logger, f"Building KG adj from {kg_adj_file}") kg_adj = cls.build_kg_adj(kg_adj_file, entity_symbols, threshold, log_weight) scipy.sparse.save_npz(prep_file, kg_adj) log_rank_0_debug( logger, f"Finished building and saving KG adj in {round(time.time() - start, 2)}s.", ) return kg_adj, prep_file
def main(): args = parse_args() logging.info(json.dumps(vars(args), indent=4)) entity_symbols = EntitySymbols.load_from_cache( load_dir=os.path.join(args.data_dir, args.entity_symbols_dir)) train_file = os.path.join(args.data_dir, args.train_file) save_dir = os.path.join(args.save_dir, "stats") logging.info(f"Will save data to {save_dir}") utils.ensure_dir(save_dir) # compute_histograms(save_dir, entity_symbols) compute_occurrences( save_dir, train_file, entity_symbols, args.lower, args.strip, num_workers=args.num_workers, )
def dump(self, save_dir): #memmapped files bahve badly if you try to overwrite them in memory, which is what we'd be doing if load_dir == save_dir if self._loaded_from_dir is None or self._loaded_from_dir != save_dir: utils.ensure_dir(save_dir) utils.dump_json_file(filename=os.path.join(save_dir, "fmt_types.json"), contents=self._fmt_types) utils.dump_json_file(filename=os.path.join(save_dir, "max_values.json"), contents=self._max_values) utils.dump_json_file(filename=os.path.join(save_dir, "vocabulary.json"), contents=self._stoi) np.save(file=os.path.join(save_dir, "itos.npy"), arr=self._itos, allow_pickle=True) for tri_name in self._fmt_types: self._record_tris[tri_name].save( os.path.join(save_dir, f'record_trie_{tri_name}.marisa'))
def load_regularization_mapping(cls, main_args, num_types_with_pad_and_unk, type2row_dict, reg_file, log_func): """ Reads in a csv file with columns [qid, regularization]. In the forward pass, the entity id with associated qid will be regularized with probability regularization. """ reg_str = reg_file.split(".csv")[0] prep_dir = data_utils.get_data_prep_dir(main_args) prep_file = os.path.join( prep_dir, f'entity_regularization_mapping_{reg_str}.pt') utils.ensure_dir(os.path.dirname(prep_file)) log_func(f"Looking for regularization mapping in {prep_file}") if (not main_args.data_config.overwrite_preprocessed_data and os.path.exists(prep_file)): log_func( f'Loading existing entity regularization mapping from {prep_file}' ) start = time.time() typeid2reg = torch.load(prep_file) log_func( f'Loaded existing entity regularization mapping in {round(time.time() - start, 2)}s' ) else: start = time.time() reg_file = os.path.join(main_args.data_config.data_dir, reg_file) log_func(f'Building entity regularization mapping from {reg_file}') typeid2reg_raw = pd.read_csv(reg_file) assert "typeid" in typeid2reg_raw.columns and "regularization" in typeid2reg_raw.columns, f"Expected typeid and regularization as the column names for {reg_file}" # default of no mask typeid2reg_arr = [0.0] * num_types_with_pad_and_unk for row_idx, row in typeid2reg_raw.iterrows(): # Happens when we filter QIDs not in our entity dump and the max typeid is smaller than the total number if int(row["typeid"]) not in type2row_dict: continue typeid = type2row_dict[int(row["typeid"])] typeid2reg_arr[typeid] = row["regularization"] typeid2reg = torch.tensor(typeid2reg_arr) torch.save(typeid2reg, prep_file) log_func( f"Finished building and saving entity regularization mapping in {round(time.time() - start, 2)}s." ) return typeid2reg
def prep(cls, args, entity_symbols, num_aliases_with_pad, num_cands_K, log_func=print): # we pass num_aliases_with_pad and num_cands_K to remove the dependence on entity_symbols # when the alias table is already prepped data_shape = (num_aliases_with_pad, num_cands_K) # dependent on train_in_candidates flag prep_dir = data_utils.get_emb_prep_dir(args) alias_str = os.path.splitext(args.data_config.alias_cand_map)[0] prep_file = os.path.join( prep_dir, f'alias2entity_table_{alias_str}_InC{int(args.data_config.train_in_candidates)}.pt' ) if (not args.data_config.overwrite_preprocessed_data and os.path.exists(prep_file)): log_func(f'Loading alias table from {prep_file}') start = time.time() alias2entity_table = np.memmap(prep_file, dtype='int64', mode='r', shape=data_shape) log_func(f'Loaded alias table in {round(time.time() - start, 2)}s') else: start = time.time() log_func(f'Building alias table') utils.ensure_dir(prep_dir) mmap_file = np.memmap(prep_file, dtype='int64', mode='w+', shape=data_shape) alias2entity_table = cls.build_alias_table(args, entity_symbols) mmap_file[:] = alias2entity_table[:] mmap_file.flush() log_func( f"Finished building and saving alias table in {round(time.time() - start, 2)}s." ) return alias2entity_table, prep_file
def load_regularization_mapping(cls, main_args, entity_symbols, reg_file, log_func): """ Reads in a csv file with columns [qid, regularization]. In the forward pass, the entity id with associated qid will be regularized with probability regularization. """ reg_str = reg_file.split(".csv")[0] prep_dir = data_utils.get_data_prep_dir(main_args) prep_file = os.path.join( prep_dir, f'entity_regularization_mapping_{reg_str}.pt') utils.ensure_dir(os.path.dirname(prep_file)) log_func(f"Looking for regularization mapping in {prep_file}") if (not main_args.data_config.overwrite_preprocessed_data and os.path.exists(prep_file)): log_func( f'Loading existing entity regularization mapping from {prep_file}' ) start = time.time() eid2reg = torch.load(prep_file) log_func( f'Loaded existing entity regularization mapping in {round(time.time() - start, 2)}s' ) else: start = time.time() log_func(f'Building entity regularization mapping from {reg_file}') qid2reg = pd.read_csv(reg_file) assert "qid" in qid2reg.columns and "regularization" in qid2reg.columns, f"Expected qid and regularization as the column names for {reg_file}" # default of no mask eid2reg_arr = [0.0 ] * entity_symbols.num_entities_with_pad_and_nocand for row_idx, row in qid2reg.iterrows(): eid = entity_symbols.get_eid(row["qid"]) eid2reg_arr[eid] = row["regularization"] eid2reg = torch.tensor(eid2reg_arr) torch.save(eid2reg, prep_file) log_func( f"Finished building and saving entity regularization mapping in {round(time.time() - start, 2)}s." ) return eid2reg
def test_end2end_withreg_evalbatch(self): reg_file = "test/temp/reg_file.csv" utils.ensure_dir("test/temp") reg_data = [ ["qid", "regularization"], ["Q1", "0.5"], ["Q2", "0.3"], ["Q3", "0.2"], ["Q4", "0.9"], ] self.args.data_config.eval_accumulation_steps = 2 self.args.run_config.dataset_threads = 2 self.args.run_config.eval_batch_size = 2 with open(reg_file, "w") as out_f: for item in reg_data: out_f.write(",".join(item) + "\n") self.args.data_config.ent_embeddings[0]["args"][ "regularize_mapping"] = reg_file scores = run_model(mode="train", config=self.args) assert type(scores) is dict assert len(scores) > 0 assert scores["model/all/train/loss"] < 0.05 self.args["model_config"][ "model_path"] = f"{emmental.Meta.log_path}/last_model.pth" emmental.Meta.config["model_config"][ "model_path"] = f"{emmental.Meta.log_path}/last_model.pth" result_file, out_emb_file = run_model(mode="dump_embs", config=self.args) assert os.path.exists(result_file) results = [ujson.loads(li) for li in open(result_file)] assert 18 == len(results) # 18 total sentences assert set([f for li in results for f in li["ctx_emb_ids"] ]) == set(range(51)) # 38 total mentions assert os.path.exists(out_emb_file) shutil.rmtree("test/temp", ignore_errors=True)
def __init__( self, main_args, dataset, use_weak_label, entity_symbols, dataset_threads, split="train", ): global_start = time.time() log_rank_0_info(logger, f"Building slice dataset for {split} from {dataset}.") spawn_method = main_args.run_config.spawn_method data_config = main_args.data_config orig_spawn = multiprocessing.get_start_method() multiprocessing.set_start_method(spawn_method, force=True) self.slice_names = data_utils.get_eval_slices(data_config.eval_slices) self.get_slice_dt = lambda max_a2p: np.dtype([ ("sent_idx", int), ("subslice_idx", int), ("alias_slice_incidence", int, (max_a2p, )), ("prob_labels", float, (max_a2p, )), ]) self.get_storage = lambda max_a2p: np.dtype( [(slice_name, self.get_slice_dt(max_a2p)) for slice_name in self.slice_names]) # Folder for all mmap saved files save_dataset_folder = data_utils.get_save_data_folder( data_config, use_weak_label, dataset) utils.ensure_dir(save_dataset_folder) # Folder for temporary output files temp_output_folder = os.path.join(data_config.data_dir, data_config.data_prep_dir, f"prep_{split}_slice_files") utils.ensure_dir(temp_output_folder) # Input step 1 create_ex_indir = os.path.join(temp_output_folder, "create_examples_input") utils.ensure_dir(create_ex_indir) # Input step 2 create_ex_outdir = os.path.join(temp_output_folder, "create_examples_output") utils.ensure_dir(create_ex_outdir) # Meta data saved files meta_file = os.path.join(temp_output_folder, "meta_data.json") # File for standard training data hash = hashlib.sha1(str( self.slice_names).encode("UTF-8")).hexdigest()[:10] self.save_dataset_name = os.path.join(save_dataset_folder, f"ned_slices_{hash}.bin") self.save_data_config_name = os.path.join(save_dataset_folder, "ned_slices_config.json") # ======================================================================================= # SLICE DATA # ======================================================================================= log_rank_0_debug(logger, "Loading dataset...") log_rank_0_debug(logger, f"Seeing if {self.save_dataset_name} exists") if data_config.overwrite_preprocessed_data or (not os.path.exists( self.save_dataset_name)): st_time = time.time() try: log_rank_0_info( logger, f"Building dataset from scratch. Saving to {save_dataset_folder}", ) create_examples( dataset, create_ex_indir, create_ex_outdir, meta_file, data_config, dataset_threads, self.slice_names, use_weak_label, split, ) max_alias2pred = utils.load_json_file( meta_file)["max_alias2pred"] convert_examples_to_features_and_save( meta_file, dataset_threads, self.slice_names, self.save_dataset_name, self.get_storage(max_alias2pred), ) utils.dump_json_file(self.save_data_config_name, {"max_alias2pred": max_alias2pred}) log_rank_0_debug( logger, f"Finished prepping data in {time.time() - st_time}") except Exception as e: tb = traceback.TracebackException.from_exception(e) logger.error(e) logger.error("\n".join(tb.stack.format())) shutil.rmtree(save_dataset_folder, ignore_errors=True) raise log_rank_0_info( logger, f"Loading data from {self.save_dataset_name} and {self.save_data_config_name}", ) max_alias2pred = utils.load_json_file( self.save_data_config_name)["max_alias2pred"] self.data, self.sent_to_row_id_dict = self.build_data_dict( self.save_dataset_name, self.get_storage(max_alias2pred)) assert len(self.data) > 0 assert len(self.sent_to_row_id_dict) > 0 log_rank_0_debug(logger, f"Removing temporary output files") shutil.rmtree(temp_output_folder, ignore_errors=True) # Set spawn back to original/default, which is "fork" or "spawn". This is needed for the Meta.config to # be correctly passed in the collate_fn. multiprocessing.set_start_method(orig_spawn, force=True) log_rank_0_info( logger, f"Final slice data initialization time from {split} is {time.time() - global_start}s", )
def run_dump_preds(args, test_data_file, trainer, dataloader, logger, entity_symbols, dump_embs=False): """ Dumpes Preds: Remember that if a sentence has all gold=False anchors, it's dropped and will not be seen If a subsplit of a sentence has all gold=False anchors, it will also be dropped and not seen """ # we only care about the entity embeddings for the final slice head eval_folder = train_utils.get_eval_folder(args, test_data_file) utils.ensure_dir(eval_folder) # write to file (M x hidden x size for each data point -- next step will deal with recovering original sentence indices for overflowing sentences) test_file_tag = test_data_file.split('.jsonl')[0] entity_emb_file = os.path.join(eval_folder, f'{test_file_tag}_entity_embs.pt') emb_file_config = entity_emb_file.split('.pt')[0] + '_config' M = args.data_config.max_aliases K = entity_symbols.max_candidates + ( not args.data_config.train_in_candidates) # TODO: fix extra dimension issue if dump_embs: storage_type = np.dtype([('M', int), ('K', int), ('hidden_size', int), ('sent_idx', int), ('subsent_idx', int), ('alias_list_pos', int, M), ('entity_emb', float, M * args.model_config.hidden_size), ('final_loss_true', int, M), ('final_loss_pred', int, M), ('final_loss_prob', float, M), ('final_loss_cand_probs', float, M * K)]) else: # don't need to extract contextualized entity embedding storage_type = np.dtype([('M', int), ('K', int), ('hidden_size', int), ('sent_idx', int), ('subsent_idx', int), ('alias_list_pos', int, M), ('final_loss_true', int, M), ('final_loss_pred', int, M), ('final_loss_prob', float, M), ('final_loss_cand_probs', float, M * K)]) mmap_file = np.memmap(entity_emb_file, dtype=storage_type, mode='w+', shape=(len(dataloader.dataset), )) # Init sent_idx to -1 for debugging mmap_file[:]['sent_idx'] = -1 np.save(emb_file_config, storage_type, allow_pickle=True) logger.debug(f'Created file {entity_emb_file} to save predictions.') start_idx = 0 logger.info( f'{len(dataloader)*args.run_config.eval_batch_size} samples, {len(dataloader)} batches, {len(dataloader.dataset)} len dataset' ) for i, batch in enumerate(dataloader): curr_batch_size = batch["sent_idx"].shape[0] end_idx = start_idx + curr_batch_size preds, _, entity_pack, final_entity_embs = trainer.update(batch, eval=True) model_preds = preds[DISAMBIG][FINAL_LOSS] # don't want to choose padded entity indices probs = torch.exp( masked_class_logsoftmax(pred=model_preds, mask=~entity_pack.mask, dim=2)) mmap_file[start_idx:end_idx]['M'] = M mmap_file[start_idx:end_idx]['K'] = K mmap_file[start_idx:end_idx][ 'hidden_size'] = args.model_config.hidden_size mmap_file[start_idx:end_idx]['sent_idx'] = batch["sent_idx"] mmap_file[start_idx:end_idx]['subsent_idx'] = batch["subsent_idx"] mmap_file[start_idx:end_idx]['alias_list_pos'] = batch[ 'alias_list_pos'] # This will give all aliases seen by the model during training, independent of if it's gold or not mmap_file[start_idx:end_idx][f'final_loss_true'] = batch[ 'true_entity_idx_for_train'].reshape(curr_batch_size, M).cpu().numpy() # get max for each alias, probs is batch x M x K max_probs, pred_cands = probs.max(dim=2) mmap_file[start_idx:end_idx]['final_loss_pred'] = pred_cands.cpu( ).numpy() mmap_file[start_idx:end_idx]['final_loss_prob'] = max_probs.cpu( ).numpy() mmap_file[start_idx:end_idx]['final_loss_cand_probs'] = probs.cpu( ).numpy().reshape(curr_batch_size, -1) # final_entity_embs is batch x M x K x hidden_size, pred_cands in batch x M if dump_embs: chosen_entity_embs = select_embs(embs=final_entity_embs, pred_cands=pred_cands, batch_size=curr_batch_size, M=M) # write chosen entity embs to file for contextualized entity embeddings mmap_file[start_idx:end_idx][ 'entity_emb'] = chosen_entity_embs.reshape( curr_batch_size, -1).cpu().numpy() start_idx += curr_batch_size if i % 100 == 0 and i != 0: logger.info(f'Saved {i} batches of predictions') # restitch together and write data file result_file = os.path.join(eval_folder, args.run_config.result_label_file) logger.info(f'Writing predictions to {result_file}...') filt_pred_data = merge_subsentences(os.path.join(args.data_config.data_dir, test_data_file), mmap_file, dump_embs=dump_embs) sent_idx_map = get_sent_idx_map(filt_pred_data) write_data_labels(filt_pred_data=filt_pred_data, data_file=os.path.join(args.data_config.data_dir, test_data_file), out_file=result_file, sent_idx_map=sent_idx_map, entity_dump=entity_symbols, train_in_candidates=args.data_config.train_in_candidates, dump_embs=dump_embs) out_emb_file = None # save easier-to-use embedding file if dump_embs: hidden_size = filt_pred_data[0]['hidden_size'] out_emb_file = os.path.join(eval_folder, args.run_config.result_emb_file) np.save(out_emb_file, filt_pred_data['entity_emb'].reshape(-1, hidden_size)) logger.info(f'Saving contextual entity embeddings to {out_emb_file}') logger.info(f'Wrote predictions to {result_file}') return result_file, out_emb_file
def run_model(mode, config, run_config_path=None): """ Main run method for Emmental Bootleg models. Args: mode: run mode (train, eval, dump_preds, dump_embs) config: parsed model config run_config_path: original config path (for saving) Returns: """ # Set up distributed backend and save configuration files setup(config, run_config_path) # Load entity symbols log_rank_0_info(logger, f"Loading entity symbols...") entity_symbols = EntitySymbols.load_from_cache( load_dir=os.path.join(config.data_config.entity_dir, config.data_config.entity_map_dir), alias_cand_map_file=config.data_config.alias_cand_map, alias_idx_file=config.data_config.alias_idx_map, ) # Create tasks tasks = [NED_TASK] if config.data_config.type_prediction.use_type_pred is True: tasks.append(TYPE_PRED_TASK) # Create splits for data loaders data_splits = [TRAIN_SPLIT, DEV_SPLIT, TEST_SPLIT] # Slices are for eval so we only split on test/dev slice_splits = [DEV_SPLIT, TEST_SPLIT] # If doing eval, only run on test data if mode in ["eval", "dump_preds", "dump_embs"]: data_splits = [TEST_SPLIT] slice_splits = [TEST_SPLIT] # We only do dumping if weak labels is True if mode in ["dump_preds", "dump_embs"]: if config.data_config[ f"{TEST_SPLIT}_dataset"].use_weak_label is False: raise ValueError( f"When calling dump_preds or dump_embs, we require use_weak_label to be True." ) # Gets embeddings that need to be prepped during data prep or in the __get_item__ method batch_on_the_fly_kg_adj = get_dataloader_embeddings(config, entity_symbols) # Gets dataloaders dataloaders = get_dataloaders( config, tasks, data_splits, entity_symbols, batch_on_the_fly_kg_adj, ) slice_datasets = get_slicedatasets(config, slice_splits, entity_symbols) configure_optimizer(config) # Create models and add tasks if config.model_config.attn_class == "BERTNED": log_rank_0_info(logger, f"Starting NED-Base Model") assert (config.data_config.type_prediction.use_type_pred is False), f"NED-Base does not support type prediction" assert ( config.data_config.word_embedding.use_sent_proj is False ), f"NED-Base requires word_embeddings.use_sent_proj to be False" model = EmmentalModel(name="NED-Base") model.add_tasks( ned_task.create_task(config, entity_symbols, slice_datasets)) else: log_rank_0_info(logger, f"Starting Bootleg Model") model = EmmentalModel(name="Bootleg") # TODO: make this more general for other tasks -- iterate through list of tasks # and add task for each model.add_task( ned_task.create_task(config, entity_symbols, slice_datasets)) if TYPE_PRED_TASK in tasks: model.add_task( type_pred_task.create_task(config, entity_symbols, slice_datasets)) # Add the mention type embedding to the embedding payload type_pred_task.update_ned_task(model) # Print param counts if mode == "train": log_rank_0_debug(logger, "PARAMS WITH GRAD\n" + "=" * 30) total_params = count_parameters(model, requires_grad=True, logger=logger) log_rank_0_info(logger, f"===> Total Params With Grad: {total_params}") log_rank_0_debug(logger, "PARAMS WITHOUT GRAD\n" + "=" * 30) total_params = count_parameters(model, requires_grad=False, logger=logger) log_rank_0_info(logger, f"===> Total Params Without Grad: {total_params}") # Load the best model from the pretrained model if config["model_config"]["model_path"] is not None: model.load(config["model_config"]["model_path"]) # Barrier if config["learner_config"]["local_rank"] == 0: torch.distributed.barrier() # Train model if mode == "train": emmental_learner = EmmentalLearner() emmental_learner._set_optimizer(model) emmental_learner.learn(model, dataloaders) if config.learner_config.local_rank in [0, -1]: model.save(f"{emmental.Meta.log_path}/last_model.pth") # Multi-gpu DataParallel eval (NOT distributed) if mode in ["eval", "dump_embs", "dump_preds"]: # This happens inside EmmentalLearner for training if (config["learner_config"]["local_rank"] == -1 and config["model_config"]["dataparallel"]): model._to_dataparallel() # If just finished training a model or in eval mode, run eval if mode in ["train", "eval"]: scores = model.score(dataloaders) # Save metrics and models log_rank_0_info(logger, f"Saving metrics to {emmental.Meta.log_path}") log_rank_0_info(logger, f"Metrics: {scores}") scores["log_path"] = emmental.Meta.log_path if config.learner_config.local_rank in [0, -1]: write_to_file(f"{emmental.Meta.log_path}/{mode}_metrics.txt", scores) eval_utils.write_disambig_metrics_to_csv( f"{emmental.Meta.log_path}/{mode}_disambig_metrics.csv", scores) return scores # If you want detailed dumps, save model outputs assert mode in [ "dump_preds", "dump_embs", ], 'Mode must be "dump_preds" or "dump_embs"' dump_embs = False if mode != "dump_embs" else True assert ( len(dataloaders) == 1 ), f"We should only have length 1 dataloaders for dump_embs and dump_preds!" final_result_file, final_out_emb_file = None, None if config.learner_config.local_rank in [0, -1]: # Setup files/folders filename = os.path.basename(dataloaders[0].dataset.raw_filename) log_rank_0_debug( logger, f"Collecting sentence to mention map {os.path.join(config.data_config.data_dir, filename)}", ) sentidx2num_mentions, sent_idx2row = eval_utils.get_sent_idx2num_mens( os.path.join(config.data_config.data_dir, filename)) log_rank_0_debug(logger, f"Done collecting sentence to mention map") eval_folder = eval_utils.get_eval_folder(filename) subeval_folder = os.path.join(eval_folder, "batch_results") utils.ensure_dir(subeval_folder) # Will keep track of sentences dumped already. These will only be ones with mentions all_dumped_sentences = set() number_dumped_batches = 0 total_mentions_seen = 0 all_result_files = [] all_out_emb_files = [] # Iterating over batches of predictions for res_i, res_dict in enumerate( eval_utils.batched_pred_iter( model, dataloaders[0], config.run_config.eval_accumulation_steps, sentidx2num_mentions, )): ( result_file, out_emb_file, final_sent_idxs, mentions_seen, ) = eval_utils.disambig_dump_preds( res_i, total_mentions_seen, config, res_dict, sentidx2num_mentions, sent_idx2row, subeval_folder, entity_symbols, dump_embs, NED_TASK, ) all_dumped_sentences.update(final_sent_idxs) all_result_files.append(result_file) all_out_emb_files.append(out_emb_file) total_mentions_seen += mentions_seen number_dumped_batches += 1 # Dump the sentences that had no mentions and were not already dumped # Assert all remaining sentences have no mentions assert all( v == 0 for k, v in sentidx2num_mentions.items() if k not in all_dumped_sentences ), (f"Sentences with mentions were not dumped: " f"{[k for k, v in sentidx2num_mentions.items() if k not in all_dumped_sentences]}" ) empty_sentidx2row = { k: v for k, v in sent_idx2row.items() if k not in all_dumped_sentences } empty_resultfile = eval_utils.get_result_file(number_dumped_batches, subeval_folder) all_result_files.append(empty_resultfile) # Dump the outputs eval_utils.write_data_labels_single( sentidx2row=empty_sentidx2row, output_file=empty_resultfile, filt_emb_data=None, sental2embid={}, alias_cand_map=entity_symbols.get_alias2qids(), qid2eid=entity_symbols.get_qid2eid(), result_alias_offset=total_mentions_seen, train_in_cands=config.data_config.train_in_candidates, max_cands=entity_symbols.max_candidates, dump_embs=dump_embs, ) log_rank_0_info( logger, f"Finished dumping. Merging results across accumulation steps.") # Final result files for labels and embeddings final_result_file = os.path.join(eval_folder, config.run_config.result_label_file) # Copy labels output = open(final_result_file, "wb") for file in all_result_files: shutil.copyfileobj(open(file, "rb"), output) output.close() log_rank_0_info(logger, f"Bootleg labels saved at {final_result_file}") # Try to copy embeddings if dump_embs: final_out_emb_file = os.path.join( eval_folder, config.run_config.result_emb_file) log_rank_0_info( logger, f"Trying to merge numpy embedding arrays. " f"If your machine is limited in memory, this may cause OOM errors. " f"Is that happens, result files should be saved in {subeval_folder}.", ) all_arrays = [] for i, npfile in enumerate(all_out_emb_files): all_arrays.append(np.load(npfile)) np.save(final_out_emb_file, np.concatenate(all_arrays)) log_rank_0_info( logger, f"Bootleg embeddings saved at {final_out_emb_file}") # Cleanup try_rmtree(subeval_folder) return final_result_file, final_out_emb_file
def write_data_labels( num_processes, result_alias_offset, merged_entity_emb_file, merged_storage_type, sent_idx2row, cache_folder, out_file, entity_dump, train_in_candidates, max_candidates, dump_embs, trie_candidate_map_folder=None, trie_qid2eid_file=None, ): """Takes the flattened data from merge_sentences and writes out predictions to a file, one line per sentence. The embedding ids are added to the file if dump_embs is True. Args: num_processes: number of processes result_alias_offset: alias offset of this batch of examples for writing out merged_entity_emb_file: input memmap file after merge sentences merged_storage_type: input file storage type sent_idx2row: Dict of sentence idx to row relevant to this subbatch cache_folder: folder to save temporary outputs out_file: final output file for predictions entity_dump: entity dump train_in_candidates: whether NC entities are not in candidate lists max_candidates: maximum number of candidates dump_embs: whether to dump embeddings or not trie_candidate_map_folder: folder where trie of alias->candidate map is stored for parallel proccessing trie_qid2eid_file: file where trie of qid->eid map is stored for parallel proccessing Returns: """ st = time.time() sental2embid = get_sental2embid(merged_entity_emb_file, merged_storage_type) log_rank_0_debug(logger, f"Finished getting sentence map {time.time() - st}s") total_input = len(sent_idx2row) if num_processes == 1: filt_emb_data = np.memmap(merged_entity_emb_file, dtype=merged_storage_type, mode="r+") write_data_labels_single( sentidx2row=sent_idx2row, output_file=out_file, filt_emb_data=filt_emb_data, sental2embid=sental2embid, alias_cand_map=entity_dump.get_alias2qids(), qid2eid=entity_dump.get_qid2eid(), result_alias_offset=result_alias_offset, train_in_cands=train_in_candidates, max_cands=max_candidates, dump_embs=dump_embs, ) else: assert ( trie_candidate_map_folder is not None ), "trie_candidate_map_folder is None and you have parallel turned on" assert (trie_qid2eid_file is not None ), "trie_qid2eid_file is None and you have parallel turned on" # Get trie of sentence map trie_folder = os.path.join(cache_folder, "bootleg_sental2embid") utils.ensure_dir(trie_folder) trie_file = os.path.join(trie_folder, "sentidx.marisa") utils.create_single_item_trie(sental2embid, out_file=trie_file) # Chunk file for parallel writing # We do not use TemporaryFolders as the temp dir may not have enough space for large files create_ex_indir = os.path.join(cache_folder, "_bootleg_eval_temp_indir") utils.ensure_dir(create_ex_indir) create_ex_outdir = os.path.join(cache_folder, "_bootleg_eval_temp_outdir") utils.ensure_dir(create_ex_outdir) chunk_input = int(np.ceil(total_input / num_processes)) logger.debug( f"Chunking up {total_input} lines into subfiles of size {chunk_input} lines" ) # Chunk up dictionary of data for parallel processing input_files = [] i = 0 cur_lines = 0 file_split = os.path.join(create_ex_indir, f"out{i}.jsonl") open_file = open(file_split, "w") for s_idx in sent_idx2row: if cur_lines >= chunk_input: open_file.close() input_files.append(file_split) cur_lines = 0 i += 1 file_split = os.path.join(create_ex_indir, f"out{i}.jsonl") open_file = open(file_split, "w") line = sent_idx2row[s_idx] open_file.write(ujson.dumps(line) + "\n") cur_lines += 1 open_file.close() input_files.append(file_split) # Generation input/output pairs output_files = [ in_file_name.replace(create_ex_indir, create_ex_outdir) for in_file_name in input_files ] log_rank_0_debug(logger, f"Done chunking files. Starting pool") pool = multiprocessing.Pool( processes=num_processes, initializer=write_data_labels_initializer, initargs=[ merged_entity_emb_file, merged_storage_type, trie_file, result_alias_offset, train_in_candidates, max_candidates, dump_embs, trie_candidate_map_folder, trie_qid2eid_file, ], ) input_args = list(zip(input_files, output_files)) total = 0 for res in pool.imap(write_data_labels_hlp, input_args, chunksize=1): total += 1 # Merge output files to final file log_rank_0_debug(logger, f"Merging output files") with open(out_file, "wb") as outfile: for filename in glob.glob(os.path.join(create_ex_outdir, "*")): if filename == out_file: # don't want to copy the output into the output continue with open(filename, "rb") as readfile: shutil.copyfileobj(readfile, outfile)
def merge_subsentences( num_processes, subset_sent_idx2num_mens, cache_folder, to_save_file, to_save_storage, to_read_file, to_read_storage, dump_embs=False, ): """Flatten all sentences back together over sub-sentences; removing the PAD aliases from the data I.e., converts from sent_idx -> array of values to (sent_idx, alias_idx) -> value with varying numbers of aliases per sentence. Args: num_processes: number of processes subset_sent_idx2num_mens: Dict of sentence index to number of mentions for this batch cache_folder: cache directory to_save_file: memmap file to save results to to_save_storage: save file storage type to_read_file: memmap file to read predictions from to_read_storage: read file storage type dump_embs: whether to save embeddings or not Returns: """ # Compute sent idx to offset so we know where to fill in mentions cur_offset = 0 sentidx2offset = {} for k, v in subset_sent_idx2num_mens.items(): sentidx2offset[k] = cur_offset cur_offset += v # print("Sent Idx, Num Mens, Offset", k, v, cur_offset) total_num_mentions = cur_offset # print("TOTAL", total_num_mentions) full_pred_data = np.memmap(to_read_file, dtype=to_read_storage, mode="r") M = int(full_pred_data[0]["M"]) K = int(full_pred_data[0]["K"]) hidden_size = int(full_pred_data[0]["hidden_size"]) # print("TOTAL MENS", total_num_mentions) filt_emb_data = np.memmap(to_save_file, dtype=to_save_storage, mode="w+", shape=(total_num_mentions, )) filt_emb_data["hidden_size"] = hidden_size filt_emb_data["sent_idx"][:] = -1 filt_emb_data["alias_list_pos"][:] = -1 all_ids = list(range(0, len(full_pred_data))) start = time.time() if num_processes == 1: seen_ids = merge_subsentences_single( M, K, hidden_size, dump_embs, all_ids, filt_emb_data, full_pred_data, sentidx2offset, ) else: # Get trie for sentence start map trie_folder = os.path.join(cache_folder, "bootleg_sent_idx2num_mens") utils.ensure_dir(trie_folder) trie_file = os.path.join(trie_folder, "sentidx.marisa") utils.create_single_item_trie(sentidx2offset, out_file=trie_file) # Chunk up date chunk_size = int(np.ceil(len(full_pred_data) / num_processes)) row_idx_set_chunks = [ all_ids[ids:ids + chunk_size] for ids in range(0, len(full_pred_data), chunk_size) ] # Start pool input_args = [[M, K, hidden_size, dump_embs, chunk] for chunk in row_idx_set_chunks] log_rank_0_debug( logger, f"Merging sentences together with {num_processes} processes") pool = multiprocessing.Pool( processes=num_processes, initializer=merge_subsentences_initializer, initargs=[ to_save_file, to_save_storage, to_read_file, to_read_storage, trie_file, ], ) seen_ids = set() for sent_ids_seen in pool.imap_unordered(merge_subsentences_hlp, input_args, chunksize=1): for emb_id in sent_ids_seen: assert ( emb_id not in seen_ids ), f"{emb_id} already seen, something went wrong with sub-sentences" seen_ids.add(emb_id) # filt_emb_data = np.memmap(to_save_file, dtype=to_save_storage, mode="r") # for i in range(len(filt_emb_data)): # si = filt_emb_data[i]["sent_idx"] # al_test = filt_emb_data[i]["alias_list_pos"] # if si == -1 or al_test == -1: # print("BAD", i, filt_emb_data[i]) # import ipdb; ipdb.set_trace() logging.debug(f"Saw {len(seen_ids)} sentences") logging.debug(f"Time to merge sub-sentences {time.time() - start}s") return
def disambig_dump_preds( result_idx, result_alias_offset, config, res_dict, sent_idx2num_mens, sent_idx2row, save_folder, entity_symbols, dump_embs, task_name, ): """Dumps the predictions of a disambiguation task. Args: result_idx: batch index of the result arrays result_alias_offset: overall offset of the starting example (i.e., the number of previous mens already written) config: model config res_dict: result dictionary from Emmental predict sent_idx2num_mens: Dict sentence idx to number of mentions sent_idx2row: Dict sentence idx to row of eval data save_folder: folder to save results entity_symbols: entity symbols dump_embs: whether to save the contextualized embeddings or not task_name: task name Returns: saved prediction file, saved embedding file (will be None if dump_embs is False) """ num_processes = min(config.run_config.dataset_threads, int(multiprocessing.cpu_count() * 0.9)) cache_dir = os.path.join(save_folder, f"cache_{result_idx}") utils.ensure_dir(cache_dir) trie_candidate_map_folder = None trie_qid2eid_file = None # Save the alias->QID candidate map and the QID->EID mapping in memory efficient structures for faster # prediction dumping if num_processes > 1: entity_prep_dir = data_utils.get_emb_prep_dir(config.data_config) trie_candidate_map_folder = os.path.join(entity_prep_dir, "for_dumping_preds", "alias_cand_trie") utils.ensure_dir(trie_candidate_map_folder) check_and_create_alias_cand_trie(trie_candidate_map_folder, entity_symbols) trie_qid2eid_file = os.path.join(entity_prep_dir, "for_dumping_preds", "qid2eid_trie.marisa") if not os.path.exists(trie_qid2eid_file): utils.create_single_item_trie(entity_symbols.get_qid2eid(), out_file=trie_qid2eid_file) # This is dumping disambig_res_dict = {} for k in res_dict: assert task_name in res_dict[ k], f"{task_name} not in res_dict for key {k}" disambig_res_dict[k] = res_dict[k][task_name] # write to file (M x hidden x size for each data point -- next step will deal with recovering original sentence # indices for overflowing sentences) unmerged_entity_emb_file = os.path.join(save_folder, f"entity_embs.pt") merged_entity_emb_file = os.path.join(save_folder, f"entity_embs_unmerged.pt") emb_file_config = os.path.splitext( unmerged_entity_emb_file)[0] + "_config.npy" M = config.data_config.max_aliases K = entity_symbols.max_candidates + ( not config.data_config.train_in_candidates) if dump_embs: unmerged_storage_type = np.dtype([ ("M", int), ("K", int), ("hidden_size", int), ("sent_idx", int), ("subsent_idx", int), ("alias_list_pos", int, (M, )), ("entity_emb", float, M * config.model_config.hidden_size), ("final_loss_true", int, (M, )), ("final_loss_pred", int, (M, )), ("final_loss_prob", float, (M, )), ("final_loss_cand_probs", float, M * K), ]) merged_storage_type = np.dtype([ ("hidden_size", int), ("sent_idx", int), ("alias_list_pos", int), ("entity_emb", float, config.model_config.hidden_size), ("final_loss_pred", int), ("final_loss_prob", float), ("final_loss_cand_probs", float, K), ]) else: # don't need to extract contextualized entity embedding unmerged_storage_type = np.dtype([ ("M", int), ("K", int), ("hidden_size", int), ("sent_idx", int), ("subsent_idx", int), ("alias_list_pos", int, (M, )), ("final_loss_true", int, (M, )), ("final_loss_pred", int, (M, )), ("final_loss_prob", float, (M, )), ("final_loss_cand_probs", float, M * K), ]) merged_storage_type = np.dtype([ ("hidden_size", int), ("sent_idx", int), ("alias_list_pos", int), ("final_loss_pred", int), ("final_loss_prob", float), ("final_loss_cand_probs", float, K), ]) mmap_file = np.memmap( unmerged_entity_emb_file, dtype=unmerged_storage_type, mode="w+", shape=(len(disambig_res_dict["uids"]), ), ) # print("MEMMAP FILE SHAPE", len(disambig_res_dict["uids"])) # Init sent_idx to -1 for debugging mmap_file[:]["sent_idx"] = -1 np.save(emb_file_config, unmerged_storage_type, allow_pickle=True) log_rank_0_debug( logger, f"Created file {unmerged_entity_emb_file} to save predictions.") log_rank_0_debug(logger, f'{len(disambig_res_dict["uids"])} samples') for_iteration = [ disambig_res_dict["uids"], disambig_res_dict["golds"], disambig_res_dict["probs"], disambig_res_dict["preds"], ] all_sent_idx = set() for i, (uid, gold, probs, model_pred) in enumerate(zip(*for_iteration)): # disambig_res_dict["output"] is dict with keys ['_input__alias_orig_list_pos', # 'bootleg_pred_1', '_input__sent_idx', '_input__for_dump_gold_cand_K_idx_train', '_input__subsent_idx', 0, 1] sent_idx = disambig_res_dict["outputs"]["_input__sent_idx"][i] # print("INSIDE LOOP", sent_idx, "AT", i) subsent_idx = disambig_res_dict["outputs"]["_input__subsent_idx"][i] alias_orig_list_pos = disambig_res_dict["outputs"][ "_input__alias_orig_list_pos"][i] gold_cand_K_idx_train = disambig_res_dict["outputs"][ "_input__for_dump_gold_cand_K_idx_train"][i] output_embeddings = disambig_res_dict["outputs"][ f"{PRED_LAYER}_ent_embs"][i] mmap_file[i]["M"] = M mmap_file[i]["K"] = K mmap_file[i]["hidden_size"] = config.model_config.hidden_size mmap_file[i]["sent_idx"] = sent_idx mmap_file[i]["subsent_idx"] = subsent_idx mmap_file[i]["alias_list_pos"] = alias_orig_list_pos # This will give all aliases seen by the model during training, independent of if it's gold or not mmap_file[i][f"final_loss_true"] = gold_cand_K_idx_train.reshape(M) # get max for each alias, probs is M x K max_probs = probs.max(axis=1) pred_cands = probs.argmax(axis=1) mmap_file[i]["final_loss_pred"] = pred_cands mmap_file[i]["final_loss_prob"] = max_probs mmap_file[i]["final_loss_cand_probs"] = probs.reshape(1, -1) all_sent_idx.add(str(sent_idx)) # final_entity_embs is M x K x hidden_size, pred_cands is M if dump_embs: chosen_entity_embs = select_embs(embs=output_embeddings, pred_cands=pred_cands, M=M) # write chosen entity embs to file for contextualized entity embeddings mmap_file[i]["entity_emb"] = chosen_entity_embs.reshape(1, -1) # for i in range(len(mmap_file)): # si = mmap_file[i]["sent_idx"] # if -1 == si: # import pdb # pdb.set_trace() # assert si != -1, f"{i} {mmap_file[i]}" # Store all predicted sentences to filter the sentence mapping by subset_sent_idx2num_mens = { k: v for k, v in sent_idx2num_mens.items() if k in all_sent_idx } # print("ALL SEEN", all_sent_idx) subsent_sent_idx2row = { k: v for k, v in sent_idx2row.items() if k in all_sent_idx } result_file = get_result_file(result_idx, save_folder) log_rank_0_debug(logger, f"Writing predictions to {result_file}...") merge_subsentences( num_processes=num_processes, subset_sent_idx2num_mens=subset_sent_idx2num_mens, cache_folder=cache_dir, to_save_file=merged_entity_emb_file, to_save_storage=merged_storage_type, to_read_file=unmerged_entity_emb_file, to_read_storage=unmerged_storage_type, dump_embs=dump_embs, ) write_data_labels( num_processes=num_processes, result_alias_offset=result_alias_offset, merged_entity_emb_file=merged_entity_emb_file, merged_storage_type=merged_storage_type, sent_idx2row=subsent_sent_idx2row, cache_folder=cache_dir, out_file=result_file, entity_dump=entity_symbols, train_in_candidates=config.data_config.train_in_candidates, max_candidates=entity_symbols.max_candidates, dump_embs=dump_embs, trie_candidate_map_folder=trie_candidate_map_folder, trie_qid2eid_file=trie_qid2eid_file, ) out_emb_file = None filt_emb_data = np.memmap(merged_entity_emb_file, dtype=merged_storage_type, mode="r+") total_mentions_seen = len(filt_emb_data) # save easier-to-use embedding file if dump_embs: hidden_size = filt_emb_data[0]["hidden_size"] out_emb_file = get_emb_file(result_idx, save_folder) np.save(out_emb_file, filt_emb_data["entity_emb"].reshape(-1, hidden_size)) log_rank_0_debug( logger, f"Saving contextual entity embeddings for {result_idx} to {out_emb_file}", ) filt_emb_data = None # Cleanup cache - sometimes the file in cache_dir is still open so we need to retry to delete it try_rmtree(cache_dir) log_rank_0_debug(logger, f"Wrote predictions for {result_idx} to {result_file}") return result_file, out_emb_file, all_sent_idx, total_mentions_seen
def merge_data( num_processes, train_in_candidates, keep_orig, max_candidates, file_pairs, entity_dump_f, ): # File pair is in file, cand map file, out file, is_train # Chunk file for parallel writing create_ex_indir = os.path.join(os.path.dirname(file_pairs[0]), "_bootleg_temp_indir") utils.ensure_dir(create_ex_indir) create_ex_indir_cands = os.path.join(os.path.dirname(file_pairs[0]), "_bootleg_temp_indir2") utils.ensure_dir(create_ex_indir_cands) create_ex_outdir = os.path.join(os.path.dirname(file_pairs[0]), "_bootleg_temp_outdir") utils.ensure_dir(create_ex_outdir) print(f"Counting lines") total_input = sum(1 for _ in open(file_pairs[0])) total_input_cands = sum(1 for _ in open(file_pairs[1])) assert ( total_input_cands == total_input ), f"{total_input} lines of orig data != {total_input_cands} of cand data" chunk_input_size = int(np.ceil(total_input / num_processes)) total_input_from_chunks, input_files_dict = utils.chunk_file( file_pairs[0], create_ex_indir, chunk_input_size) total_input_cands_from_chunks, input_files_cands_dict = utils.chunk_file( file_pairs[1], create_ex_indir_cands, chunk_input_size) input_files = list(input_files_dict.keys()) input_cand_files = list(input_files_cands_dict.keys()) assert len(input_cand_files) == len(input_files) input_file_lines = [input_files_dict[k] for k in input_files] input_cand_file_lines = [ input_files_cands_dict[k] for k in input_cand_files ] for p_l, p_r in zip(input_file_lines, input_cand_file_lines): assert ( p_l == p_r ), f"The matching chunk files don't have matching sizes {p_l} versus {p_r}" output_files = [ in_file_name.replace(create_ex_indir, create_ex_outdir) for in_file_name in input_files ] assert ( total_input == total_input_from_chunks ), f"Lengths of files {total_input} doesn't match {total_input_from_chunks}" assert ( total_input_cands == total_input_cands_from_chunks ), f"Lengths of files {total_input_cands} doesn't match {total_input_cands_from_chunks}" # file_pairs is input file, cand map file, output file, is_train input_args = [[ train_in_candidates, keep_orig, max_candidates, input_files[i], input_file_lines[i], input_cand_files[i], output_files[i], file_pairs[3], ] for i in range(len(input_files))] pool = multiprocessing.Pool(processes=num_processes, initializer=init_process, initargs=[entity_dump_f]) new_alias2qids = {} total_seen = 0 total_dropped = 0 for res in pool.imap(merge_data_hlp, input_args, chunksize=1): temp_alias2qids, seen, dropped = res total_seen += seen total_dropped += dropped for k in temp_alias2qids: assert k not in new_alias2qids, f"{k}" new_alias2qids[k] = temp_alias2qids[k] print( f"Overall Recall for {file_pairs[0]}: {(total_seen - total_dropped) / total_seen} for seeing {total_seen}" ) # Merge output files to final file print(f"Merging output files") with open(file_pairs[2], "wb") as outfile: for filename in glob.glob(os.path.join(create_ex_outdir, "*")): if filename == file_pairs[2]: # don't want to copy the output into the output continue with open(filename, "rb") as readfile: shutil.copyfileobj(readfile, outfile) # Remove temporary files/folders shutil.rmtree(create_ex_indir) shutil.rmtree(create_ex_indir_cands) shutil.rmtree(create_ex_outdir) return new_alias2qids
def get_data_prep_dir(args): prep_dir = os.path.join(args.data_config.data_dir, args.data_config.data_prep_dir) utils.ensure_dir(prep_dir) return prep_dir
def load_regularization_mapping(cls, data_config, num_types_with_pad_and_unk, type2row_dict, reg_file): """Reads in a csv file with columns [qid, regularization]. In the forward pass, the entity id with associated qid will be regularized with probability regularization. Args: data_config: data config num_entities_with_pad_and_nocand: number of types including pad and null option type2row_dict: Dict from typeID to row id in the type embedding matrix reg_file: regularization csv file Returns: Tensor where each value is the regularization value for EID """ reg_str = os.path.splitext(os.path.basename(reg_file.replace("/", "_")))[0] prep_dir = data_utils.get_data_prep_dir(data_config) prep_file = os.path.join(prep_dir, f"type_regularization_mapping_{reg_str}.pt") utils.ensure_dir(os.path.dirname(prep_file)) log_rank_0_debug(logger, f"Looking for regularization mapping in {prep_file}") if not data_config.overwrite_preprocessed_data and os.path.exists( prep_file): log_rank_0_debug( logger, f"Loading existing entity regularization mapping from {prep_file}", ) start = time.time() typeid2reg = torch.load(prep_file) log_rank_0_debug( logger, f"Loaded existing entity regularization mapping in {round(time.time() - start, 2)}s", ) else: start = time.time() log_rank_0_debug( logger, f"Building entity regularization mapping from {reg_file}") typeid2reg_raw = pd.read_csv(reg_file) assert ( "typeid" in typeid2reg_raw.columns and "regularization" in typeid2reg_raw.columns ), f"Expected typeid and regularization as the column names for {reg_file}" # default of no mask typeid2reg_arr = [0.0] * num_types_with_pad_and_unk for row_idx, row in typeid2reg_raw.iterrows(): # Happens when we filter QIDs not in our entity db and the max typeid is smaller than the total number if int(row["typeid"]) not in type2row_dict: continue typeid = type2row_dict[int(row["typeid"])] typeid2reg_arr[typeid] = row["regularization"] typeid2reg = torch.Tensor(typeid2reg_arr) torch.save(typeid2reg, prep_file) log_rank_0_debug( logger, f"Finished building and saving entity regularization mapping in {round(time.time() - start, 2)}s.", ) return typeid2reg