コード例 #1
0
 def prep(cls,
          main_args,
          emb_args,
          entity_symbols,
          log_func=print,
          word_symbols=None):
     """Loads KG information into matrix format or loads from saved state"""
     file_tag = os.path.splitext(emb_args.kg_adj)[0]
     prep_dir = data_utils.get_emb_prep_dir(main_args)
     prep_file = os.path.join(prep_dir, f'kg_adj_file_{file_tag}.npz')
     utils.ensure_dir(os.path.dirname(prep_file))
     if (not main_args.data_config.overwrite_preprocessed_data
             and os.path.exists(prep_file)):
         log_func(f'Loading existing KG adj from {prep_file}')
         start = time.time()
         kg_adj = scipy.sparse.load_npz(prep_file)
         log_func(
             f'Loaded existing KG adj in {round(time.time() - start, 2)}s')
     else:
         start = time.time()
         kg_adj_file = os.path.join(main_args.data_config.emb_dir,
                                    emb_args.kg_adj)
         log_func(f'Building KG adj from {kg_adj_file}')
         kg_adj = cls.build_kg_adj(kg_adj_file, entity_symbols, emb_args)
         scipy.sparse.save_npz(prep_file, kg_adj)
         log_func(
             f"Finished building and saving KG adj in {round(time.time() - start, 2)}s."
         )
     return kg_adj, prep_file
コード例 #2
0
 def prep(cls,
          main_args,
          emb_args,
          entity_symbols,
          log_func=print,
          word_symbols=None):
     type_str = os.path.splitext(emb_args.type_labels)[0]
     prep_dir = data_utils.get_emb_prep_dir(main_args)
     prep_file = os.path.join(
         prep_dir, f'type_table_{type_str}_{emb_args.max_types}.pt')
     utils.ensure_dir(os.path.dirname(prep_file))
     if (not main_args.data_config.overwrite_preprocessed_data
             and os.path.exists(prep_file)):
         log_func(f'Loading existing type table from {prep_file}')
         start = time.time()
         eid2typeids_table, type2row, num_types_with_unk = torch.load(
             prep_file)
         log_func(
             f'Loaded existing type table in {round(time.time() - start, 2)}s'
         )
     else:
         start = time.time()
         log_func(f'Building type table')
         type_labels = os.path.join(main_args.data_config.emb_dir,
                                    emb_args.type_labels)
         eid2typeids_table, type2row, num_types_with_unk = cls.build_type_table(
             type_labels=type_labels,
             max_types=emb_args.max_types,
             entity_symbols=entity_symbols)
         torch.save((eid2typeids_table, type2row, num_types_with_unk),
                    prep_file)
         log_func(
             f"Finished building and saving type table in {round(time.time() - start, 2)}s."
         )
     return eid2typeids_table, type2row, num_types_with_unk, prep_file
コード例 #3
0
ファイル: type_symbols.py プロジェクト: lorr1/bootleg
    def save(self, save_dir, prefix=""):
        """Dumps the type symbols.

        Args:
            save_dir: directory string to save
            prefix: prefix to add to beginning to file

        Returns:
        """
        utils.ensure_dir(str(save_dir))
        utils.dump_json_file(
            filename=os.path.join(save_dir, "config.json"),
            contents={
                "max_types": self.max_types,
            },
        )
        utils.dump_json_file(
            filename=os.path.join(save_dir, f"{prefix}qid2typenames.json"),
            contents=self._qid2typenames,
        )
        utils.dump_json_file(
            filename=os.path.join(save_dir, f"{prefix}qid2typeids.json"),
            contents=self._qid2typeid,
        )
        utils.dump_json_file(
            filename=os.path.join(save_dir, f"{prefix}type_vocab.json"),
            contents=self._type_vocab,
        )
コード例 #4
0
ファイル: entity_symbols.py プロジェクト: pombredanne/bootleg
    def save(self, save_dir):
        """Dumps the entity symbols.

        Args:
            save_dir: directory string to save

        Returns:
        """
        self._sort_alias_cands()
        utils.ensure_dir(save_dir)
        utils.dump_json_file(
            filename=os.path.join(save_dir, "config.json"),
            contents={
                "max_candidates": self.max_candidates,
                "datetime": str(datetime.now()),
            },
        )
        utils.dump_json_file(
            filename=os.path.join(save_dir, self.alias_cand_map_file),
            contents=self._alias2qids,
        )
        utils.dump_json_file(
            filename=os.path.join(save_dir, "qid2title.json"), contents=self._qid2title
        )
        utils.dump_json_file(
            filename=os.path.join(save_dir, "qid2eid.json"), contents=self._qid2eid
        )
        utils.dump_json_file(
            filename=os.path.join(save_dir, self.alias_idx_file),
            contents=self._alias2id,
        )
コード例 #5
0
def main():
    args = parse_args()
    logging.info(json.dumps(args, indent=4))
    entity_symbols = EntitySymbols(
        load_dir=os.path.join(args.data_dir, args.entity_symbols_dir))
    train_file = os.path.join(args.data_dir, args.train_file)
    save_dir = os.path.join(args.save_dir, "stats")
    logging.info(f"Will save data to {save_dir}")
    utils.ensure_dir(save_dir)
    # compute_histograms(save_dir, entity_symbols)
    compute_occurrences(save_dir,
                        train_file,
                        entity_symbols,
                        args.lower,
                        args.strip,
                        num_workers=args.num_workers)
    if not args.no_types:
        type_symbols = TypeSymbols(
            entity_symbols=entity_symbols,
            emb_dir=args.emb_dir,
            max_types=args.max_types,
            emb_file="hyena_type_emb.pkl",
            type_vocab_file="hyena_type_graph.vocab.pkl",
            type_file="hyena_types.txt")
        compute_type_occurrences(save_dir, "orig", entity_symbols,
                                 type_symbols.qid2typenames, train_file)
コード例 #6
0
def compute_occurrences(save_dir,
                        data_file,
                        entity_dump,
                        lower,
                        strip,
                        num_workers=8):
    global all_aliases
    all_aliases = get_all_aliases(entity_dump._alias2qids)

    # divide up data into chunks
    num_lines = get_num_lines(data_file)
    num_processes = min(num_workers, int(multiprocessing.cpu_count()))
    logging.info(f'Using {num_processes} workers...')
    chunk_size = int(np.ceil(num_lines / (num_processes)))
    chunk_file_path = os.path.join(save_dir, 'tmp')
    utils.ensure_dir(chunk_file_path)
    chunk_infiles = [
        os.path.join(f'{chunk_file_path}', f'data_chunk_{chunk_id}_in.jsonl')
        for chunk_id in range(num_processes)
    ]
    chunk_text_data(data_file, chunk_infiles, chunk_size, num_lines)

    pool = multiprocessing.Pool(processes=num_processes)
    subprocess_args = [[chunk_infiles[i], lower, strip]
                       for i in range(num_processes)]
    results = pool.map(compute_occurrences_single, subprocess_args)
    pool.close()
    pool.join()
    logging.info('Finished collecting counts')
    logging.info('Merging counts....')
    # merge counters together
    ent_occurrences = Counter()
    # alias histogram
    alias_occurrences = Counter()
    # alias text occurrances
    alias_text_occurrences = Counter()
    # number of aliases per sentence
    alias_pair_occurrences = Counter()
    # alias|entity histogram
    alias_entity_pair = Counter()
    for result_set in results:
        ent_occurrences += result_set['ent_occurrences']
        alias_occurrences += result_set['alias_occurrences']
        alias_text_occurrences += result_set['alias_text_occurrences']
        alias_pair_occurrences += result_set['alias_pair_occurrences']
        alias_entity_pair += result_set['alias_entity_pair']
    # save counters
    utils.dump_json_file(filename=os.path.join(save_dir, "entity_count.json"),
                         contents=ent_occurrences)
    utils.dump_json_file(filename=os.path.join(save_dir, "alias_counts.json"),
                         contents=alias_occurrences)
    utils.dump_json_file(filename=os.path.join(save_dir,
                                               "alias_text_counts.json"),
                         contents=alias_text_occurrences)
    utils.dump_json_file(filename=os.path.join(save_dir,
                                               "alias_pair_occurrences.json"),
                         contents=alias_pair_occurrences)
    utils.dump_json_file(filename=os.path.join(save_dir,
                                               "alias_entity_counts.json"),
                         contents=alias_entity_pair)
コード例 #7
0
    def load_regularization_mapping(cls, data_config, entity_symbols,
                                    reg_file):
        """Reads in a csv file with columns [qid, regularization].

        In the forward pass, the entity id with associated qid will be
        regularized with probability regularization.

        Args:
            data_config: data config
            qid2topk_eid: Dict from QID to eid in the entity embedding
            num_entities_with_pad_and_nocand: number of entities including pad and null candidate option
            reg_file: regularization csv file

        Returns: Tensor where each value is the regularization value for EID
        """
        reg_str = os.path.splitext(os.path.basename(reg_file.replace("/",
                                                                     "_")))[0]
        prep_dir = data_utils.get_data_prep_dir(data_config)
        prep_file = os.path.join(
            prep_dir, f"entity_regularization_mapping_{reg_str}.pt")
        utils.ensure_dir(os.path.dirname(prep_file))
        log_rank_0_debug(logger,
                         f"Looking for regularization mapping in {prep_file}")
        if not data_config.overwrite_preprocessed_data and os.path.exists(
                prep_file):
            log_rank_0_debug(
                logger,
                f"Loading existing entity regularization mapping from {prep_file}",
            )
            start = time.time()
            eid2reg = torch.load(prep_file)
            log_rank_0_debug(
                logger,
                f"Loaded existing entity regularization mapping in {round(time.time() - start, 2)}s",
            )
        else:
            start = time.time()
            log_rank_0_info(
                logger,
                f"Building entity regularization mapping from {reg_file}")
            qid2reg = pd.read_csv(reg_file)
            assert (
                "qid" in qid2reg.columns
                and "regularization" in qid2reg.columns
            ), f"Expected qid and regularization as the column names for {reg_file}"
            # default of no mask
            eid2reg_arr = [0.0
                           ] * entity_symbols.num_entities_with_pad_and_nocand
            for row_idx, row in qid2reg.iterrows():
                if entity_symbols.qid_exists(row["qid"]):
                    eid = entity_symbols.get_eid(row["qid"])
                    eid2reg_arr[eid] = row["regularization"]
            eid2reg = torch.tensor(eid2reg_arr)
            torch.save(eid2reg, prep_file)
            log_rank_0_debug(
                logger,
                f"Finished building and saving entity regularization mapping in {round(time.time() - start, 2)}s.",
            )
        return eid2reg
コード例 #8
0
def setup_run_folders(args, mode):
    if args.run_config.timestamp == "":
        # create a timestamp for directory for saving results
        start_date = strftime("%Y%m%d")
        start_time = strftime("%H%M%S")
        args.run_config.timestamp = "{:s}_{:s}".format(start_date, start_time)
        utils.ensure_dir(get_save_folder(args.run_config))
    return
コード例 #9
0
ファイル: alias_to_ent_encoder.py プロジェクト: lorr1/bootleg
    def prep(
        cls,
        data_config,
        entity_symbols,
        num_aliases_with_pad_and_unk,
        num_cands_K,
    ):
        """Preps the alias to entity EID table.

        Args:
            data_config: data config
            entity_symbols: entity symbols
            num_aliases_with_pad_and_unk: number of aliases including pad and unk
            num_cands_K: number of candidates per alias (aka K)

        Returns: torch Tensor of the alias to EID table, save pt file
        """
        # we pass num_aliases_with_pad_and_unk and num_cands_K to remove the dependence on entity_symbols
        # when the alias table is already prepped
        data_shape = (num_aliases_with_pad_and_unk, num_cands_K)
        # dependent on train_in_candidates flag
        prep_dir = data_utils.get_emb_prep_dir(data_config)
        alias_str = os.path.splitext(
            data_config.alias_cand_map.replace("/", "_"))[0]
        prep_file = os.path.join(
            prep_dir,
            f"alias2entity_table_{alias_str}_InC{int(data_config.train_in_candidates)}.pt",
        )
        log_rank_0_debug(logger, f"Looking for alias table in {prep_file}")
        if not data_config.overwrite_preprocessed_data and os.path.exists(
                prep_file):
            log_rank_0_debug(logger, f"Loading alias table from {prep_file}")
            start = time.time()
            alias2entity_table = np.memmap(prep_file,
                                           dtype="int64",
                                           mode="r+",
                                           shape=data_shape)
            log_rank_0_debug(
                logger,
                f"Loaded alias table in {round(time.time() - start, 2)}s")
        else:
            start = time.time()
            log_rank_0_debug(logger, f"Building alias table")
            utils.ensure_dir(prep_dir)
            mmap_file = np.memmap(prep_file,
                                  dtype="int64",
                                  mode="w+",
                                  shape=data_shape)
            alias2entity_table = cls.build_alias_table(data_config,
                                                       entity_symbols)
            mmap_file[:] = alias2entity_table[:]
            mmap_file.flush()
            log_rank_0_debug(
                logger,
                f"Finished building and saving alias table in {round(time.time() - start, 2)}s.",
            )
        alias2entity_table = torch.from_numpy(alias2entity_table)
        return alias2entity_table, prep_file
コード例 #10
0
    def prep(cls, data_config, entity_symbols):
        """Prep the title data.

        Args:
            data_config: data config
            entity_symbols: entity symbols

        Returns: torch tensor EID to title token IDs, EID to title token mask, EID to title token type ID (for BERT)
        """
        prep_dir = data_utils.get_emb_prep_dir(data_config)
        prep_file_token_ids = os.path.join(
            prep_dir,
            f"title_token_ids_{data_config.word_embedding.bert_model}.pt")
        prep_file_attn_mask = os.path.join(
            prep_dir,
            f"title_attn_mask_{data_config.word_embedding.bert_model}.pt")
        prep_file_token_type_ids = os.path.join(
            prep_dir,
            f"title_token_type_ids_{data_config.word_embedding.bert_model}.pt")
        utils.ensure_dir(os.path.dirname(prep_file_token_ids))
        log_rank_0_debug(
            logger,
            f"Looking for title table mapping in {prep_file_token_ids}")
        if (not data_config.overwrite_preprocessed_data
                and os.path.exists(prep_file_token_ids)
                and os.path.exists(prep_file_attn_mask)
                and os.path.exists(prep_file_token_type_ids)):
            log_rank_0_debug(
                logger,
                f"Loading existing title table from {prep_file_token_ids}")
            start = time.time()
            entity2titleid = torch.load(prep_file_token_ids)
            entity2titlemask = torch.load(prep_file_attn_mask)
            entity2tokentypeid = torch.load(prep_file_token_type_ids)
            log_rank_0_debug(
                logger,
                f"Loaded existing title table in {round(time.time() - start, 2)}s",
            )
        else:
            start = time.time()
            log_rank_0_debug(logger, f"Loading tokenizer")
            tokenizer = load_tokenizer(data_config)
            (
                entity2titleid,
                entity2titlemask,
                entity2tokentypeid,
            ) = cls.build_title_table(tokenizer=tokenizer,
                                      entity_symbols=entity_symbols)
            torch.save(entity2titleid, prep_file_token_ids)
            torch.save(entity2titlemask, prep_file_attn_mask)
            torch.save(entity2tokentypeid, prep_file_token_type_ids)
            log_rank_0_debug(
                logger,
                f"Finished building and saving title table in {round(time.time() - start, 2)}s.",
            )
        return entity2titleid, entity2titlemask, entity2tokentypeid
コード例 #11
0
def get_data_prep_dir(data_config):
    """Get data prep directory for saving prep files. Lives inside data_dir.

    Args:
        data_config: data config

    Returns: directory path
    """
    prep_dir = os.path.join(data_config.data_dir, data_config.data_prep_dir)
    utils.ensure_dir(prep_dir)
    return prep_dir
コード例 #12
0
ファイル: entity_symbols.py プロジェクト: paper2code/bootleg
 def dump(self, save_dir, stats={}, args=None):
     self._sort_alias_cands()
     utils.ensure_dir(save_dir)
     utils.dump_json_file(filename=os.path.join(save_dir, "config.json"), contents={"max_candidates":self.max_candidates,
                                                                                    "max_alias_len":self.max_alias_len,
                                                                                    "datetime": str(datetime.now())})
     utils.dump_json_file(filename=os.path.join(save_dir, self.alias_cand_map_file), contents=self._alias2qids)
     utils.dump_json_file(filename=os.path.join(save_dir, "qid2title.json"), contents=self._qid2title)
     utils.dump_json_file(filename=os.path.join(save_dir, "qid2eid.json"), contents=self._qid2eid)
     utils.dump_json_file(filename=os.path.join(save_dir, "filter_stats.json"), contents=stats)
     if args is not None:
         utils.dump_json_file(filename=os.path.join(save_dir, "args.json"), contents=vars(args))
コード例 #13
0
 def __init__(self, args, is_writer, distributed):
     super(BERTWordSymbols, self).__init__(args, is_writer, distributed)
     self.is_bert = True
     self.unk_id = None
     self.pad_id = 0
     cache_dir = args.word_embedding.cache_dir
     utils.ensure_dir(cache_dir)
     # import torch
     # import os
     # from transformers import BertTokenizer
     # cache_dir = "pretrained_bert_models"
     # tokenizer = BertTokenizer.from_pretrained('bert-base-cased', cache_dir=cache_dir, do_lower_case=False)
     # torch.save(tokenizer, os.path.join(cache_dir, "bert_base_cased_tokenizer.pt"))
     self.tokenizer = torch.load(os.path.join(cache_dir, "bert_base_cased_tokenizer.pt"))
     self.vocab = self.tokenizer.vocab
     self.num_words = len(self.vocab)
     self.word_embedding_dim = 768
コード例 #14
0
ファイル: type_embs.py プロジェクト: pombredanne/bootleg
    def prep(cls, data_config, emb_args, entity_symbols):
        """Prep the type id table.

        Args:
            data_config: data config
            emb_args: embedding args
            entity_symbols: entity synbols

        Returns: torch tensor from EID to type IDS, type ID to row in type embedding matrix,
                 and number of types with unk type
        """
        type_str = os.path.splitext(emb_args.type_labels.replace("/", "_"))[0]
        prep_dir = data_utils.get_emb_prep_dir(data_config)
        prep_file = os.path.join(
            prep_dir, f"type_table_{type_str}_{emb_args.max_types}.pt")
        utils.ensure_dir(os.path.dirname(prep_file))
        if not data_config.overwrite_preprocessed_data and os.path.exists(
                prep_file):
            log_rank_0_debug(logger,
                             f"Loading existing type table from {prep_file}")
            start = time.time()
            eid2typeids_table, type2row_dict, num_types_with_unk = torch.load(
                prep_file)
            log_rank_0_debug(
                logger,
                f"Loaded existing type table in {round(time.time() - start, 2)}s",
            )
        else:
            start = time.time()
            type_labels = os.path.join(data_config.emb_dir,
                                       emb_args.type_labels)
            type_vocab = os.path.join(data_config.emb_dir, emb_args.type_vocab)
            log_rank_0_debug(logger, f"Building type table from {type_labels}")
            eid2typeids_table, type2row_dict, num_types_with_unk = cls.build_type_table(
                type_labels=type_labels,
                type_vocab=type_vocab,
                max_types=emb_args.max_types,
                entity_symbols=entity_symbols,
            )
            torch.save((eid2typeids_table, type2row_dict, num_types_with_unk),
                       prep_file)
            log_rank_0_debug(
                logger,
                f"Finished building and saving type table in {round(time.time() - start, 2)}s.",
            )
        return eid2typeids_table, type2row_dict, num_types_with_unk, prep_file
コード例 #15
0
ファイル: kg_embs.py プロジェクト: pombredanne/bootleg
    def prep(
        cls,
        data_config,
        emb_args,
        entity_symbols,
        threshold,
        log_weight,
    ):
        """Preps the KG information.

        Args:
            data_config: data config
            emb_args: embedding args
            entity_symbols: entity symbols
            threshold: weight threshold for counting an edge
            log_weight: whether to take the log of the weight value after the threshold

        Returns: numpy sparce KG adjacency matrix, prep file
        """
        file_tag = os.path.splitext(emb_args.kg_adj.replace("/", "_"))[0]
        prep_dir = data_utils.get_emb_prep_dir(data_config)
        prep_file = os.path.join(
            prep_dir, f"kg_adj_file_{file_tag}_{threshold}_{log_weight}.npz")
        utils.ensure_dir(os.path.dirname(prep_file))
        if not data_config.overwrite_preprocessed_data and os.path.exists(
                prep_file):
            log_rank_0_debug(logger,
                             f"Loading existing KG adj from {prep_file}")
            start = time.time()
            kg_adj = scipy.sparse.load_npz(prep_file)
            log_rank_0_debug(
                logger,
                f"Loaded existing KG adj in {round(time.time() - start, 2)}s")
        else:
            start = time.time()
            kg_adj_file = os.path.join(data_config.emb_dir, emb_args.kg_adj)
            log_rank_0_debug(logger, f"Building KG adj from {kg_adj_file}")
            kg_adj = cls.build_kg_adj(kg_adj_file, entity_symbols, threshold,
                                      log_weight)
            scipy.sparse.save_npz(prep_file, kg_adj)
            log_rank_0_debug(
                logger,
                f"Finished building and saving KG adj in {round(time.time() - start, 2)}s.",
            )
        return kg_adj, prep_file
コード例 #16
0
def main():
    args = parse_args()
    logging.info(json.dumps(vars(args), indent=4))
    entity_symbols = EntitySymbols.load_from_cache(
        load_dir=os.path.join(args.data_dir, args.entity_symbols_dir))
    train_file = os.path.join(args.data_dir, args.train_file)
    save_dir = os.path.join(args.save_dir, "stats")
    logging.info(f"Will save data to {save_dir}")
    utils.ensure_dir(save_dir)
    # compute_histograms(save_dir, entity_symbols)
    compute_occurrences(
        save_dir,
        train_file,
        entity_symbols,
        args.lower,
        args.strip,
        num_workers=args.num_workers,
    )
コード例 #17
0
 def dump(self, save_dir):
     #memmapped files bahve badly if you try to overwrite them in memory, which is what we'd be doing if load_dir == save_dir
     if self._loaded_from_dir is None or self._loaded_from_dir != save_dir:
         utils.ensure_dir(save_dir)
         utils.dump_json_file(filename=os.path.join(save_dir,
                                                    "fmt_types.json"),
                              contents=self._fmt_types)
         utils.dump_json_file(filename=os.path.join(save_dir,
                                                    "max_values.json"),
                              contents=self._max_values)
         utils.dump_json_file(filename=os.path.join(save_dir,
                                                    "vocabulary.json"),
                              contents=self._stoi)
         np.save(file=os.path.join(save_dir, "itos.npy"),
                 arr=self._itos,
                 allow_pickle=True)
         for tri_name in self._fmt_types:
             self._record_tris[tri_name].save(
                 os.path.join(save_dir, f'record_trie_{tri_name}.marisa'))
コード例 #18
0
ファイル: type_embs.py プロジェクト: syyunn/bootleg
 def load_regularization_mapping(cls, main_args, num_types_with_pad_and_unk,
                                 type2row_dict, reg_file, log_func):
     """
     Reads in a csv file with columns [qid, regularization].
     In the forward pass, the entity id with associated qid will be regularized with probability regularization.
     """
     reg_str = reg_file.split(".csv")[0]
     prep_dir = data_utils.get_data_prep_dir(main_args)
     prep_file = os.path.join(
         prep_dir, f'entity_regularization_mapping_{reg_str}.pt')
     utils.ensure_dir(os.path.dirname(prep_file))
     log_func(f"Looking for regularization mapping in {prep_file}")
     if (not main_args.data_config.overwrite_preprocessed_data
             and os.path.exists(prep_file)):
         log_func(
             f'Loading existing entity regularization mapping from {prep_file}'
         )
         start = time.time()
         typeid2reg = torch.load(prep_file)
         log_func(
             f'Loaded existing entity regularization mapping in {round(time.time() - start, 2)}s'
         )
     else:
         start = time.time()
         reg_file = os.path.join(main_args.data_config.data_dir, reg_file)
         log_func(f'Building entity regularization mapping from {reg_file}')
         typeid2reg_raw = pd.read_csv(reg_file)
         assert "typeid" in typeid2reg_raw.columns and "regularization" in typeid2reg_raw.columns, f"Expected typeid and regularization as the column names for {reg_file}"
         # default of no mask
         typeid2reg_arr = [0.0] * num_types_with_pad_and_unk
         for row_idx, row in typeid2reg_raw.iterrows():
             # Happens when we filter QIDs not in our entity dump and the max typeid is smaller than the total number
             if int(row["typeid"]) not in type2row_dict:
                 continue
             typeid = type2row_dict[int(row["typeid"])]
             typeid2reg_arr[typeid] = row["regularization"]
         typeid2reg = torch.tensor(typeid2reg_arr)
         torch.save(typeid2reg, prep_file)
         log_func(
             f"Finished building and saving entity regularization mapping in {round(time.time() - start, 2)}s."
         )
     return typeid2reg
コード例 #19
0
 def prep(cls,
          args,
          entity_symbols,
          num_aliases_with_pad,
          num_cands_K,
          log_func=print):
     # we pass num_aliases_with_pad and num_cands_K to remove the dependence on entity_symbols
     # when the alias table is already prepped
     data_shape = (num_aliases_with_pad, num_cands_K)
     # dependent on train_in_candidates flag
     prep_dir = data_utils.get_emb_prep_dir(args)
     alias_str = os.path.splitext(args.data_config.alias_cand_map)[0]
     prep_file = os.path.join(
         prep_dir,
         f'alias2entity_table_{alias_str}_InC{int(args.data_config.train_in_candidates)}.pt'
     )
     if (not args.data_config.overwrite_preprocessed_data
             and os.path.exists(prep_file)):
         log_func(f'Loading alias table from {prep_file}')
         start = time.time()
         alias2entity_table = np.memmap(prep_file,
                                        dtype='int64',
                                        mode='r',
                                        shape=data_shape)
         log_func(f'Loaded alias table in {round(time.time() - start, 2)}s')
     else:
         start = time.time()
         log_func(f'Building alias table')
         utils.ensure_dir(prep_dir)
         mmap_file = np.memmap(prep_file,
                               dtype='int64',
                               mode='w+',
                               shape=data_shape)
         alias2entity_table = cls.build_alias_table(args, entity_symbols)
         mmap_file[:] = alias2entity_table[:]
         mmap_file.flush()
         log_func(
             f"Finished building and saving alias table in {round(time.time() - start, 2)}s."
         )
     return alias2entity_table, prep_file
コード例 #20
0
ファイル: entity_embs.py プロジェクト: syyunn/bootleg
 def load_regularization_mapping(cls, main_args, entity_symbols, reg_file,
                                 log_func):
     """
     Reads in a csv file with columns [qid, regularization].
     In the forward pass, the entity id with associated qid will be regularized with probability regularization.
     """
     reg_str = reg_file.split(".csv")[0]
     prep_dir = data_utils.get_data_prep_dir(main_args)
     prep_file = os.path.join(
         prep_dir, f'entity_regularization_mapping_{reg_str}.pt')
     utils.ensure_dir(os.path.dirname(prep_file))
     log_func(f"Looking for regularization mapping in {prep_file}")
     if (not main_args.data_config.overwrite_preprocessed_data
             and os.path.exists(prep_file)):
         log_func(
             f'Loading existing entity regularization mapping from {prep_file}'
         )
         start = time.time()
         eid2reg = torch.load(prep_file)
         log_func(
             f'Loaded existing entity regularization mapping in {round(time.time() - start, 2)}s'
         )
     else:
         start = time.time()
         log_func(f'Building entity regularization mapping from {reg_file}')
         qid2reg = pd.read_csv(reg_file)
         assert "qid" in qid2reg.columns and "regularization" in qid2reg.columns, f"Expected qid and regularization as the column names for {reg_file}"
         # default of no mask
         eid2reg_arr = [0.0
                        ] * entity_symbols.num_entities_with_pad_and_nocand
         for row_idx, row in qid2reg.iterrows():
             eid = entity_symbols.get_eid(row["qid"])
             eid2reg_arr[eid] = row["regularization"]
         eid2reg = torch.tensor(eid2reg_arr)
         torch.save(eid2reg, prep_file)
         log_func(
             f"Finished building and saving entity regularization mapping in {round(time.time() - start, 2)}s."
         )
     return eid2reg
コード例 #21
0
    def test_end2end_withreg_evalbatch(self):
        reg_file = "test/temp/reg_file.csv"
        utils.ensure_dir("test/temp")
        reg_data = [
            ["qid", "regularization"],
            ["Q1", "0.5"],
            ["Q2", "0.3"],
            ["Q3", "0.2"],
            ["Q4", "0.9"],
        ]
        self.args.data_config.eval_accumulation_steps = 2
        self.args.run_config.dataset_threads = 2
        self.args.run_config.eval_batch_size = 2
        with open(reg_file, "w") as out_f:
            for item in reg_data:
                out_f.write(",".join(item) + "\n")

        self.args.data_config.ent_embeddings[0]["args"][
            "regularize_mapping"] = reg_file
        scores = run_model(mode="train", config=self.args)
        assert type(scores) is dict
        assert len(scores) > 0
        assert scores["model/all/train/loss"] < 0.05

        self.args["model_config"][
            "model_path"] = f"{emmental.Meta.log_path}/last_model.pth"
        emmental.Meta.config["model_config"][
            "model_path"] = f"{emmental.Meta.log_path}/last_model.pth"

        result_file, out_emb_file = run_model(mode="dump_embs",
                                              config=self.args)
        assert os.path.exists(result_file)
        results = [ujson.loads(li) for li in open(result_file)]
        assert 18 == len(results)  # 18 total sentences
        assert set([f for li in results for f in li["ctx_emb_ids"]
                    ]) == set(range(51))  # 38 total mentions
        assert os.path.exists(out_emb_file)

        shutil.rmtree("test/temp", ignore_errors=True)
コード例 #22
0
    def __init__(
        self,
        main_args,
        dataset,
        use_weak_label,
        entity_symbols,
        dataset_threads,
        split="train",
    ):
        global_start = time.time()
        log_rank_0_info(logger,
                        f"Building slice dataset for {split} from {dataset}.")
        spawn_method = main_args.run_config.spawn_method
        data_config = main_args.data_config
        orig_spawn = multiprocessing.get_start_method()
        multiprocessing.set_start_method(spawn_method, force=True)
        self.slice_names = data_utils.get_eval_slices(data_config.eval_slices)
        self.get_slice_dt = lambda max_a2p: np.dtype([
            ("sent_idx", int),
            ("subslice_idx", int),
            ("alias_slice_incidence", int, (max_a2p, )),
            ("prob_labels", float, (max_a2p, )),
        ])
        self.get_storage = lambda max_a2p: np.dtype(
            [(slice_name, self.get_slice_dt(max_a2p))
             for slice_name in self.slice_names])
        # Folder for all mmap saved files
        save_dataset_folder = data_utils.get_save_data_folder(
            data_config, use_weak_label, dataset)
        utils.ensure_dir(save_dataset_folder)
        # Folder for temporary output files
        temp_output_folder = os.path.join(data_config.data_dir,
                                          data_config.data_prep_dir,
                                          f"prep_{split}_slice_files")
        utils.ensure_dir(temp_output_folder)
        # Input step 1
        create_ex_indir = os.path.join(temp_output_folder,
                                       "create_examples_input")
        utils.ensure_dir(create_ex_indir)
        # Input step 2
        create_ex_outdir = os.path.join(temp_output_folder,
                                        "create_examples_output")
        utils.ensure_dir(create_ex_outdir)
        # Meta data saved files
        meta_file = os.path.join(temp_output_folder, "meta_data.json")
        # File for standard training data
        hash = hashlib.sha1(str(
            self.slice_names).encode("UTF-8")).hexdigest()[:10]
        self.save_dataset_name = os.path.join(save_dataset_folder,
                                              f"ned_slices_{hash}.bin")
        self.save_data_config_name = os.path.join(save_dataset_folder,
                                                  "ned_slices_config.json")

        # =======================================================================================
        # SLICE DATA
        # =======================================================================================
        log_rank_0_debug(logger, "Loading dataset...")
        log_rank_0_debug(logger, f"Seeing if {self.save_dataset_name} exists")
        if data_config.overwrite_preprocessed_data or (not os.path.exists(
                self.save_dataset_name)):
            st_time = time.time()
            try:
                log_rank_0_info(
                    logger,
                    f"Building dataset from scratch. Saving to {save_dataset_folder}",
                )
                create_examples(
                    dataset,
                    create_ex_indir,
                    create_ex_outdir,
                    meta_file,
                    data_config,
                    dataset_threads,
                    self.slice_names,
                    use_weak_label,
                    split,
                )
                max_alias2pred = utils.load_json_file(
                    meta_file)["max_alias2pred"]
                convert_examples_to_features_and_save(
                    meta_file,
                    dataset_threads,
                    self.slice_names,
                    self.save_dataset_name,
                    self.get_storage(max_alias2pred),
                )
                utils.dump_json_file(self.save_data_config_name,
                                     {"max_alias2pred": max_alias2pred})

                log_rank_0_debug(
                    logger,
                    f"Finished prepping data in {time.time() - st_time}")
            except Exception as e:
                tb = traceback.TracebackException.from_exception(e)
                logger.error(e)
                logger.error("\n".join(tb.stack.format()))
                shutil.rmtree(save_dataset_folder, ignore_errors=True)
                raise

        log_rank_0_info(
            logger,
            f"Loading data from {self.save_dataset_name} and {self.save_data_config_name}",
        )
        max_alias2pred = utils.load_json_file(
            self.save_data_config_name)["max_alias2pred"]
        self.data, self.sent_to_row_id_dict = self.build_data_dict(
            self.save_dataset_name, self.get_storage(max_alias2pred))
        assert len(self.data) > 0
        assert len(self.sent_to_row_id_dict) > 0
        log_rank_0_debug(logger, f"Removing temporary output files")
        shutil.rmtree(temp_output_folder, ignore_errors=True)
        # Set spawn back to original/default, which is "fork" or "spawn". This is needed for the Meta.config to
        # be correctly passed in the collate_fn.
        multiprocessing.set_start_method(orig_spawn, force=True)
        log_rank_0_info(
            logger,
            f"Final slice data initialization time from {split} is {time.time() - global_start}s",
        )
コード例 #23
0
ファイル: eval_utils.py プロジェクト: parakalan/bootleg
def run_dump_preds(args,
                   test_data_file,
                   trainer,
                   dataloader,
                   logger,
                   entity_symbols,
                   dump_embs=False):
    """
    Dumpes Preds:
    Remember that if a sentence has all gold=False anchors, it's dropped and will not be seen
    If a subsplit of a sentence has all gold=False anchors, it will also be dropped and not seen
    """
    # we only care about the entity embeddings for the final slice head
    eval_folder = train_utils.get_eval_folder(args, test_data_file)
    utils.ensure_dir(eval_folder)

    # write to file (M x hidden x size for each data point -- next step will deal with recovering original sentence indices for overflowing sentences)
    test_file_tag = test_data_file.split('.jsonl')[0]
    entity_emb_file = os.path.join(eval_folder,
                                   f'{test_file_tag}_entity_embs.pt')
    emb_file_config = entity_emb_file.split('.pt')[0] + '_config'
    M = args.data_config.max_aliases
    K = entity_symbols.max_candidates + (
        not args.data_config.train_in_candidates)
    # TODO: fix extra dimension issue
    if dump_embs:
        storage_type = np.dtype([('M', int), ('K', int), ('hidden_size', int),
                                 ('sent_idx', int), ('subsent_idx', int),
                                 ('alias_list_pos', int, M),
                                 ('entity_emb', float,
                                  M * args.model_config.hidden_size),
                                 ('final_loss_true', int, M),
                                 ('final_loss_pred', int, M),
                                 ('final_loss_prob', float, M),
                                 ('final_loss_cand_probs', float, M * K)])
    else:
        # don't need to extract contextualized entity embedding
        storage_type = np.dtype([('M', int), ('K', int), ('hidden_size', int),
                                 ('sent_idx', int), ('subsent_idx', int),
                                 ('alias_list_pos', int, M),
                                 ('final_loss_true', int, M),
                                 ('final_loss_pred', int, M),
                                 ('final_loss_prob', float, M),
                                 ('final_loss_cand_probs', float, M * K)])
    mmap_file = np.memmap(entity_emb_file,
                          dtype=storage_type,
                          mode='w+',
                          shape=(len(dataloader.dataset), ))
    # Init sent_idx to -1 for debugging
    mmap_file[:]['sent_idx'] = -1
    np.save(emb_file_config, storage_type, allow_pickle=True)
    logger.debug(f'Created file {entity_emb_file} to save predictions.')

    start_idx = 0
    logger.info(
        f'{len(dataloader)*args.run_config.eval_batch_size} samples, {len(dataloader)} batches, {len(dataloader.dataset)} len dataset'
    )
    for i, batch in enumerate(dataloader):
        curr_batch_size = batch["sent_idx"].shape[0]
        end_idx = start_idx + curr_batch_size
        preds, _, entity_pack, final_entity_embs = trainer.update(batch,
                                                                  eval=True)
        model_preds = preds[DISAMBIG][FINAL_LOSS]
        # don't want to choose padded entity indices
        probs = torch.exp(
            masked_class_logsoftmax(pred=model_preds,
                                    mask=~entity_pack.mask,
                                    dim=2))

        mmap_file[start_idx:end_idx]['M'] = M
        mmap_file[start_idx:end_idx]['K'] = K
        mmap_file[start_idx:end_idx][
            'hidden_size'] = args.model_config.hidden_size
        mmap_file[start_idx:end_idx]['sent_idx'] = batch["sent_idx"]
        mmap_file[start_idx:end_idx]['subsent_idx'] = batch["subsent_idx"]
        mmap_file[start_idx:end_idx]['alias_list_pos'] = batch[
            'alias_list_pos']
        # This will give all aliases seen by the model during training, independent of if it's gold or not
        mmap_file[start_idx:end_idx][f'final_loss_true'] = batch[
            'true_entity_idx_for_train'].reshape(curr_batch_size,
                                                 M).cpu().numpy()

        # get max for each alias, probs is batch x M x K
        max_probs, pred_cands = probs.max(dim=2)

        mmap_file[start_idx:end_idx]['final_loss_pred'] = pred_cands.cpu(
        ).numpy()
        mmap_file[start_idx:end_idx]['final_loss_prob'] = max_probs.cpu(
        ).numpy()
        mmap_file[start_idx:end_idx]['final_loss_cand_probs'] = probs.cpu(
        ).numpy().reshape(curr_batch_size, -1)

        # final_entity_embs is batch x M x K x hidden_size, pred_cands in batch x M
        if dump_embs:
            chosen_entity_embs = select_embs(embs=final_entity_embs,
                                             pred_cands=pred_cands,
                                             batch_size=curr_batch_size,
                                             M=M)

            # write chosen entity embs to file for contextualized entity embeddings
            mmap_file[start_idx:end_idx][
                'entity_emb'] = chosen_entity_embs.reshape(
                    curr_batch_size, -1).cpu().numpy()

        start_idx += curr_batch_size
        if i % 100 == 0 and i != 0:
            logger.info(f'Saved {i} batches of predictions')

    # restitch together and write data file
    result_file = os.path.join(eval_folder, args.run_config.result_label_file)
    logger.info(f'Writing predictions to {result_file}...')
    filt_pred_data = merge_subsentences(os.path.join(args.data_config.data_dir,
                                                     test_data_file),
                                        mmap_file,
                                        dump_embs=dump_embs)
    sent_idx_map = get_sent_idx_map(filt_pred_data)

    write_data_labels(filt_pred_data=filt_pred_data,
                      data_file=os.path.join(args.data_config.data_dir,
                                             test_data_file),
                      out_file=result_file,
                      sent_idx_map=sent_idx_map,
                      entity_dump=entity_symbols,
                      train_in_candidates=args.data_config.train_in_candidates,
                      dump_embs=dump_embs)

    out_emb_file = None
    # save easier-to-use embedding file
    if dump_embs:
        hidden_size = filt_pred_data[0]['hidden_size']
        out_emb_file = os.path.join(eval_folder,
                                    args.run_config.result_emb_file)
        np.save(out_emb_file,
                filt_pred_data['entity_emb'].reshape(-1, hidden_size))
        logger.info(f'Saving contextual entity embeddings to {out_emb_file}')
    logger.info(f'Wrote predictions to {result_file}')
    return result_file, out_emb_file
コード例 #24
0
def run_model(mode, config, run_config_path=None):
    """
    Main run method for Emmental Bootleg models.
    Args:
        mode: run mode (train, eval, dump_preds, dump_embs)
        config: parsed model config
        run_config_path: original config path (for saving)

    Returns:

    """

    # Set up distributed backend and save configuration files
    setup(config, run_config_path)

    # Load entity symbols
    log_rank_0_info(logger, f"Loading entity symbols...")
    entity_symbols = EntitySymbols.load_from_cache(
        load_dir=os.path.join(config.data_config.entity_dir,
                              config.data_config.entity_map_dir),
        alias_cand_map_file=config.data_config.alias_cand_map,
        alias_idx_file=config.data_config.alias_idx_map,
    )
    # Create tasks
    tasks = [NED_TASK]
    if config.data_config.type_prediction.use_type_pred is True:
        tasks.append(TYPE_PRED_TASK)

    # Create splits for data loaders
    data_splits = [TRAIN_SPLIT, DEV_SPLIT, TEST_SPLIT]
    # Slices are for eval so we only split on test/dev
    slice_splits = [DEV_SPLIT, TEST_SPLIT]
    # If doing eval, only run on test data
    if mode in ["eval", "dump_preds", "dump_embs"]:
        data_splits = [TEST_SPLIT]
        slice_splits = [TEST_SPLIT]
        # We only do dumping if weak labels is True
        if mode in ["dump_preds", "dump_embs"]:
            if config.data_config[
                    f"{TEST_SPLIT}_dataset"].use_weak_label is False:
                raise ValueError(
                    f"When calling dump_preds or dump_embs, we require use_weak_label to be True."
                )

    # Gets embeddings that need to be prepped during data prep or in the __get_item__ method
    batch_on_the_fly_kg_adj = get_dataloader_embeddings(config, entity_symbols)
    # Gets dataloaders
    dataloaders = get_dataloaders(
        config,
        tasks,
        data_splits,
        entity_symbols,
        batch_on_the_fly_kg_adj,
    )
    slice_datasets = get_slicedatasets(config, slice_splits, entity_symbols)

    configure_optimizer(config)

    # Create models and add tasks
    if config.model_config.attn_class == "BERTNED":
        log_rank_0_info(logger, f"Starting NED-Base Model")
        assert (config.data_config.type_prediction.use_type_pred is
                False), f"NED-Base does not support type prediction"
        assert (
            config.data_config.word_embedding.use_sent_proj is False
        ), f"NED-Base requires word_embeddings.use_sent_proj to be False"
        model = EmmentalModel(name="NED-Base")
        model.add_tasks(
            ned_task.create_task(config, entity_symbols, slice_datasets))
    else:
        log_rank_0_info(logger, f"Starting Bootleg Model")
        model = EmmentalModel(name="Bootleg")
        # TODO: make this more general for other tasks -- iterate through list of tasks
        # and add task for each
        model.add_task(
            ned_task.create_task(config, entity_symbols, slice_datasets))
        if TYPE_PRED_TASK in tasks:
            model.add_task(
                type_pred_task.create_task(config, entity_symbols,
                                           slice_datasets))
            # Add the mention type embedding to the embedding payload
            type_pred_task.update_ned_task(model)

    # Print param counts
    if mode == "train":
        log_rank_0_debug(logger, "PARAMS WITH GRAD\n" + "=" * 30)
        total_params = count_parameters(model,
                                        requires_grad=True,
                                        logger=logger)
        log_rank_0_info(logger, f"===> Total Params With Grad: {total_params}")
        log_rank_0_debug(logger, "PARAMS WITHOUT GRAD\n" + "=" * 30)
        total_params = count_parameters(model,
                                        requires_grad=False,
                                        logger=logger)
        log_rank_0_info(logger,
                        f"===> Total Params Without Grad: {total_params}")

    # Load the best model from the pretrained model
    if config["model_config"]["model_path"] is not None:
        model.load(config["model_config"]["model_path"])

    # Barrier
    if config["learner_config"]["local_rank"] == 0:
        torch.distributed.barrier()

    # Train model
    if mode == "train":
        emmental_learner = EmmentalLearner()
        emmental_learner._set_optimizer(model)
        emmental_learner.learn(model, dataloaders)
        if config.learner_config.local_rank in [0, -1]:
            model.save(f"{emmental.Meta.log_path}/last_model.pth")

    # Multi-gpu DataParallel eval (NOT distributed)
    if mode in ["eval", "dump_embs", "dump_preds"]:
        # This happens inside EmmentalLearner for training
        if (config["learner_config"]["local_rank"] == -1
                and config["model_config"]["dataparallel"]):
            model._to_dataparallel()

    # If just finished training a model or in eval mode, run eval
    if mode in ["train", "eval"]:
        scores = model.score(dataloaders)
        # Save metrics and models
        log_rank_0_info(logger, f"Saving metrics to {emmental.Meta.log_path}")
        log_rank_0_info(logger, f"Metrics: {scores}")
        scores["log_path"] = emmental.Meta.log_path
        if config.learner_config.local_rank in [0, -1]:
            write_to_file(f"{emmental.Meta.log_path}/{mode}_metrics.txt",
                          scores)
            eval_utils.write_disambig_metrics_to_csv(
                f"{emmental.Meta.log_path}/{mode}_disambig_metrics.csv",
                scores)
        return scores

    # If you want detailed dumps, save model outputs
    assert mode in [
        "dump_preds",
        "dump_embs",
    ], 'Mode must be "dump_preds" or "dump_embs"'
    dump_embs = False if mode != "dump_embs" else True
    assert (
        len(dataloaders) == 1
    ), f"We should only have length 1 dataloaders for dump_embs and dump_preds!"
    final_result_file, final_out_emb_file = None, None
    if config.learner_config.local_rank in [0, -1]:
        # Setup files/folders
        filename = os.path.basename(dataloaders[0].dataset.raw_filename)
        log_rank_0_debug(
            logger,
            f"Collecting sentence to mention map {os.path.join(config.data_config.data_dir, filename)}",
        )
        sentidx2num_mentions, sent_idx2row = eval_utils.get_sent_idx2num_mens(
            os.path.join(config.data_config.data_dir, filename))
        log_rank_0_debug(logger, f"Done collecting sentence to mention map")
        eval_folder = eval_utils.get_eval_folder(filename)
        subeval_folder = os.path.join(eval_folder, "batch_results")
        utils.ensure_dir(subeval_folder)
        # Will keep track of sentences dumped already. These will only be ones with mentions
        all_dumped_sentences = set()
        number_dumped_batches = 0
        total_mentions_seen = 0
        all_result_files = []
        all_out_emb_files = []
        # Iterating over batches of predictions
        for res_i, res_dict in enumerate(
                eval_utils.batched_pred_iter(
                    model,
                    dataloaders[0],
                    config.run_config.eval_accumulation_steps,
                    sentidx2num_mentions,
                )):
            (
                result_file,
                out_emb_file,
                final_sent_idxs,
                mentions_seen,
            ) = eval_utils.disambig_dump_preds(
                res_i,
                total_mentions_seen,
                config,
                res_dict,
                sentidx2num_mentions,
                sent_idx2row,
                subeval_folder,
                entity_symbols,
                dump_embs,
                NED_TASK,
            )
            all_dumped_sentences.update(final_sent_idxs)
            all_result_files.append(result_file)
            all_out_emb_files.append(out_emb_file)
            total_mentions_seen += mentions_seen
            number_dumped_batches += 1

        # Dump the sentences that had no mentions and were not already dumped
        # Assert all remaining sentences have no mentions
        assert all(
            v == 0 for k, v in sentidx2num_mentions.items()
            if k not in all_dumped_sentences
        ), (f"Sentences with mentions were not dumped: "
            f"{[k for k, v in sentidx2num_mentions.items() if k not in all_dumped_sentences]}"
            )
        empty_sentidx2row = {
            k: v
            for k, v in sent_idx2row.items() if k not in all_dumped_sentences
        }
        empty_resultfile = eval_utils.get_result_file(number_dumped_batches,
                                                      subeval_folder)
        all_result_files.append(empty_resultfile)
        # Dump the outputs
        eval_utils.write_data_labels_single(
            sentidx2row=empty_sentidx2row,
            output_file=empty_resultfile,
            filt_emb_data=None,
            sental2embid={},
            alias_cand_map=entity_symbols.get_alias2qids(),
            qid2eid=entity_symbols.get_qid2eid(),
            result_alias_offset=total_mentions_seen,
            train_in_cands=config.data_config.train_in_candidates,
            max_cands=entity_symbols.max_candidates,
            dump_embs=dump_embs,
        )

        log_rank_0_info(
            logger,
            f"Finished dumping. Merging results across accumulation steps.")
        # Final result files for labels and embeddings
        final_result_file = os.path.join(eval_folder,
                                         config.run_config.result_label_file)
        # Copy labels
        output = open(final_result_file, "wb")
        for file in all_result_files:
            shutil.copyfileobj(open(file, "rb"), output)
        output.close()
        log_rank_0_info(logger, f"Bootleg labels saved at {final_result_file}")
        # Try to copy embeddings
        if dump_embs:
            final_out_emb_file = os.path.join(
                eval_folder, config.run_config.result_emb_file)
            log_rank_0_info(
                logger,
                f"Trying to merge numpy embedding arrays. "
                f"If your machine is limited in memory, this may cause OOM errors. "
                f"Is that happens, result files should be saved in {subeval_folder}.",
            )
            all_arrays = []
            for i, npfile in enumerate(all_out_emb_files):
                all_arrays.append(np.load(npfile))
            np.save(final_out_emb_file, np.concatenate(all_arrays))
            log_rank_0_info(
                logger, f"Bootleg embeddings saved at {final_out_emb_file}")

        # Cleanup
        try_rmtree(subeval_folder)
    return final_result_file, final_out_emb_file
コード例 #25
0
def write_data_labels(
    num_processes,
    result_alias_offset,
    merged_entity_emb_file,
    merged_storage_type,
    sent_idx2row,
    cache_folder,
    out_file,
    entity_dump,
    train_in_candidates,
    max_candidates,
    dump_embs,
    trie_candidate_map_folder=None,
    trie_qid2eid_file=None,
):
    """Takes the flattened data from merge_sentences and writes out predictions
    to a file, one line per sentence.

    The embedding ids are added to the file if dump_embs is True.

    Args:
        num_processes: number of processes
        result_alias_offset: alias offset of this batch of examples for writing out
        merged_entity_emb_file: input memmap file after merge sentences
        merged_storage_type: input file storage type
        sent_idx2row: Dict of sentence idx to row relevant to this subbatch
        cache_folder: folder to save temporary outputs
        out_file: final output file for predictions
        entity_dump: entity dump
        train_in_candidates: whether NC entities are not in candidate lists
        max_candidates: maximum number of candidates
        dump_embs: whether to dump embeddings or not
        trie_candidate_map_folder: folder where trie of alias->candidate map is stored for parallel proccessing
        trie_qid2eid_file: file where trie of qid->eid map is stored for parallel proccessing

    Returns:
    """
    st = time.time()
    sental2embid = get_sental2embid(merged_entity_emb_file,
                                    merged_storage_type)
    log_rank_0_debug(logger,
                     f"Finished getting sentence map {time.time() - st}s")

    total_input = len(sent_idx2row)
    if num_processes == 1:
        filt_emb_data = np.memmap(merged_entity_emb_file,
                                  dtype=merged_storage_type,
                                  mode="r+")
        write_data_labels_single(
            sentidx2row=sent_idx2row,
            output_file=out_file,
            filt_emb_data=filt_emb_data,
            sental2embid=sental2embid,
            alias_cand_map=entity_dump.get_alias2qids(),
            qid2eid=entity_dump.get_qid2eid(),
            result_alias_offset=result_alias_offset,
            train_in_cands=train_in_candidates,
            max_cands=max_candidates,
            dump_embs=dump_embs,
        )
    else:
        assert (
            trie_candidate_map_folder is not None
        ), "trie_candidate_map_folder is None and you have parallel turned on"
        assert (trie_qid2eid_file is not None
                ), "trie_qid2eid_file is None and you have parallel turned on"

        # Get trie of sentence map
        trie_folder = os.path.join(cache_folder, "bootleg_sental2embid")
        utils.ensure_dir(trie_folder)
        trie_file = os.path.join(trie_folder, "sentidx.marisa")
        utils.create_single_item_trie(sental2embid, out_file=trie_file)
        # Chunk file for parallel writing
        # We do not use TemporaryFolders as the temp dir may not have enough space for large files
        create_ex_indir = os.path.join(cache_folder,
                                       "_bootleg_eval_temp_indir")
        utils.ensure_dir(create_ex_indir)
        create_ex_outdir = os.path.join(cache_folder,
                                        "_bootleg_eval_temp_outdir")
        utils.ensure_dir(create_ex_outdir)
        chunk_input = int(np.ceil(total_input / num_processes))
        logger.debug(
            f"Chunking up {total_input} lines into subfiles of size {chunk_input} lines"
        )
        # Chunk up dictionary of data for parallel processing
        input_files = []
        i = 0
        cur_lines = 0
        file_split = os.path.join(create_ex_indir, f"out{i}.jsonl")
        open_file = open(file_split, "w")
        for s_idx in sent_idx2row:
            if cur_lines >= chunk_input:
                open_file.close()
                input_files.append(file_split)
                cur_lines = 0
                i += 1
                file_split = os.path.join(create_ex_indir, f"out{i}.jsonl")
                open_file = open(file_split, "w")
            line = sent_idx2row[s_idx]
            open_file.write(ujson.dumps(line) + "\n")
            cur_lines += 1
        open_file.close()
        input_files.append(file_split)
        # Generation input/output pairs
        output_files = [
            in_file_name.replace(create_ex_indir, create_ex_outdir)
            for in_file_name in input_files
        ]
        log_rank_0_debug(logger, f"Done chunking files. Starting pool")

        pool = multiprocessing.Pool(
            processes=num_processes,
            initializer=write_data_labels_initializer,
            initargs=[
                merged_entity_emb_file,
                merged_storage_type,
                trie_file,
                result_alias_offset,
                train_in_candidates,
                max_candidates,
                dump_embs,
                trie_candidate_map_folder,
                trie_qid2eid_file,
            ],
        )

        input_args = list(zip(input_files, output_files))

        total = 0
        for res in pool.imap(write_data_labels_hlp, input_args, chunksize=1):
            total += 1

        # Merge output files to final file
        log_rank_0_debug(logger, f"Merging output files")
        with open(out_file, "wb") as outfile:
            for filename in glob.glob(os.path.join(create_ex_outdir, "*")):
                if filename == out_file:
                    # don't want to copy the output into the output
                    continue
                with open(filename, "rb") as readfile:
                    shutil.copyfileobj(readfile, outfile)
コード例 #26
0
def merge_subsentences(
    num_processes,
    subset_sent_idx2num_mens,
    cache_folder,
    to_save_file,
    to_save_storage,
    to_read_file,
    to_read_storage,
    dump_embs=False,
):
    """Flatten all sentences back together over sub-sentences; removing the PAD
    aliases from the data I.e., converts from sent_idx -> array of values to
    (sent_idx, alias_idx) -> value with varying numbers of aliases per
    sentence.

    Args:
        num_processes: number of processes
        subset_sent_idx2num_mens: Dict of sentence index to number of mentions for this batch
        cache_folder: cache directory
        to_save_file: memmap file to save results to
        to_save_storage: save file storage type
        to_read_file: memmap file to read predictions from
        to_read_storage: read file storage type
        dump_embs: whether to save embeddings or not

    Returns:
    """
    # Compute sent idx to offset so we know where to fill in mentions
    cur_offset = 0
    sentidx2offset = {}
    for k, v in subset_sent_idx2num_mens.items():
        sentidx2offset[k] = cur_offset
        cur_offset += v
        # print("Sent Idx, Num Mens, Offset", k, v, cur_offset)
    total_num_mentions = cur_offset
    # print("TOTAL", total_num_mentions)
    full_pred_data = np.memmap(to_read_file, dtype=to_read_storage, mode="r")
    M = int(full_pred_data[0]["M"])
    K = int(full_pred_data[0]["K"])
    hidden_size = int(full_pred_data[0]["hidden_size"])
    # print("TOTAL MENS", total_num_mentions)
    filt_emb_data = np.memmap(to_save_file,
                              dtype=to_save_storage,
                              mode="w+",
                              shape=(total_num_mentions, ))
    filt_emb_data["hidden_size"] = hidden_size
    filt_emb_data["sent_idx"][:] = -1
    filt_emb_data["alias_list_pos"][:] = -1

    all_ids = list(range(0, len(full_pred_data)))
    start = time.time()
    if num_processes == 1:
        seen_ids = merge_subsentences_single(
            M,
            K,
            hidden_size,
            dump_embs,
            all_ids,
            filt_emb_data,
            full_pred_data,
            sentidx2offset,
        )
    else:
        # Get trie for sentence start map
        trie_folder = os.path.join(cache_folder, "bootleg_sent_idx2num_mens")
        utils.ensure_dir(trie_folder)
        trie_file = os.path.join(trie_folder, "sentidx.marisa")
        utils.create_single_item_trie(sentidx2offset, out_file=trie_file)
        # Chunk up date
        chunk_size = int(np.ceil(len(full_pred_data) / num_processes))
        row_idx_set_chunks = [
            all_ids[ids:ids + chunk_size]
            for ids in range(0, len(full_pred_data), chunk_size)
        ]
        # Start pool
        input_args = [[M, K, hidden_size, dump_embs, chunk]
                      for chunk in row_idx_set_chunks]
        log_rank_0_debug(
            logger,
            f"Merging sentences together with {num_processes} processes")
        pool = multiprocessing.Pool(
            processes=num_processes,
            initializer=merge_subsentences_initializer,
            initargs=[
                to_save_file,
                to_save_storage,
                to_read_file,
                to_read_storage,
                trie_file,
            ],
        )

        seen_ids = set()
        for sent_ids_seen in pool.imap_unordered(merge_subsentences_hlp,
                                                 input_args,
                                                 chunksize=1):
            for emb_id in sent_ids_seen:
                assert (
                    emb_id not in seen_ids
                ), f"{emb_id} already seen, something went wrong with sub-sentences"
                seen_ids.add(emb_id)
    # filt_emb_data = np.memmap(to_save_file, dtype=to_save_storage, mode="r")
    # for i in range(len(filt_emb_data)):
    #     si = filt_emb_data[i]["sent_idx"]
    #     al_test = filt_emb_data[i]["alias_list_pos"]
    #     if si == -1 or al_test == -1:
    #         print("BAD", i, filt_emb_data[i])
    #         import ipdb; ipdb.set_trace()
    logging.debug(f"Saw {len(seen_ids)} sentences")
    logging.debug(f"Time to merge sub-sentences {time.time() - start}s")
    return
コード例 #27
0
def disambig_dump_preds(
    result_idx,
    result_alias_offset,
    config,
    res_dict,
    sent_idx2num_mens,
    sent_idx2row,
    save_folder,
    entity_symbols,
    dump_embs,
    task_name,
):
    """Dumps the predictions of a disambiguation task.

    Args:
        result_idx: batch index of the result arrays
        result_alias_offset: overall offset of the starting example (i.e., the number of previous mens already written)
        config: model config
        res_dict: result dictionary from Emmental predict
        sent_idx2num_mens: Dict sentence idx to number of mentions
        sent_idx2row: Dict sentence idx to row of eval data
        save_folder: folder to save results
        entity_symbols: entity symbols
        dump_embs: whether to save the contextualized embeddings or not
        task_name: task name

    Returns: saved prediction file, saved embedding file (will be None if dump_embs is False)
    """
    num_processes = min(config.run_config.dataset_threads,
                        int(multiprocessing.cpu_count() * 0.9))
    cache_dir = os.path.join(save_folder, f"cache_{result_idx}")
    utils.ensure_dir(cache_dir)
    trie_candidate_map_folder = None
    trie_qid2eid_file = None
    # Save the alias->QID candidate map and the QID->EID mapping in memory efficient structures for faster
    # prediction dumping
    if num_processes > 1:
        entity_prep_dir = data_utils.get_emb_prep_dir(config.data_config)
        trie_candidate_map_folder = os.path.join(entity_prep_dir,
                                                 "for_dumping_preds",
                                                 "alias_cand_trie")
        utils.ensure_dir(trie_candidate_map_folder)
        check_and_create_alias_cand_trie(trie_candidate_map_folder,
                                         entity_symbols)
        trie_qid2eid_file = os.path.join(entity_prep_dir, "for_dumping_preds",
                                         "qid2eid_trie.marisa")
        if not os.path.exists(trie_qid2eid_file):
            utils.create_single_item_trie(entity_symbols.get_qid2eid(),
                                          out_file=trie_qid2eid_file)

    # This is dumping
    disambig_res_dict = {}
    for k in res_dict:
        assert task_name in res_dict[
            k], f"{task_name} not in res_dict for key {k}"
        disambig_res_dict[k] = res_dict[k][task_name]

    # write to file (M x hidden x size for each data point -- next step will deal with recovering original sentence
    # indices for overflowing sentences)
    unmerged_entity_emb_file = os.path.join(save_folder, f"entity_embs.pt")
    merged_entity_emb_file = os.path.join(save_folder,
                                          f"entity_embs_unmerged.pt")
    emb_file_config = os.path.splitext(
        unmerged_entity_emb_file)[0] + "_config.npy"
    M = config.data_config.max_aliases
    K = entity_symbols.max_candidates + (
        not config.data_config.train_in_candidates)
    if dump_embs:
        unmerged_storage_type = np.dtype([
            ("M", int),
            ("K", int),
            ("hidden_size", int),
            ("sent_idx", int),
            ("subsent_idx", int),
            ("alias_list_pos", int, (M, )),
            ("entity_emb", float, M * config.model_config.hidden_size),
            ("final_loss_true", int, (M, )),
            ("final_loss_pred", int, (M, )),
            ("final_loss_prob", float, (M, )),
            ("final_loss_cand_probs", float, M * K),
        ])
        merged_storage_type = np.dtype([
            ("hidden_size", int),
            ("sent_idx", int),
            ("alias_list_pos", int),
            ("entity_emb", float, config.model_config.hidden_size),
            ("final_loss_pred", int),
            ("final_loss_prob", float),
            ("final_loss_cand_probs", float, K),
        ])
    else:
        # don't need to extract contextualized entity embedding
        unmerged_storage_type = np.dtype([
            ("M", int),
            ("K", int),
            ("hidden_size", int),
            ("sent_idx", int),
            ("subsent_idx", int),
            ("alias_list_pos", int, (M, )),
            ("final_loss_true", int, (M, )),
            ("final_loss_pred", int, (M, )),
            ("final_loss_prob", float, (M, )),
            ("final_loss_cand_probs", float, M * K),
        ])
        merged_storage_type = np.dtype([
            ("hidden_size", int),
            ("sent_idx", int),
            ("alias_list_pos", int),
            ("final_loss_pred", int),
            ("final_loss_prob", float),
            ("final_loss_cand_probs", float, K),
        ])
    mmap_file = np.memmap(
        unmerged_entity_emb_file,
        dtype=unmerged_storage_type,
        mode="w+",
        shape=(len(disambig_res_dict["uids"]), ),
    )
    # print("MEMMAP FILE SHAPE", len(disambig_res_dict["uids"]))
    # Init sent_idx to -1 for debugging
    mmap_file[:]["sent_idx"] = -1
    np.save(emb_file_config, unmerged_storage_type, allow_pickle=True)
    log_rank_0_debug(
        logger,
        f"Created file {unmerged_entity_emb_file} to save predictions.")

    log_rank_0_debug(logger, f'{len(disambig_res_dict["uids"])} samples')
    for_iteration = [
        disambig_res_dict["uids"],
        disambig_res_dict["golds"],
        disambig_res_dict["probs"],
        disambig_res_dict["preds"],
    ]
    all_sent_idx = set()
    for i, (uid, gold, probs, model_pred) in enumerate(zip(*for_iteration)):
        # disambig_res_dict["output"] is dict with keys ['_input__alias_orig_list_pos',
        # 'bootleg_pred_1', '_input__sent_idx', '_input__for_dump_gold_cand_K_idx_train', '_input__subsent_idx', 0, 1]
        sent_idx = disambig_res_dict["outputs"]["_input__sent_idx"][i]
        # print("INSIDE LOOP", sent_idx, "AT", i)
        subsent_idx = disambig_res_dict["outputs"]["_input__subsent_idx"][i]
        alias_orig_list_pos = disambig_res_dict["outputs"][
            "_input__alias_orig_list_pos"][i]
        gold_cand_K_idx_train = disambig_res_dict["outputs"][
            "_input__for_dump_gold_cand_K_idx_train"][i]
        output_embeddings = disambig_res_dict["outputs"][
            f"{PRED_LAYER}_ent_embs"][i]
        mmap_file[i]["M"] = M
        mmap_file[i]["K"] = K
        mmap_file[i]["hidden_size"] = config.model_config.hidden_size
        mmap_file[i]["sent_idx"] = sent_idx
        mmap_file[i]["subsent_idx"] = subsent_idx
        mmap_file[i]["alias_list_pos"] = alias_orig_list_pos
        # This will give all aliases seen by the model during training, independent of if it's gold or not
        mmap_file[i][f"final_loss_true"] = gold_cand_K_idx_train.reshape(M)

        # get max for each alias, probs is M x K
        max_probs = probs.max(axis=1)
        pred_cands = probs.argmax(axis=1)

        mmap_file[i]["final_loss_pred"] = pred_cands
        mmap_file[i]["final_loss_prob"] = max_probs
        mmap_file[i]["final_loss_cand_probs"] = probs.reshape(1, -1)

        all_sent_idx.add(str(sent_idx))
        # final_entity_embs is M x K x hidden_size, pred_cands is M
        if dump_embs:
            chosen_entity_embs = select_embs(embs=output_embeddings,
                                             pred_cands=pred_cands,
                                             M=M)

            # write chosen entity embs to file for contextualized entity embeddings
            mmap_file[i]["entity_emb"] = chosen_entity_embs.reshape(1, -1)

    # for i in range(len(mmap_file)):
    #     si = mmap_file[i]["sent_idx"]
    #     if -1 == si:
    #         import pdb
    #         pdb.set_trace()
    #     assert si != -1, f"{i} {mmap_file[i]}"
    # Store all predicted sentences to filter the sentence mapping by
    subset_sent_idx2num_mens = {
        k: v
        for k, v in sent_idx2num_mens.items() if k in all_sent_idx
    }
    # print("ALL SEEN", all_sent_idx)
    subsent_sent_idx2row = {
        k: v
        for k, v in sent_idx2row.items() if k in all_sent_idx
    }
    result_file = get_result_file(result_idx, save_folder)
    log_rank_0_debug(logger, f"Writing predictions to {result_file}...")
    merge_subsentences(
        num_processes=num_processes,
        subset_sent_idx2num_mens=subset_sent_idx2num_mens,
        cache_folder=cache_dir,
        to_save_file=merged_entity_emb_file,
        to_save_storage=merged_storage_type,
        to_read_file=unmerged_entity_emb_file,
        to_read_storage=unmerged_storage_type,
        dump_embs=dump_embs,
    )
    write_data_labels(
        num_processes=num_processes,
        result_alias_offset=result_alias_offset,
        merged_entity_emb_file=merged_entity_emb_file,
        merged_storage_type=merged_storage_type,
        sent_idx2row=subsent_sent_idx2row,
        cache_folder=cache_dir,
        out_file=result_file,
        entity_dump=entity_symbols,
        train_in_candidates=config.data_config.train_in_candidates,
        max_candidates=entity_symbols.max_candidates,
        dump_embs=dump_embs,
        trie_candidate_map_folder=trie_candidate_map_folder,
        trie_qid2eid_file=trie_qid2eid_file,
    )

    out_emb_file = None
    filt_emb_data = np.memmap(merged_entity_emb_file,
                              dtype=merged_storage_type,
                              mode="r+")
    total_mentions_seen = len(filt_emb_data)
    # save easier-to-use embedding file
    if dump_embs:
        hidden_size = filt_emb_data[0]["hidden_size"]
        out_emb_file = get_emb_file(result_idx, save_folder)
        np.save(out_emb_file,
                filt_emb_data["entity_emb"].reshape(-1, hidden_size))
        log_rank_0_debug(
            logger,
            f"Saving contextual entity embeddings for {result_idx} to {out_emb_file}",
        )
    filt_emb_data = None

    # Cleanup cache - sometimes the file in cache_dir is still open so we need to retry to delete it
    try_rmtree(cache_dir)

    log_rank_0_debug(logger,
                     f"Wrote predictions for {result_idx} to {result_file}")
    return result_file, out_emb_file, all_sent_idx, total_mentions_seen
コード例 #28
0
def merge_data(
    num_processes,
    train_in_candidates,
    keep_orig,
    max_candidates,
    file_pairs,
    entity_dump_f,
):
    # File pair is in file, cand map file, out file, is_train

    # Chunk file for parallel writing
    create_ex_indir = os.path.join(os.path.dirname(file_pairs[0]),
                                   "_bootleg_temp_indir")
    utils.ensure_dir(create_ex_indir)
    create_ex_indir_cands = os.path.join(os.path.dirname(file_pairs[0]),
                                         "_bootleg_temp_indir2")
    utils.ensure_dir(create_ex_indir_cands)
    create_ex_outdir = os.path.join(os.path.dirname(file_pairs[0]),
                                    "_bootleg_temp_outdir")
    utils.ensure_dir(create_ex_outdir)
    print(f"Counting lines")
    total_input = sum(1 for _ in open(file_pairs[0]))
    total_input_cands = sum(1 for _ in open(file_pairs[1]))
    assert (
        total_input_cands == total_input
    ), f"{total_input} lines of orig data != {total_input_cands} of cand data"
    chunk_input_size = int(np.ceil(total_input / num_processes))
    total_input_from_chunks, input_files_dict = utils.chunk_file(
        file_pairs[0], create_ex_indir, chunk_input_size)
    total_input_cands_from_chunks, input_files_cands_dict = utils.chunk_file(
        file_pairs[1], create_ex_indir_cands, chunk_input_size)

    input_files = list(input_files_dict.keys())
    input_cand_files = list(input_files_cands_dict.keys())
    assert len(input_cand_files) == len(input_files)
    input_file_lines = [input_files_dict[k] for k in input_files]
    input_cand_file_lines = [
        input_files_cands_dict[k] for k in input_cand_files
    ]
    for p_l, p_r in zip(input_file_lines, input_cand_file_lines):
        assert (
            p_l == p_r
        ), f"The matching chunk files don't have matching sizes {p_l} versus {p_r}"
    output_files = [
        in_file_name.replace(create_ex_indir, create_ex_outdir)
        for in_file_name in input_files
    ]
    assert (
        total_input == total_input_from_chunks
    ), f"Lengths of files {total_input} doesn't match {total_input_from_chunks}"
    assert (
        total_input_cands == total_input_cands_from_chunks
    ), f"Lengths of files {total_input_cands} doesn't match {total_input_cands_from_chunks}"
    # file_pairs is input file, cand map file, output file, is_train
    input_args = [[
        train_in_candidates,
        keep_orig,
        max_candidates,
        input_files[i],
        input_file_lines[i],
        input_cand_files[i],
        output_files[i],
        file_pairs[3],
    ] for i in range(len(input_files))]

    pool = multiprocessing.Pool(processes=num_processes,
                                initializer=init_process,
                                initargs=[entity_dump_f])

    new_alias2qids = {}
    total_seen = 0
    total_dropped = 0
    for res in pool.imap(merge_data_hlp, input_args, chunksize=1):
        temp_alias2qids, seen, dropped = res
        total_seen += seen
        total_dropped += dropped
        for k in temp_alias2qids:
            assert k not in new_alias2qids, f"{k}"
            new_alias2qids[k] = temp_alias2qids[k]
    print(
        f"Overall Recall for {file_pairs[0]}: {(total_seen - total_dropped) / total_seen} for seeing {total_seen}"
    )
    # Merge output files to final file
    print(f"Merging output files")
    with open(file_pairs[2], "wb") as outfile:
        for filename in glob.glob(os.path.join(create_ex_outdir, "*")):
            if filename == file_pairs[2]:
                # don't want to copy the output into the output
                continue
            with open(filename, "rb") as readfile:
                shutil.copyfileobj(readfile, outfile)
    # Remove temporary files/folders
    shutil.rmtree(create_ex_indir)
    shutil.rmtree(create_ex_indir_cands)
    shutil.rmtree(create_ex_outdir)
    return new_alias2qids
コード例 #29
0
ファイル: data_utils.py プロジェクト: paper2code/bootleg
def get_data_prep_dir(args):
    prep_dir = os.path.join(args.data_config.data_dir,
                            args.data_config.data_prep_dir)
    utils.ensure_dir(prep_dir)
    return prep_dir
コード例 #30
0
ファイル: type_embs.py プロジェクト: pombredanne/bootleg
    def load_regularization_mapping(cls, data_config,
                                    num_types_with_pad_and_unk, type2row_dict,
                                    reg_file):
        """Reads in a csv file with columns [qid, regularization].

        In the forward pass, the entity id with associated qid will be
        regularized with probability regularization.

        Args:
            data_config: data config
            num_entities_with_pad_and_nocand: number of types including pad and null option
            type2row_dict: Dict from typeID to row id in the type embedding matrix
            reg_file: regularization csv file

        Returns: Tensor where each value is the regularization value for EID
        """
        reg_str = os.path.splitext(os.path.basename(reg_file.replace("/",
                                                                     "_")))[0]
        prep_dir = data_utils.get_data_prep_dir(data_config)
        prep_file = os.path.join(prep_dir,
                                 f"type_regularization_mapping_{reg_str}.pt")
        utils.ensure_dir(os.path.dirname(prep_file))
        log_rank_0_debug(logger,
                         f"Looking for regularization mapping in {prep_file}")
        if not data_config.overwrite_preprocessed_data and os.path.exists(
                prep_file):
            log_rank_0_debug(
                logger,
                f"Loading existing entity regularization mapping from {prep_file}",
            )
            start = time.time()
            typeid2reg = torch.load(prep_file)
            log_rank_0_debug(
                logger,
                f"Loaded existing entity regularization mapping in {round(time.time() - start, 2)}s",
            )
        else:
            start = time.time()
            log_rank_0_debug(
                logger,
                f"Building entity regularization mapping from {reg_file}")
            typeid2reg_raw = pd.read_csv(reg_file)
            assert (
                "typeid" in typeid2reg_raw.columns
                and "regularization" in typeid2reg_raw.columns
            ), f"Expected typeid and regularization as the column names for {reg_file}"
            # default of no mask
            typeid2reg_arr = [0.0] * num_types_with_pad_and_unk
            for row_idx, row in typeid2reg_raw.iterrows():
                # Happens when we filter QIDs not in our entity db and the max typeid is smaller than the total number
                if int(row["typeid"]) not in type2row_dict:
                    continue
                typeid = type2row_dict[int(row["typeid"])]
                typeid2reg_arr[typeid] = row["regularization"]
            typeid2reg = torch.Tensor(typeid2reg_arr)
            torch.save(typeid2reg, prep_file)
            log_rank_0_debug(
                logger,
                f"Finished building and saving entity regularization mapping in {round(time.time() - start, 2)}s.",
            )
        return typeid2reg