示例#1
0
def check_and_create_alias_cand_trie(save_folder, entity_symbols):
    """Creates a mmap memory trie object if it doesn't exist for storing the
    alias-candidate mappings.

    Args:
        save_folder: save folder for alias trie
        entity_symbols: entity symbols

    Returns:
    """
    try:
        AliasCandRecordTrie(load_dir=save_folder)
    except FileNotFoundError:
        log_rank_0_debug(
            logger,
            "Creating the alias candidate trie for faster parallel processing. "
            "This is a one time cost",
        )
        alias_trie = AliasCandRecordTrie(
            input_dict=entity_symbols.get_alias2qids(),
            vocabulary=entity_symbols.get_qid2title(),
            max_value=entity_symbols.max_candidates,
        )
        alias_trie.dump(save_folder)
    return
示例#2
0
def get_sent_idx2num_mens(data_file):
    """Gets the map from sentence index to number of mentions and to data. Used
    for calculating offsets and chunking file.

    Args:
        data_file: eval file

    Returns: Dict of sentence index -> number of mention per sentence, Dict of sentence index -> input line
    """
    sent_idx2num_mens = {}
    sent_idx2row = {}
    total_num_mentions = 0
    with open(data_file) as f:
        for line in f:
            line = ujson.loads(line)
            # keep track of the start idx in the condensed memory mapped file for each sentence (varying number of
            # aliases)
            assert (
                line["sent_idx_unq"] not in sent_idx2num_mens
            ), f'Sentence indices must be unique. {line["sent_idx_unq"]} already seen.'
            sent_idx2row[str(line["sent_idx_unq"])] = line
            # Save as string for Marisa Tri later
            sent_idx2num_mens[str(line["sent_idx_unq"])] = len(line["aliases"])
            # We include false aliases for debugging (and alias_pos includes them)
            total_num_mentions += len(line["aliases"])
            # print("INSIDE SENT MAP", str(line["sent_idx_unq"]), total_num_mentions)

    log_rank_0_debug(
        logger,
        f"Total number of mentions across all sentences: {total_num_mentions}")
    return sent_idx2num_mens, sent_idx2row
示例#3
0
 def __init__(
     self,
     main_args,
     emb_args,
     entity_symbols,
     key,
     cpu,
     normalize,
     dropout1d_perc,
     dropout2d_perc,
 ):
     super(KGIndices, self).__init__(
         main_args=main_args,
         emb_args=emb_args,
         entity_symbols=entity_symbols,
         key=key,
         cpu=cpu,
         normalize=normalize,
         dropout1d_perc=dropout1d_perc,
         dropout2d_perc=dropout2d_perc,
     )
     self._dim = 0
     # Weight for the diagonal addition to the KG indices - allows for summing an entity with other connections
     self.kg_bias_weight = torch.nn.Parameter(torch.tensor(2.0))
     self.kg_softmax = torch.nn.Softmax(dim=2)
     # This determines that, when prepping the embeddings, we will query the kg_adj matrix - generating
     # M*K values per entity candidate
     self.kg_adj_process_func = embedding_utils.prep_kg_feature_matrix
     log_rank_0_debug(
         logger,
         f"You are using the KGIndices class with key {key}."
         f" This key need to be used in the attention network to access the kg bias matrix."
         f" This embedding is not appended to the payload.",
     )
示例#4
0
    def prep(
        cls,
        data_config,
        entity_symbols,
        num_aliases_with_pad_and_unk,
        num_cands_K,
    ):
        """Preps the alias to entity EID table.

        Args:
            data_config: data config
            entity_symbols: entity symbols
            num_aliases_with_pad_and_unk: number of aliases including pad and unk
            num_cands_K: number of candidates per alias (aka K)

        Returns: torch Tensor of the alias to EID table, save pt file
        """
        # we pass num_aliases_with_pad_and_unk and num_cands_K to remove the dependence on entity_symbols
        # when the alias table is already prepped
        data_shape = (num_aliases_with_pad_and_unk, num_cands_K)
        # dependent on train_in_candidates flag
        prep_dir = data_utils.get_emb_prep_dir(data_config)
        alias_str = os.path.splitext(
            data_config.alias_cand_map.replace("/", "_"))[0]
        prep_file = os.path.join(
            prep_dir,
            f"alias2entity_table_{alias_str}_InC{int(data_config.train_in_candidates)}.pt",
        )
        log_rank_0_debug(logger, f"Looking for alias table in {prep_file}")
        if not data_config.overwrite_preprocessed_data and os.path.exists(
                prep_file):
            log_rank_0_debug(logger, f"Loading alias table from {prep_file}")
            start = time.time()
            alias2entity_table = np.memmap(prep_file,
                                           dtype="int64",
                                           mode="r+",
                                           shape=data_shape)
            log_rank_0_debug(
                logger,
                f"Loaded alias table in {round(time.time() - start, 2)}s")
        else:
            start = time.time()
            log_rank_0_debug(logger, f"Building alias table")
            utils.ensure_dir(prep_dir)
            mmap_file = np.memmap(prep_file,
                                  dtype="int64",
                                  mode="w+",
                                  shape=data_shape)
            alias2entity_table = cls.build_alias_table(data_config,
                                                       entity_symbols)
            mmap_file[:] = alias2entity_table[:]
            mmap_file.flush()
            log_rank_0_debug(
                logger,
                f"Finished building and saving alias table in {round(time.time() - start, 2)}s.",
            )
        alias2entity_table = torch.from_numpy(alias2entity_table)
        return alias2entity_table, prep_file
示例#5
0
    def __init__(
        self,
        main_args,
        emb_args,
        entity_symbols,
        key,
        cpu,
        normalize,
        dropout1d_perc,
        dropout2d_perc,
    ):
        super(TitleEmb, self).__init__(
            main_args=main_args,
            emb_args=emb_args,
            entity_symbols=entity_symbols,
            key=key,
            cpu=cpu,
            normalize=normalize,
            dropout1d_perc=dropout1d_perc,
            dropout2d_perc=dropout2d_perc,
        )
        allowable_keys = {"proj", "requires_grad"}
        correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args)
        if not correct:
            raise ValueError(f"The key {bad_key} is not in {allowable_keys}")
        self.orig_dim = BERT_WORD_DIM
        self.merge_func = self.average_titles
        self.M = main_args.data_config.max_aliases
        self.K = get_max_candidates(entity_symbols, main_args.data_config)
        self._dim = main_args.model_config.hidden_size
        if "proj" in emb_args:
            self._dim = emb_args.proj
        self.requires_grad = True
        if "requires_grad" in emb_args:
            self.requires_grad = emb_args.requires_grad
        self.title_proj = torch.nn.Linear(self.orig_dim, self._dim)
        log_rank_0_debug(
            logger,
            f'Setting the "proj" parameter to {self._dim} and the "requires_grad" parameter to {self.requires_grad}',
        )

        (
            entity2titleid_table,
            entity2titlemask_table,
            entity2tokentypeid_table,
        ) = self.prep(
            data_config=main_args.data_config,
            entity_symbols=entity_symbols,
        )
        self.register_buffer("entity2titleid_table",
                             entity2titleid_table,
                             persistent=False)
        self.register_buffer("entity2titlemask_table",
                             entity2titlemask_table,
                             persistent=False)
        self.register_buffer("entity2tokentypeid_table",
                             entity2tokentypeid_table,
                             persistent=False)
示例#6
0
    def __init__(
        self,
        main_args,
        emb_args,
        entity_symbols,
        key,
        cpu,
        normalize,
        dropout1d_perc,
        dropout2d_perc,
    ):
        super(StaticEmb, self).__init__(
            main_args=main_args,
            emb_args=emb_args,
            entity_symbols=entity_symbols,
            key=key,
            cpu=cpu,
            normalize=normalize,
            dropout1d_perc=dropout1d_perc,
            dropout2d_perc=dropout2d_perc,
        )
        allowable_keys = {"emb_file", "proj"}
        correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args)
        if not correct:
            raise ValueError(f"The key {bad_key} is not in {allowable_keys}")
        assert "emb_file" in emb_args, f"Must have emb_file in args for StaticEmb"
        self.entity2static = self.prep(
            data_config=main_args.data_config,
            emb_args=emb_args,
            entity_symbols=entity_symbols,
        )
        self.orig_dim = self.entity2static.shape[1]
        # entity2static_embedding = torch.nn.Embedding(
        #     entity_symbols.num_entities_with_pad_and_nocand,
        #     self.orig_dim,
        #     padding_idx=-1,
        #     sparse=True,
        # )

        self.normalize = False
        if self.orig_dim > 1:
            self.normalize = True
        self.proj = None
        self._dim = self.orig_dim
        if "proj" in emb_args:
            log_rank_0_debug(
                logger,
                f"Adding a projection layer to the static emb to go to dim {emb_args.proj}",
            )
            self._dim = emb_args.proj
            self.proj = MLP(
                input_size=self.orig_dim,
                num_hidden_units=None,
                output_size=self._dim,
                num_layers=1,
            )
示例#7
0
    def __init__(
        self,
        main_args,
        emb_args,
        entity_symbols,
        key,
        cpu,
        normalize,
        dropout1d_perc,
        dropout2d_perc,
    ):
        super(KGAdjEmb, self).__init__(
            main_args=main_args,
            emb_args=emb_args,
            entity_symbols=entity_symbols,
            key=key,
            cpu=cpu,
            normalize=normalize,
            dropout1d_perc=dropout1d_perc,
            dropout2d_perc=dropout2d_perc,
        )
        allowable_keys = {"kg_adj", "threshold", "log_weight"}
        correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args)
        if not correct:
            raise ValueError(f"The key {bad_key} is not in {allowable_keys}")
        assert "kg_adj" in emb_args, f"KG embedding requires kg_adj to be set in args"
        assert (self.normalize is
                False), f"We can't normalize a KGAdjEmb as it has hidden dim 1"
        assert (self.dropout1d_perc == 0.0
                ), f"We can't dropout 1d a KGAdjEmb as it has hidden dim 1"
        assert (self.dropout2d_perc == 0.0
                ), f"We can't dropout 2d a KGAdjEmb as it has hidden dim 1"
        # This determines that, when prepping the embeddings, we will query the kg_adj matrix and sum the results -
        # generating one value per entity candidate
        self.kg_adj_process_func = embedding_utils.prep_kg_feature_sum
        self._dim = 1

        self.threshold_weight = 0
        if "threshold" in emb_args:
            self.threshold_weight = float(emb_args.threshold)
        self.log_weight = False
        if "log_weight" in emb_args:
            self.log_weight = emb_args.log_weight
            assert type(self.log_weight) is bool
        log_rank_0_debug(
            logger,
            f"Setting log_weight to be {self.log_weight} and threshold to be {self.threshold_weight} in {key}",
        )

        self.kg_adj, self.prep_file = self.prep(
            data_config=main_args.data_config,
            emb_args=emb_args,
            entity_symbols=entity_symbols,
            threshold=self.threshold_weight,
            log_weight=self.log_weight,
        )
示例#8
0
    def prep(cls, data_config, entity_symbols):
        """Prep the title data.

        Args:
            data_config: data config
            entity_symbols: entity symbols

        Returns: torch tensor EID to title token IDs, EID to title token mask, EID to title token type ID (for BERT)
        """
        prep_dir = data_utils.get_emb_prep_dir(data_config)
        prep_file_token_ids = os.path.join(
            prep_dir,
            f"title_token_ids_{data_config.word_embedding.bert_model}.pt")
        prep_file_attn_mask = os.path.join(
            prep_dir,
            f"title_attn_mask_{data_config.word_embedding.bert_model}.pt")
        prep_file_token_type_ids = os.path.join(
            prep_dir,
            f"title_token_type_ids_{data_config.word_embedding.bert_model}.pt")
        utils.ensure_dir(os.path.dirname(prep_file_token_ids))
        log_rank_0_debug(
            logger,
            f"Looking for title table mapping in {prep_file_token_ids}")
        if (not data_config.overwrite_preprocessed_data
                and os.path.exists(prep_file_token_ids)
                and os.path.exists(prep_file_attn_mask)
                and os.path.exists(prep_file_token_type_ids)):
            log_rank_0_debug(
                logger,
                f"Loading existing title table from {prep_file_token_ids}")
            start = time.time()
            entity2titleid = torch.load(prep_file_token_ids)
            entity2titlemask = torch.load(prep_file_attn_mask)
            entity2tokentypeid = torch.load(prep_file_token_type_ids)
            log_rank_0_debug(
                logger,
                f"Loaded existing title table in {round(time.time() - start, 2)}s",
            )
        else:
            start = time.time()
            log_rank_0_debug(logger, f"Loading tokenizer")
            tokenizer = load_tokenizer(data_config)
            (
                entity2titleid,
                entity2titlemask,
                entity2tokentypeid,
            ) = cls.build_title_table(tokenizer=tokenizer,
                                      entity_symbols=entity_symbols)
            torch.save(entity2titleid, prep_file_token_ids)
            torch.save(entity2titlemask, prep_file_attn_mask)
            torch.save(entity2tokentypeid, prep_file_token_type_ids)
            log_rank_0_debug(
                logger,
                f"Finished building and saving title table in {round(time.time() - start, 2)}s.",
            )
        return entity2titleid, entity2titlemask, entity2tokentypeid
示例#9
0
def configure_optimizer(config):
    """Configures the optimizer for Bootleg. By default, we use
    SparseDenseAdam. We always change the parameter group for layer norms
    following standard BERT finetuning methods.

    Args:
        config: config

    Returns:
    """
    # Set default Bootleg optimizer if config doesn't override it
    if config.learner_config.optimizer_config.optimizer is None:
        log_rank_0_debug(logger,
                         f"Setting default optimizer to be SparseDenseAdam")
        custom_optimizer = partial(
            SparseDenseAdamW,
            lr=config.learner_config.optimizer_config.lr,
            weight_decay=config.learner_config.optimizer_config.l2,
            betas=config.learner_config.optimizer_config.adamw_config.betas,
            eps=config.learner_config.optimizer_config.adamw_config.eps,
        )
        custom_optim_config = {
            "learner_config": {
                "optimizer_config": {
                    "optimizer": custom_optimizer
                }
            }
        }
        emmental.Meta.update_config(custom_optim_config)

    # Specify parameter group for Adam BERT
    def grouped_parameters(model):
        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        return [
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                emmental.Meta.config["learner_config"]["optimizer_config"]
                ["l2"],
            },
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]

    emmental.Meta.config["learner_config"]["optimizer_config"][
        "parameters"] = grouped_parameters
示例#10
0
    def load_regularization_mapping(cls, data_config, entity_symbols,
                                    reg_file):
        """Reads in a csv file with columns [qid, regularization].

        In the forward pass, the entity id with associated qid will be
        regularized with probability regularization.

        Args:
            data_config: data config
            qid2topk_eid: Dict from QID to eid in the entity embedding
            num_entities_with_pad_and_nocand: number of entities including pad and null candidate option
            reg_file: regularization csv file

        Returns: Tensor where each value is the regularization value for EID
        """
        reg_str = os.path.splitext(os.path.basename(reg_file.replace("/",
                                                                     "_")))[0]
        prep_dir = data_utils.get_data_prep_dir(data_config)
        prep_file = os.path.join(
            prep_dir, f"entity_regularization_mapping_{reg_str}.pt")
        utils.ensure_dir(os.path.dirname(prep_file))
        log_rank_0_debug(logger,
                         f"Looking for regularization mapping in {prep_file}")
        if not data_config.overwrite_preprocessed_data and os.path.exists(
                prep_file):
            log_rank_0_debug(
                logger,
                f"Loading existing entity regularization mapping from {prep_file}",
            )
            start = time.time()
            eid2reg = torch.load(prep_file)
            log_rank_0_debug(
                logger,
                f"Loaded existing entity regularization mapping in {round(time.time() - start, 2)}s",
            )
        else:
            start = time.time()
            log_rank_0_info(
                logger,
                f"Building entity regularization mapping from {reg_file}")
            qid2reg = pd.read_csv(reg_file)
            assert (
                "qid" in qid2reg.columns
                and "regularization" in qid2reg.columns
            ), f"Expected qid and regularization as the column names for {reg_file}"
            # default of no mask
            eid2reg_arr = [0.0
                           ] * entity_symbols.num_entities_with_pad_and_nocand
            for row_idx, row in qid2reg.iterrows():
                if entity_symbols.qid_exists(row["qid"]):
                    eid = entity_symbols.get_eid(row["qid"])
                    eid2reg_arr[eid] = row["regularization"]
            eid2reg = torch.tensor(eid2reg_arr)
            torch.save(eid2reg, prep_file)
            log_rank_0_debug(
                logger,
                f"Finished building and saving entity regularization mapping in {round(time.time() - start, 2)}s.",
            )
        return eid2reg
示例#11
0
    def prep(cls, data_config, emb_args, entity_symbols):
        """Static embedding prep.

        Args:
            data_config: data config
            emb_args: embedding args
            entity_symbols: entity symbols

        Returns: numpy embedding array where each row is the embedding for an EID
        """
        static_str = os.path.splitext(os.path.basename(emb_args.emb_file))[0]
        prep_dir = data_utils.get_emb_prep_dir(data_config)
        prep_file = os.path.join(prep_dir, f"static_table_{static_str}.npy")
        log_rank_0_debug(logger,
                         f"Looking for static embedding saved at {prep_file}")
        if not data_config.overwrite_preprocessed_data and os.path.exists(
                prep_file):
            log_rank_0_debug(
                logger,
                f"Loading existing static embedding table from {prep_file}")
            start = time.time()
            entity2staticemb_table = np.load(prep_file, mmap_mode="r")
            log_rank_0_debug(
                logger,
                f"Loaded existing static embedding table in {round(time.time() - start, 2)}s",
            )
        else:
            start = time.time()
            emb_file = emb_args.emb_file
            log_rank_0_debug(logger,
                             f"Building static table from file {emb_file}")
            entity2staticemb_table = cls.build_static_embeddings(
                emb_file=emb_file, entity_symbols=entity_symbols)
            np.save(prep_file, entity2staticemb_table)
            entity2staticemb_table = np.load(prep_file, mmap_mode="r")
            log_rank_0_debug(
                logger,
                f"Finished building and saving static embedding table in {round(time.time() - start, 2)}s.",
            )
        return entity2staticemb_table
示例#12
0
def count_parameters(model, requires_grad, logger):
    """Counts the number of parameters, printing along the way, with
    param.required_grad == requires_grad.

    Args:
        model: model to count
        requires_grad: whether to look at grad or no grad params
        logger: logger

    Returns:
    """
    for p in [
            p for p in model.named_parameters()
            if p[1].requires_grad is requires_grad
    ]:
        log_rank_0_debug(
            logger,
            "{:s} {:d} {:.2f} MB".format(p[0], p[1].numel(),
                                         p[1].numel() * 4 / 1024**2),
        )
    return sum(p.numel() for p in model.parameters()
               if p.requires_grad is requires_grad)
示例#13
0
    def prep(cls, data_config, emb_args, entity_symbols):
        """Prep the type id table.

        Args:
            data_config: data config
            emb_args: embedding args
            entity_symbols: entity synbols

        Returns: torch tensor from EID to type IDS, type ID to row in type embedding matrix,
                 and number of types with unk type
        """
        type_str = os.path.splitext(emb_args.type_labels.replace("/", "_"))[0]
        prep_dir = data_utils.get_emb_prep_dir(data_config)
        prep_file = os.path.join(
            prep_dir, f"type_table_{type_str}_{emb_args.max_types}.pt")
        utils.ensure_dir(os.path.dirname(prep_file))
        if not data_config.overwrite_preprocessed_data and os.path.exists(
                prep_file):
            log_rank_0_debug(logger,
                             f"Loading existing type table from {prep_file}")
            start = time.time()
            eid2typeids_table, type2row_dict, num_types_with_unk = torch.load(
                prep_file)
            log_rank_0_debug(
                logger,
                f"Loaded existing type table in {round(time.time() - start, 2)}s",
            )
        else:
            start = time.time()
            type_labels = os.path.join(data_config.emb_dir,
                                       emb_args.type_labels)
            type_vocab = os.path.join(data_config.emb_dir, emb_args.type_vocab)
            log_rank_0_debug(logger, f"Building type table from {type_labels}")
            eid2typeids_table, type2row_dict, num_types_with_unk = cls.build_type_table(
                type_labels=type_labels,
                type_vocab=type_vocab,
                max_types=emb_args.max_types,
                entity_symbols=entity_symbols,
            )
            torch.save((eid2typeids_table, type2row_dict, num_types_with_unk),
                       prep_file)
            log_rank_0_debug(
                logger,
                f"Finished building and saving type table in {round(time.time() - start, 2)}s.",
            )
        return eid2typeids_table, type2row_dict, num_types_with_unk, prep_file
示例#14
0
    def prep(
        cls,
        data_config,
        emb_args,
        entity_symbols,
        threshold,
        log_weight,
    ):
        """Preps the KG information.

        Args:
            data_config: data config
            emb_args: embedding args
            entity_symbols: entity symbols
            threshold: weight threshold for counting an edge
            log_weight: whether to take the log of the weight value after the threshold

        Returns: numpy sparce KG adjacency matrix, prep file
        """
        file_tag = os.path.splitext(emb_args.kg_adj.replace("/", "_"))[0]
        prep_dir = data_utils.get_emb_prep_dir(data_config)
        prep_file = os.path.join(
            prep_dir, f"kg_adj_file_{file_tag}_{threshold}_{log_weight}.npz")
        utils.ensure_dir(os.path.dirname(prep_file))
        if not data_config.overwrite_preprocessed_data and os.path.exists(
                prep_file):
            log_rank_0_debug(logger,
                             f"Loading existing KG adj from {prep_file}")
            start = time.time()
            kg_adj = scipy.sparse.load_npz(prep_file)
            log_rank_0_debug(
                logger,
                f"Loaded existing KG adj in {round(time.time() - start, 2)}s")
        else:
            start = time.time()
            kg_adj_file = os.path.join(data_config.emb_dir, emb_args.kg_adj)
            log_rank_0_debug(logger, f"Building KG adj from {kg_adj_file}")
            kg_adj = cls.build_kg_adj(kg_adj_file, entity_symbols, threshold,
                                      log_weight)
            scipy.sparse.save_npz(prep_file, kg_adj)
            log_rank_0_debug(
                logger,
                f"Finished building and saving KG adj in {round(time.time() - start, 2)}s.",
            )
        return kg_adj, prep_file
示例#15
0
def convert_examples_to_features_and_save(meta_file, dataset_threads,
                                          slice_names, save_dataset_name,
                                          storage):
    """Converts the prepped examples into input features and saves in memmap
    files. These are used in the __get_item__ method.

    Args:
        meta_file: metadata file where input file paths are saved
        dataset_threads: number of threads
        slice_names: list of slice names to evaluation on
        save_dataset_name: data file name to save
        storage: data storage type (for memmap)

    Returns:
    """
    log_rank_0_debug(logger, "Starting to extract subsentences")
    start = time.time()
    num_processes = min(dataset_threads,
                        int(0.8 * multiprocessing.cpu_count()))

    log_rank_0_info(
        logger,
        f"Starting to build and save features with {num_processes} threads")

    log_rank_0_debug(logger, f"Counting lines")
    total_input = utils.load_json_file(meta_file)["num_mentions"]
    max_alias2pred = utils.load_json_file(meta_file)["max_alias2pred"]
    files_and_counts = utils.load_json_file(meta_file)["files_and_counts"]

    # IMPORTANT: for distributed writing to memmap files, you must create them in w+ mode before
    # being opened in r+ mode by workers
    memmap_file = np.memmap(save_dataset_name,
                            dtype=storage,
                            mode="w+",
                            shape=(total_input, ),
                            order="C")
    # Save -1 in sent_idx to check that things are loaded correctly later
    memmap_file[slice_names[0]]["sent_idx"][:] = -1

    input_args = []
    # Saves where in memap file to start writing
    offset = 0
    for i, in_file_name in enumerate(files_and_counts.keys()):
        input_args.append({
            "file_name": in_file_name,
            "in_file_lines": files_and_counts[in_file_name],
            "save_file_offset": offset,
            "ex_print_mod": int(np.ceil(total_input / 20)),
            "slice_names": slice_names,
            "max_alias2pred": max_alias2pred,
        })
        offset += files_and_counts[in_file_name]

    if num_processes == 1:
        assert len(input_args) == 1
        total_output = convert_examples_to_features_and_save_single(
            input_args[0], memmap_file)
    else:
        log_rank_0_debug(
            logger,
            "Initializing pool. This make take a few minutes.",
        )
        pool = multiprocessing.Pool(
            processes=num_processes,
            initializer=convert_examples_to_features_and_save_initializer,
            initargs=[save_dataset_name, storage],
        )

        total_output = 0
        for res in pool.imap_unordered(
                convert_examples_to_features_and_save_hlp, input_args,
                chunksize=1):
            total_output += res
        pool.close()

    # Verify that sentences are unique and saved correctly
    mmap_file = np.memmap(save_dataset_name, dtype=storage, mode="r")
    all_uniq_ids = set()
    for i in tqdm(range(total_input), desc="Checking sentence uniqueness"):
        assert (mmap_file[slice_names[0]]["sent_idx"][i] !=
                -1), f"Index {i} has -1 sent idx"
        uniq_id = str(
            f"{mmap_file[slice_names[0]]['sent_idx'][i]}.{mmap_file[slice_names[0]]['subslice_idx'][i]}"
        )
        assert (uniq_id not in all_uniq_ids
                ), f"Idx {uniq_id} is not unique and already in data"
        all_uniq_ids.add(uniq_id)

    log_rank_0_debug(
        logger,
        f"Done with extracting examples in {time.time() - start}. Total lines seen {total_input}. "
        f"Total lines kept {total_output}",
    )
    return
示例#16
0
    def __init__(
        self,
        main_args,
        dataset,
        use_weak_label,
        entity_symbols,
        dataset_threads,
        split="train",
    ):
        global_start = time.time()
        log_rank_0_info(logger,
                        f"Building slice dataset for {split} from {dataset}.")
        spawn_method = main_args.run_config.spawn_method
        data_config = main_args.data_config
        orig_spawn = multiprocessing.get_start_method()
        multiprocessing.set_start_method(spawn_method, force=True)
        self.slice_names = data_utils.get_eval_slices(data_config.eval_slices)
        self.get_slice_dt = lambda max_a2p: np.dtype([
            ("sent_idx", int),
            ("subslice_idx", int),
            ("alias_slice_incidence", int, (max_a2p, )),
            ("prob_labels", float, (max_a2p, )),
        ])
        self.get_storage = lambda max_a2p: np.dtype(
            [(slice_name, self.get_slice_dt(max_a2p))
             for slice_name in self.slice_names])
        # Folder for all mmap saved files
        save_dataset_folder = data_utils.get_save_data_folder(
            data_config, use_weak_label, dataset)
        utils.ensure_dir(save_dataset_folder)
        # Folder for temporary output files
        temp_output_folder = os.path.join(data_config.data_dir,
                                          data_config.data_prep_dir,
                                          f"prep_{split}_slice_files")
        utils.ensure_dir(temp_output_folder)
        # Input step 1
        create_ex_indir = os.path.join(temp_output_folder,
                                       "create_examples_input")
        utils.ensure_dir(create_ex_indir)
        # Input step 2
        create_ex_outdir = os.path.join(temp_output_folder,
                                        "create_examples_output")
        utils.ensure_dir(create_ex_outdir)
        # Meta data saved files
        meta_file = os.path.join(temp_output_folder, "meta_data.json")
        # File for standard training data
        hash = hashlib.sha1(str(
            self.slice_names).encode("UTF-8")).hexdigest()[:10]
        self.save_dataset_name = os.path.join(save_dataset_folder,
                                              f"ned_slices_{hash}.bin")
        self.save_data_config_name = os.path.join(save_dataset_folder,
                                                  "ned_slices_config.json")

        # =======================================================================================
        # SLICE DATA
        # =======================================================================================
        log_rank_0_debug(logger, "Loading dataset...")
        log_rank_0_debug(logger, f"Seeing if {self.save_dataset_name} exists")
        if data_config.overwrite_preprocessed_data or (not os.path.exists(
                self.save_dataset_name)):
            st_time = time.time()
            try:
                log_rank_0_info(
                    logger,
                    f"Building dataset from scratch. Saving to {save_dataset_folder}",
                )
                create_examples(
                    dataset,
                    create_ex_indir,
                    create_ex_outdir,
                    meta_file,
                    data_config,
                    dataset_threads,
                    self.slice_names,
                    use_weak_label,
                    split,
                )
                max_alias2pred = utils.load_json_file(
                    meta_file)["max_alias2pred"]
                convert_examples_to_features_and_save(
                    meta_file,
                    dataset_threads,
                    self.slice_names,
                    self.save_dataset_name,
                    self.get_storage(max_alias2pred),
                )
                utils.dump_json_file(self.save_data_config_name,
                                     {"max_alias2pred": max_alias2pred})

                log_rank_0_debug(
                    logger,
                    f"Finished prepping data in {time.time() - st_time}")
            except Exception as e:
                tb = traceback.TracebackException.from_exception(e)
                logger.error(e)
                logger.error("\n".join(tb.stack.format()))
                shutil.rmtree(save_dataset_folder, ignore_errors=True)
                raise

        log_rank_0_info(
            logger,
            f"Loading data from {self.save_dataset_name} and {self.save_data_config_name}",
        )
        max_alias2pred = utils.load_json_file(
            self.save_data_config_name)["max_alias2pred"]
        self.data, self.sent_to_row_id_dict = self.build_data_dict(
            self.save_dataset_name, self.get_storage(max_alias2pred))
        assert len(self.data) > 0
        assert len(self.sent_to_row_id_dict) > 0
        log_rank_0_debug(logger, f"Removing temporary output files")
        shutil.rmtree(temp_output_folder, ignore_errors=True)
        # Set spawn back to original/default, which is "fork" or "spawn". This is needed for the Meta.config to
        # be correctly passed in the collate_fn.
        multiprocessing.set_start_method(orig_spawn, force=True)
        log_rank_0_info(
            logger,
            f"Final slice data initialization time from {split} is {time.time() - global_start}s",
        )
示例#17
0
def run_model(mode, config, run_config_path=None):
    """
    Main run method for Emmental Bootleg models.
    Args:
        mode: run mode (train, eval, dump_preds, dump_embs)
        config: parsed model config
        run_config_path: original config path (for saving)

    Returns:

    """

    # Set up distributed backend and save configuration files
    setup(config, run_config_path)

    # Load entity symbols
    log_rank_0_info(logger, f"Loading entity symbols...")
    entity_symbols = EntitySymbols.load_from_cache(
        load_dir=os.path.join(config.data_config.entity_dir,
                              config.data_config.entity_map_dir),
        alias_cand_map_file=config.data_config.alias_cand_map,
        alias_idx_file=config.data_config.alias_idx_map,
    )
    # Create tasks
    tasks = [NED_TASK]
    if config.data_config.type_prediction.use_type_pred is True:
        tasks.append(TYPE_PRED_TASK)

    # Create splits for data loaders
    data_splits = [TRAIN_SPLIT, DEV_SPLIT, TEST_SPLIT]
    # Slices are for eval so we only split on test/dev
    slice_splits = [DEV_SPLIT, TEST_SPLIT]
    # If doing eval, only run on test data
    if mode in ["eval", "dump_preds", "dump_embs"]:
        data_splits = [TEST_SPLIT]
        slice_splits = [TEST_SPLIT]
        # We only do dumping if weak labels is True
        if mode in ["dump_preds", "dump_embs"]:
            if config.data_config[
                    f"{TEST_SPLIT}_dataset"].use_weak_label is False:
                raise ValueError(
                    f"When calling dump_preds or dump_embs, we require use_weak_label to be True."
                )

    # Gets embeddings that need to be prepped during data prep or in the __get_item__ method
    batch_on_the_fly_kg_adj = get_dataloader_embeddings(config, entity_symbols)
    # Gets dataloaders
    dataloaders = get_dataloaders(
        config,
        tasks,
        data_splits,
        entity_symbols,
        batch_on_the_fly_kg_adj,
    )
    slice_datasets = get_slicedatasets(config, slice_splits, entity_symbols)

    configure_optimizer(config)

    # Create models and add tasks
    if config.model_config.attn_class == "BERTNED":
        log_rank_0_info(logger, f"Starting NED-Base Model")
        assert (config.data_config.type_prediction.use_type_pred is
                False), f"NED-Base does not support type prediction"
        assert (
            config.data_config.word_embedding.use_sent_proj is False
        ), f"NED-Base requires word_embeddings.use_sent_proj to be False"
        model = EmmentalModel(name="NED-Base")
        model.add_tasks(
            ned_task.create_task(config, entity_symbols, slice_datasets))
    else:
        log_rank_0_info(logger, f"Starting Bootleg Model")
        model = EmmentalModel(name="Bootleg")
        # TODO: make this more general for other tasks -- iterate through list of tasks
        # and add task for each
        model.add_task(
            ned_task.create_task(config, entity_symbols, slice_datasets))
        if TYPE_PRED_TASK in tasks:
            model.add_task(
                type_pred_task.create_task(config, entity_symbols,
                                           slice_datasets))
            # Add the mention type embedding to the embedding payload
            type_pred_task.update_ned_task(model)

    # Print param counts
    if mode == "train":
        log_rank_0_debug(logger, "PARAMS WITH GRAD\n" + "=" * 30)
        total_params = count_parameters(model,
                                        requires_grad=True,
                                        logger=logger)
        log_rank_0_info(logger, f"===> Total Params With Grad: {total_params}")
        log_rank_0_debug(logger, "PARAMS WITHOUT GRAD\n" + "=" * 30)
        total_params = count_parameters(model,
                                        requires_grad=False,
                                        logger=logger)
        log_rank_0_info(logger,
                        f"===> Total Params Without Grad: {total_params}")

    # Load the best model from the pretrained model
    if config["model_config"]["model_path"] is not None:
        model.load(config["model_config"]["model_path"])

    # Barrier
    if config["learner_config"]["local_rank"] == 0:
        torch.distributed.barrier()

    # Train model
    if mode == "train":
        emmental_learner = EmmentalLearner()
        emmental_learner._set_optimizer(model)
        emmental_learner.learn(model, dataloaders)
        if config.learner_config.local_rank in [0, -1]:
            model.save(f"{emmental.Meta.log_path}/last_model.pth")

    # Multi-gpu DataParallel eval (NOT distributed)
    if mode in ["eval", "dump_embs", "dump_preds"]:
        # This happens inside EmmentalLearner for training
        if (config["learner_config"]["local_rank"] == -1
                and config["model_config"]["dataparallel"]):
            model._to_dataparallel()

    # If just finished training a model or in eval mode, run eval
    if mode in ["train", "eval"]:
        scores = model.score(dataloaders)
        # Save metrics and models
        log_rank_0_info(logger, f"Saving metrics to {emmental.Meta.log_path}")
        log_rank_0_info(logger, f"Metrics: {scores}")
        scores["log_path"] = emmental.Meta.log_path
        if config.learner_config.local_rank in [0, -1]:
            write_to_file(f"{emmental.Meta.log_path}/{mode}_metrics.txt",
                          scores)
            eval_utils.write_disambig_metrics_to_csv(
                f"{emmental.Meta.log_path}/{mode}_disambig_metrics.csv",
                scores)
        return scores

    # If you want detailed dumps, save model outputs
    assert mode in [
        "dump_preds",
        "dump_embs",
    ], 'Mode must be "dump_preds" or "dump_embs"'
    dump_embs = False if mode != "dump_embs" else True
    assert (
        len(dataloaders) == 1
    ), f"We should only have length 1 dataloaders for dump_embs and dump_preds!"
    final_result_file, final_out_emb_file = None, None
    if config.learner_config.local_rank in [0, -1]:
        # Setup files/folders
        filename = os.path.basename(dataloaders[0].dataset.raw_filename)
        log_rank_0_debug(
            logger,
            f"Collecting sentence to mention map {os.path.join(config.data_config.data_dir, filename)}",
        )
        sentidx2num_mentions, sent_idx2row = eval_utils.get_sent_idx2num_mens(
            os.path.join(config.data_config.data_dir, filename))
        log_rank_0_debug(logger, f"Done collecting sentence to mention map")
        eval_folder = eval_utils.get_eval_folder(filename)
        subeval_folder = os.path.join(eval_folder, "batch_results")
        utils.ensure_dir(subeval_folder)
        # Will keep track of sentences dumped already. These will only be ones with mentions
        all_dumped_sentences = set()
        number_dumped_batches = 0
        total_mentions_seen = 0
        all_result_files = []
        all_out_emb_files = []
        # Iterating over batches of predictions
        for res_i, res_dict in enumerate(
                eval_utils.batched_pred_iter(
                    model,
                    dataloaders[0],
                    config.run_config.eval_accumulation_steps,
                    sentidx2num_mentions,
                )):
            (
                result_file,
                out_emb_file,
                final_sent_idxs,
                mentions_seen,
            ) = eval_utils.disambig_dump_preds(
                res_i,
                total_mentions_seen,
                config,
                res_dict,
                sentidx2num_mentions,
                sent_idx2row,
                subeval_folder,
                entity_symbols,
                dump_embs,
                NED_TASK,
            )
            all_dumped_sentences.update(final_sent_idxs)
            all_result_files.append(result_file)
            all_out_emb_files.append(out_emb_file)
            total_mentions_seen += mentions_seen
            number_dumped_batches += 1

        # Dump the sentences that had no mentions and were not already dumped
        # Assert all remaining sentences have no mentions
        assert all(
            v == 0 for k, v in sentidx2num_mentions.items()
            if k not in all_dumped_sentences
        ), (f"Sentences with mentions were not dumped: "
            f"{[k for k, v in sentidx2num_mentions.items() if k not in all_dumped_sentences]}"
            )
        empty_sentidx2row = {
            k: v
            for k, v in sent_idx2row.items() if k not in all_dumped_sentences
        }
        empty_resultfile = eval_utils.get_result_file(number_dumped_batches,
                                                      subeval_folder)
        all_result_files.append(empty_resultfile)
        # Dump the outputs
        eval_utils.write_data_labels_single(
            sentidx2row=empty_sentidx2row,
            output_file=empty_resultfile,
            filt_emb_data=None,
            sental2embid={},
            alias_cand_map=entity_symbols.get_alias2qids(),
            qid2eid=entity_symbols.get_qid2eid(),
            result_alias_offset=total_mentions_seen,
            train_in_cands=config.data_config.train_in_candidates,
            max_cands=entity_symbols.max_candidates,
            dump_embs=dump_embs,
        )

        log_rank_0_info(
            logger,
            f"Finished dumping. Merging results across accumulation steps.")
        # Final result files for labels and embeddings
        final_result_file = os.path.join(eval_folder,
                                         config.run_config.result_label_file)
        # Copy labels
        output = open(final_result_file, "wb")
        for file in all_result_files:
            shutil.copyfileobj(open(file, "rb"), output)
        output.close()
        log_rank_0_info(logger, f"Bootleg labels saved at {final_result_file}")
        # Try to copy embeddings
        if dump_embs:
            final_out_emb_file = os.path.join(
                eval_folder, config.run_config.result_emb_file)
            log_rank_0_info(
                logger,
                f"Trying to merge numpy embedding arrays. "
                f"If your machine is limited in memory, this may cause OOM errors. "
                f"Is that happens, result files should be saved in {subeval_folder}.",
            )
            all_arrays = []
            for i, npfile in enumerate(all_out_emb_files):
                all_arrays.append(np.load(npfile))
            np.save(final_out_emb_file, np.concatenate(all_arrays))
            log_rank_0_info(
                logger, f"Bootleg embeddings saved at {final_out_emb_file}")

        # Cleanup
        try_rmtree(subeval_folder)
    return final_result_file, final_out_emb_file
示例#18
0
def create_examples(
    dataset,
    create_ex_indir,
    create_ex_outdir,
    meta_file,
    data_config,
    dataset_threads,
    slice_names,
    use_weak_label,
    split,
):
    """Creates examples from the raw input data.

    Args:
        dataset: dataset file
        create_ex_indir: temporary directory where input files are stored
        create_ex_outdir: temporary directory to store output files from method
        meta_file: metadata file to save the file names/paths for the next step in prep pipeline
        data_config: data config
        dataset_threads: number of threads
        slice_names: list of slices to evaluate on
        use_weak_label: whether to use weak labeling or not
        split: data split

    Returns:
    """
    log_rank_0_debug(logger, "Starting to extract subsentences")
    start = time.time()
    num_processes = min(dataset_threads,
                        int(0.8 * multiprocessing.cpu_count()))

    log_rank_0_debug(logger, f"Counting lines")
    total_input = sum(1 for _ in open(dataset))
    if num_processes == 1:
        out_file_name = os.path.join(create_ex_outdir,
                                     os.path.basename(dataset))
        constants_dict = {
            "slice_names": slice_names,
            "use_weak_label": use_weak_label,
            "max_aliases": data_config.max_aliases,
            "split": split,
            "train_in_candidates": data_config.train_in_candidates,
        }
        files_and_counts = {}
        res = create_examples_single(dataset, total_input, out_file_name,
                                     constants_dict)
        total_output = res["total_lines"]
        max_alias2pred = res["max_alias2pred"]
        files_and_counts[res["output_filename"]] = res["total_lines"]
    else:
        log_rank_0_info(
            logger,
            f"Strating to extract examples with {num_processes} threads")
        log_rank_0_debug(
            logger, "Parallelizing with " + str(num_processes) + " threads.")
        chunk_input = int(np.ceil(total_input / num_processes))
        log_rank_0_debug(
            logger,
            f"Chunking up {total_input} lines into subfiles of size {chunk_input} lines",
        )
        total_input_from_chunks, input_files_dict = utils.chunk_file(
            dataset, create_ex_indir, chunk_input)

        input_files = list(input_files_dict.keys())
        input_file_lines = [input_files_dict[k] for k in input_files]
        output_files = [
            in_file_name.replace(create_ex_indir, create_ex_outdir)
            for in_file_name in input_files
        ]
        assert (
            total_input == total_input_from_chunks
        ), f"Lengths of files {total_input} doesn't mathc {total_input_from_chunks}"
        log_rank_0_debug(logger, f"Done chunking files")

        pool = multiprocessing.Pool(
            processes=num_processes,
            initializer=create_examples_initializer,
            initargs=[
                data_config,
                slice_names,
                use_weak_label,
                data_config.max_aliases,
                split,
                data_config.train_in_candidates,
            ],
        )
        total_output = 0
        max_alias2pred = 0
        input_args = list(zip(input_files, input_file_lines, output_files))
        # Store output files and counts for saving in next step
        files_and_counts = {}
        for res in pool.imap_unordered(create_examples_hlp,
                                       input_args,
                                       chunksize=1):
            total_output += res["total_lines"]
            max_alias2pred = max(max_alias2pred, res["max_alias2pred"])
            files_and_counts[res["output_filename"]] = res["total_lines"]
        pool.close()
    utils.dump_json_file(
        meta_file,
        {
            "num_mentions": total_output,
            "files_and_counts": files_and_counts,
            "max_alias2pred": max_alias2pred,
        },
    )
    log_rank_0_debug(
        logger,
        f"Done with extracting examples in {time.time()-start}. Total lines seen {total_input}. "
        f"Total lines kept {total_output}.",
    )
    return
示例#19
0
    def build_type_table(cls, type_labels, type_vocab, max_types,
                         entity_symbols):
        """Builds the EID to type ids table.

        Args:
            type_labels: QID to type ids or type names json mapping
            type_vocab: type name to type ids
            max_types: maximum number of types for an entity
            entity_symbols: entity symbols

        Returns: torch tensor from EID to type IDS, type ID to row in type embedding matrix,
                 and number of types with unk type
        """
        with open(type_vocab) as f:
            vocab = json.load(f)
        all_type_ids = set(list(vocab.values()))
        assert (
            0 not in all_type_ids
        ), f"The type id of 0 is reserved for UNK type. Please offset the typeids by 1"
        # all eids are initially assigned to unk types
        # if they occur in the type file, then they are assigned the types in the file plus padded types
        eid2typeids = torch.zeros(
            entity_symbols.num_entities_with_pad_and_nocand, max_types)
        eid2typeids[0] = torch.zeros(1, max_types)

        max_type_id_all = max(all_type_ids)
        type_hit = 0
        type2row_dict = {}
        with open(type_labels) as f:
            qid2typeid = json.load(f)
            for qid, row_types in qid2typeid.items():
                if not entity_symbols.qid_exists(qid):
                    continue
                # assign padded types to the last row
                typeids = torch.ones(max_types) * -1
                if len(row_types) > 0:
                    type_hit += 1
                    # increment by 1 to account for unk row
                    typeids_list = []
                    for type_id_or_name in row_types:
                        # If typename, map to typeid
                        if type(type_id_or_name) is str:
                            type_id = vocab[type_id_or_name]
                        else:
                            type_id = type_id_or_name
                        assert (
                            type_id > 0
                        ), f"Typeid for {qid} is 0. That is reserved. Please offset by 1"
                        assert (type_id in all_type_ids
                                ), f"Typeid for {qid} isn't in vocab"
                        typeids_list.append(type_id)
                        type2row_dict[type_id] = type_id
                    num_types = min(len(typeids_list), max_types)
                    typeids[:num_types] = torch.tensor(
                        typeids_list)[:num_types]
                    eid2typeids[entity_symbols.get_eid(qid)] = typeids
        # + 1 bc we need to account for pad row
        labeled_num_types = max_type_id_all + 1
        # assign padded types to the last row of the type embedding
        # make sure adding type labels doesn't add new types
        assert (max_type_id_all + 1) <= labeled_num_types
        eid2typeids[eid2typeids == -1] = labeled_num_types
        log_rank_0_debug(
            logger,
            f"{round(type_hit/entity_symbols.num_entities, 2)*100}% of entities are assigned types",
        )
        return eid2typeids.long(), type2row_dict, labeled_num_types
示例#20
0
    def build_static_embeddings(cls, emb_file, entity_symbols):
        """Builds the table of the embedding associated with each entity.

        Args:
            emb_file: embedding file to load
            entity_symbols: entity symbols

        Returns: numpy embedding matrix where each row is an emedding
        """
        ending = os.path.splitext(emb_file)[1]
        found = 0
        raw_num_ents = 0
        if ending == ".json":
            dct = utils.load_json_file(emb_file)
            val = next(iter(dct.values()))
            if type(val) is int or type(val) is float:
                embedding_size = 1
                conver_func = lambda x: np.array([x])
            elif type(val) is list:
                embedding_size = len(val)
                conver_func = lambda x: np.array([y for y in x])
            else:
                raise ValueError(
                    f"Unrecognized type for the array value of {type(val)}"
                )
            embeddings = {}
            for k in dct:
                embeddings[k] = conver_func(dct[k])
                assert len(embeddings[k]) == embedding_size
            entity2staticemb_table = np.zeros(
                (entity_symbols.num_entities_with_pad_and_nocand, embedding_size)
            )
            raw_num_ents = len(embeddings)
            for qid in tqdm(entity_symbols.get_all_qids()):
                if qid in embeddings:
                    found += 1
                    emb = embeddings[qid]
                    eid = entity_symbols.get_eid(qid)
                    entity2staticemb_table[eid, :embedding_size] = emb
        elif ending == ".pt":
            log_rank_0_debug(
                logger,
                f"We are readining in the embedding file from a .pt. We assume this is already mapped to eids",
            )
            (qid2eid_map, entity2staticemb_table_raw) = torch.load(emb_file)
            entity2staticemb_table_raw = (
                entity2staticemb_table_raw.detach().cpu().numpy()
            )
            raw_num_ents = entity2staticemb_table_raw.shape[0]
            # +2 handles the PAD and UNK entities
            assert entity2staticemb_table_raw.shape[0] == len(qid2eid_map) + 2, (
                f"The saved static embeddings file had mismatched shapes between qid2eid {len(qid2eid_map)} and "
                f"weights {entity2staticemb_table_raw.shape[0]}"
            )
            entity2staticemb_table = np.zeros(
                (
                    entity_symbols.num_entities_with_pad_and_nocand,
                    entity2staticemb_table_raw.shape[1],
                )
            )
            found = 0
            for qid in tqdm(entity_symbols.get_all_qids()):
                if qid in qid2eid_map:
                    found += 1
                    raw_eid = qid2eid_map[qid]
                    emb = entity2staticemb_table_raw[raw_eid]
                    new_eid = entity_symbols.get_eid(qid)
                    entity2staticemb_table[new_eid, :] = emb
        else:
            raise ValueError(
                f"We do not support static embeddings from {ending}. We only support .json and .pt"
            )
        log_rank_0_debug(
            logger,
            f"Found {found} ({found/len(entity_symbols.get_all_qids())} percent) of all entities after "
            f"reading {raw_num_ents} original entities have a static embedding",
        )
        return entity2staticemb_table
示例#21
0
    def __init__(
        self,
        main_args,
        emb_args,
        entity_symbols,
        key,
        cpu,
        normalize,
        dropout1d_perc,
        dropout2d_perc,
    ):
        super(TopKEntityEmb, self).__init__(
            main_args=main_args,
            emb_args=emb_args,
            entity_symbols=entity_symbols,
            key=key,
            cpu=cpu,
            normalize=normalize,
            dropout1d_perc=dropout1d_perc,
            dropout2d_perc=dropout2d_perc,
        )
        allowable_keys = {
            "learned_embedding_size",
            "perc_emb_drop",
            "qid2topk_eid",
            "regularize_mapping",
            "tail_init",
            "tail_init_zeros",
        }
        correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args)
        if not correct:
            raise ValueError(f"The key {bad_key} is not in {allowable_keys}")
        assert (
            "learned_embedding_size" in emb_args
        ), f"TopKEntityEmb must have learned_embedding_size in args"
        assert "perc_emb_drop" in emb_args, (
            f"To use TopKEntityEmb we need perc_emb_drop to be in the args. This gives the percentage of embeddings"
            f" removed."
        )
        self.learned_embedding_size = emb_args.learned_embedding_size
        # We remove perc_emb_drop percent of the embeddings and add one to represent the new toes embedding
        num_topk_entities_with_pad_and_nocand = (
            entity_symbols.num_entities_with_pad_and_nocand
            - int(emb_args.perc_emb_drop * entity_symbols.num_entities)
            + 1
        )
        # Mapping of entity to the new eid mapping
        eid2topkeid = torch.arange(0, entity_symbols.num_entities_with_pad_and_nocand)
        # There are issues with using -1 index into the embeddings; so we manually set it to be the last value
        eid2topkeid[-1] = num_topk_entities_with_pad_and_nocand - 1
        if "qid2topk_eid" not in emb_args:
            assert self.from_pretrained, (
                f"If you don't provide the qid2topk_eid mapping as an argument to TopKEntityEmb, "
                f"you must be loading a model from a checkpoint to build this index mapping"
            )

        self.learned_entity_embedding = nn.Embedding(
            num_topk_entities_with_pad_and_nocand,
            self.learned_embedding_size,
            padding_idx=-1,
            sparse=False,
        )

        self._dim = main_args.model_config.hidden_size

        if "regularize_mapping" in emb_args:
            eid2reg = torch.zeros(num_topk_entities_with_pad_and_nocand)
        else:
            eid2reg = None
        # If tail_init is false, all embeddings are randomly intialized.
        # If tail_init is true, we initialize all embeddings to be the same.
        self.tail_init = True
        self.tail_init_zeros = False
        # None init vec will be random
        init_vec = None
        if not self.from_pretrained:
            qid2topk_eid = utils.load_json_file(emb_args.qid2topk_eid)
            assert (
                len(qid2topk_eid) == entity_symbols.num_entities
            ), f"You must have an item in qid2topk_eid for each qid in entity_symbols"
            for qid in entity_symbols.get_all_qids():
                old_eid = entity_symbols.get_eid(qid)
                new_eid = qid2topk_eid[qid]
                eid2topkeid[old_eid] = new_eid
            assert eid2topkeid[0] == 0, f"The 0 eid shouldn't be changed"
            assert (
                eid2topkeid[-1] == num_topk_entities_with_pad_and_nocand - 1
            ), "The -1 eid should still map to -1"

            if "tail_init" in emb_args:
                self.tail_init = emb_args.tail_init
            if "tail_init_zeros" in emb_args:
                self.tail_init_zeros = emb_args.tail_init_zeros
                self.tail_init = False
                init_vec = torch.zeros(1, self.learned_embedding_size)
            assert not (
                self.tail_init and self.tail_init_zeros
            ), f"Can only have one of tail_init or tail_init_zeros set"
            if self.tail_init or self.tail_init_zeros:
                if self.tail_init_zeros:
                    log_rank_0_debug(
                        logger,
                        f"All learned entity embeddings are initialized to zero.",
                    )
                else:
                    log_rank_0_debug(
                        logger,
                        f"All learned entity embeddings are initialized to the same value.",
                    )
                init_vec = model_utils.init_embeddings_to_vec(
                    self.learned_entity_embedding, pad_idx=-1, vec=init_vec
                )
                vec_save_file = os.path.join(
                    emmental.Meta.log_path, "init_vec_entity_embs.npy"
                )
                log_rank_0_debug(logger, f"Saving init vector to {vec_save_file}")
                if (
                    torch.distributed.is_initialized()
                    and torch.distributed.get_rank() == 0
                ):
                    np.save(vec_save_file, init_vec)
            else:
                log_rank_0_debug(
                    logger, f"All learned embeddings are randomly initialized."
                )
            # Regularization mapping goes from eid to 2d dropout percent
            if "regularize_mapping" in emb_args:
                log_rank_0_debug(
                    logger,
                    f"You are using regularization mapping with a topK entity embedding. "
                    f"This means all QIDs that are mapped to the same"
                    f" EID will get the same regularization value.",
                )
                if self.dropout1d_perc > 0 or self.dropout2d_perc > 0:
                    log_rank_0_debug(
                        logger,
                        f"You have 1D or 2D regularization set with a regularize_mapping. Do you mean to do this?",
                    )
                log_rank_0_debug(
                    logger,
                    f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}",
                )
                eid2reg = self.load_regularization_mapping(
                    main_args.data_config,
                    entity_symbols,
                    qid2topk_eid,
                    num_topk_entities_with_pad_and_nocand,
                    emb_args.regularize_mapping,
                )
        # Keep this mapping so a topK model can simply be loaded without needing the new eid mapping
        self.register_buffer("eid2topkeid", eid2topkeid)
        self.register_buffer("eid2reg", eid2reg)
示例#22
0
    def __init__(
        self,
        main_args,
        emb_args,
        entity_symbols,
        key,
        cpu,
        normalize,
        dropout1d_perc,
        dropout2d_perc,
    ):
        super(LearnedEntityEmb, self).__init__(
            main_args=main_args,
            emb_args=emb_args,
            entity_symbols=entity_symbols,
            key=key,
            cpu=cpu,
            normalize=normalize,
            dropout1d_perc=dropout1d_perc,
            dropout2d_perc=dropout2d_perc,
        )
        allowable_keys = {
            "learned_embedding_size",
            "regularize_mapping",
            "tail_init",
            "tail_init_zeros",
        }
        correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args)
        if not correct:
            raise ValueError(f"The key {bad_key} is not in {allowable_keys}")
        assert (
            "learned_embedding_size" in emb_args
        ), f"LearnedEntityEmb must have learned_embedding_size in args"
        self.learned_embedding_size = emb_args.learned_embedding_size

        # Set sparsity based on optimizer and fp16. The None optimizer is Bootleg's SparseDenseAdam.
        # If fp16 is True, must use dense.
        optimiz = main_args.learner_config.optimizer_config.optimizer
        if optimiz in [None, "sparse_adam"] and main_args.learner_config.fp16 is False:
            sparse = True
        else:
            sparse = False

        if (
            torch.distributed.is_initialized()
            and main_args.model_config.distributed_backend == "nccl"
        ):
            sparse = False
        log_rank_0_debug(
            logger, f"Setting sparsity for entity embeddings to be {sparse}"
        )
        self.learned_entity_embedding = nn.Embedding(
            entity_symbols.num_entities_with_pad_and_nocand,
            self.learned_embedding_size,
            padding_idx=-1,
            sparse=sparse,
        )
        self._dim = main_args.model_config.hidden_size

        if "regularize_mapping" in emb_args:
            eid2reg = torch.zeros(entity_symbols.num_entities_with_pad_and_nocand)
        else:
            eid2reg = None

        # If tail_init is false, all embeddings are randomly intialized.
        # If tail_init is true, we initialize all embeddings to be the same.
        self.tail_init = True
        self.tail_init_zeros = False
        # None init vec will be random
        init_vec = None
        if not self.from_pretrained:
            if "tail_init" in emb_args:
                self.tail_init = emb_args.tail_init
            if "tail_init_zeros" in emb_args:
                self.tail_init_zeros = emb_args.tail_init_zeros
                self.tail_init = False
                init_vec = torch.zeros(1, self.learned_embedding_size)
            assert not (
                self.tail_init and self.tail_init_zeros
            ), f"Can only have one of tail_init or tail_init_zeros set"
            if self.tail_init or self.tail_init_zeros:
                if self.tail_init_zeros:
                    log_rank_0_debug(
                        logger,
                        f"All learned entity embeddings are initialized to zero.",
                    )
                else:
                    log_rank_0_debug(
                        logger,
                        f"All learned entity embeddings are initialized to the same value.",
                    )
                init_vec = model_utils.init_embeddings_to_vec(
                    self.learned_entity_embedding, pad_idx=-1, vec=init_vec
                )
                vec_save_file = os.path.join(
                    emmental.Meta.log_path, "init_vec_entity_embs.npy"
                )
                log_rank_0_debug(logger, f"Saving init vector to {vec_save_file}")
                if (
                    torch.distributed.is_initialized()
                    and torch.distributed.get_rank() == 0
                ):
                    np.save(vec_save_file, init_vec)
            else:
                log_rank_0_debug(
                    logger, f"All learned embeddings are randomly initialized."
                )
            # Regularization mapping goes from eid to 2d dropout percent
            if "regularize_mapping" in emb_args:
                if self.dropout1d_perc > 0 or self.dropout2d_perc > 0:
                    log_rank_0_debug(
                        logger,
                        f"You have 1D or 2D regularization set with a regularize_mapping. Do you mean to do this?",
                    )
                log_rank_0_debug(
                    logger,
                    f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}",
                )
                eid2reg = self.load_regularization_mapping(
                    main_args.data_config,
                    entity_symbols,
                    emb_args.regularize_mapping,
                )
        self.register_buffer("eid2reg", eid2reg)
示例#23
0
def merge_subsentences(
    num_processes,
    subset_sent_idx2num_mens,
    cache_folder,
    to_save_file,
    to_save_storage,
    to_read_file,
    to_read_storage,
    dump_embs=False,
):
    """Flatten all sentences back together over sub-sentences; removing the PAD
    aliases from the data I.e., converts from sent_idx -> array of values to
    (sent_idx, alias_idx) -> value with varying numbers of aliases per
    sentence.

    Args:
        num_processes: number of processes
        subset_sent_idx2num_mens: Dict of sentence index to number of mentions for this batch
        cache_folder: cache directory
        to_save_file: memmap file to save results to
        to_save_storage: save file storage type
        to_read_file: memmap file to read predictions from
        to_read_storage: read file storage type
        dump_embs: whether to save embeddings or not

    Returns:
    """
    # Compute sent idx to offset so we know where to fill in mentions
    cur_offset = 0
    sentidx2offset = {}
    for k, v in subset_sent_idx2num_mens.items():
        sentidx2offset[k] = cur_offset
        cur_offset += v
        # print("Sent Idx, Num Mens, Offset", k, v, cur_offset)
    total_num_mentions = cur_offset
    # print("TOTAL", total_num_mentions)
    full_pred_data = np.memmap(to_read_file, dtype=to_read_storage, mode="r")
    M = int(full_pred_data[0]["M"])
    K = int(full_pred_data[0]["K"])
    hidden_size = int(full_pred_data[0]["hidden_size"])
    # print("TOTAL MENS", total_num_mentions)
    filt_emb_data = np.memmap(to_save_file,
                              dtype=to_save_storage,
                              mode="w+",
                              shape=(total_num_mentions, ))
    filt_emb_data["hidden_size"] = hidden_size
    filt_emb_data["sent_idx"][:] = -1
    filt_emb_data["alias_list_pos"][:] = -1

    all_ids = list(range(0, len(full_pred_data)))
    start = time.time()
    if num_processes == 1:
        seen_ids = merge_subsentences_single(
            M,
            K,
            hidden_size,
            dump_embs,
            all_ids,
            filt_emb_data,
            full_pred_data,
            sentidx2offset,
        )
    else:
        # Get trie for sentence start map
        trie_folder = os.path.join(cache_folder, "bootleg_sent_idx2num_mens")
        utils.ensure_dir(trie_folder)
        trie_file = os.path.join(trie_folder, "sentidx.marisa")
        utils.create_single_item_trie(sentidx2offset, out_file=trie_file)
        # Chunk up date
        chunk_size = int(np.ceil(len(full_pred_data) / num_processes))
        row_idx_set_chunks = [
            all_ids[ids:ids + chunk_size]
            for ids in range(0, len(full_pred_data), chunk_size)
        ]
        # Start pool
        input_args = [[M, K, hidden_size, dump_embs, chunk]
                      for chunk in row_idx_set_chunks]
        log_rank_0_debug(
            logger,
            f"Merging sentences together with {num_processes} processes")
        pool = multiprocessing.Pool(
            processes=num_processes,
            initializer=merge_subsentences_initializer,
            initargs=[
                to_save_file,
                to_save_storage,
                to_read_file,
                to_read_storage,
                trie_file,
            ],
        )

        seen_ids = set()
        for sent_ids_seen in pool.imap_unordered(merge_subsentences_hlp,
                                                 input_args,
                                                 chunksize=1):
            for emb_id in sent_ids_seen:
                assert (
                    emb_id not in seen_ids
                ), f"{emb_id} already seen, something went wrong with sub-sentences"
                seen_ids.add(emb_id)
    # filt_emb_data = np.memmap(to_save_file, dtype=to_save_storage, mode="r")
    # for i in range(len(filt_emb_data)):
    #     si = filt_emb_data[i]["sent_idx"]
    #     al_test = filt_emb_data[i]["alias_list_pos"]
    #     if si == -1 or al_test == -1:
    #         print("BAD", i, filt_emb_data[i])
    #         import ipdb; ipdb.set_trace()
    logging.debug(f"Saw {len(seen_ids)} sentences")
    logging.debug(f"Time to merge sub-sentences {time.time() - start}s")
    return
示例#24
0
def write_data_labels(
    num_processes,
    result_alias_offset,
    merged_entity_emb_file,
    merged_storage_type,
    sent_idx2row,
    cache_folder,
    out_file,
    entity_dump,
    train_in_candidates,
    max_candidates,
    dump_embs,
    trie_candidate_map_folder=None,
    trie_qid2eid_file=None,
):
    """Takes the flattened data from merge_sentences and writes out predictions
    to a file, one line per sentence.

    The embedding ids are added to the file if dump_embs is True.

    Args:
        num_processes: number of processes
        result_alias_offset: alias offset of this batch of examples for writing out
        merged_entity_emb_file: input memmap file after merge sentences
        merged_storage_type: input file storage type
        sent_idx2row: Dict of sentence idx to row relevant to this subbatch
        cache_folder: folder to save temporary outputs
        out_file: final output file for predictions
        entity_dump: entity dump
        train_in_candidates: whether NC entities are not in candidate lists
        max_candidates: maximum number of candidates
        dump_embs: whether to dump embeddings or not
        trie_candidate_map_folder: folder where trie of alias->candidate map is stored for parallel proccessing
        trie_qid2eid_file: file where trie of qid->eid map is stored for parallel proccessing

    Returns:
    """
    st = time.time()
    sental2embid = get_sental2embid(merged_entity_emb_file,
                                    merged_storage_type)
    log_rank_0_debug(logger,
                     f"Finished getting sentence map {time.time() - st}s")

    total_input = len(sent_idx2row)
    if num_processes == 1:
        filt_emb_data = np.memmap(merged_entity_emb_file,
                                  dtype=merged_storage_type,
                                  mode="r+")
        write_data_labels_single(
            sentidx2row=sent_idx2row,
            output_file=out_file,
            filt_emb_data=filt_emb_data,
            sental2embid=sental2embid,
            alias_cand_map=entity_dump.get_alias2qids(),
            qid2eid=entity_dump.get_qid2eid(),
            result_alias_offset=result_alias_offset,
            train_in_cands=train_in_candidates,
            max_cands=max_candidates,
            dump_embs=dump_embs,
        )
    else:
        assert (
            trie_candidate_map_folder is not None
        ), "trie_candidate_map_folder is None and you have parallel turned on"
        assert (trie_qid2eid_file is not None
                ), "trie_qid2eid_file is None and you have parallel turned on"

        # Get trie of sentence map
        trie_folder = os.path.join(cache_folder, "bootleg_sental2embid")
        utils.ensure_dir(trie_folder)
        trie_file = os.path.join(trie_folder, "sentidx.marisa")
        utils.create_single_item_trie(sental2embid, out_file=trie_file)
        # Chunk file for parallel writing
        # We do not use TemporaryFolders as the temp dir may not have enough space for large files
        create_ex_indir = os.path.join(cache_folder,
                                       "_bootleg_eval_temp_indir")
        utils.ensure_dir(create_ex_indir)
        create_ex_outdir = os.path.join(cache_folder,
                                        "_bootleg_eval_temp_outdir")
        utils.ensure_dir(create_ex_outdir)
        chunk_input = int(np.ceil(total_input / num_processes))
        logger.debug(
            f"Chunking up {total_input} lines into subfiles of size {chunk_input} lines"
        )
        # Chunk up dictionary of data for parallel processing
        input_files = []
        i = 0
        cur_lines = 0
        file_split = os.path.join(create_ex_indir, f"out{i}.jsonl")
        open_file = open(file_split, "w")
        for s_idx in sent_idx2row:
            if cur_lines >= chunk_input:
                open_file.close()
                input_files.append(file_split)
                cur_lines = 0
                i += 1
                file_split = os.path.join(create_ex_indir, f"out{i}.jsonl")
                open_file = open(file_split, "w")
            line = sent_idx2row[s_idx]
            open_file.write(ujson.dumps(line) + "\n")
            cur_lines += 1
        open_file.close()
        input_files.append(file_split)
        # Generation input/output pairs
        output_files = [
            in_file_name.replace(create_ex_indir, create_ex_outdir)
            for in_file_name in input_files
        ]
        log_rank_0_debug(logger, f"Done chunking files. Starting pool")

        pool = multiprocessing.Pool(
            processes=num_processes,
            initializer=write_data_labels_initializer,
            initargs=[
                merged_entity_emb_file,
                merged_storage_type,
                trie_file,
                result_alias_offset,
                train_in_candidates,
                max_candidates,
                dump_embs,
                trie_candidate_map_folder,
                trie_qid2eid_file,
            ],
        )

        input_args = list(zip(input_files, output_files))

        total = 0
        for res in pool.imap(write_data_labels_hlp, input_args, chunksize=1):
            total += 1

        # Merge output files to final file
        log_rank_0_debug(logger, f"Merging output files")
        with open(out_file, "wb") as outfile:
            for filename in glob.glob(os.path.join(create_ex_outdir, "*")):
                if filename == out_file:
                    # don't want to copy the output into the output
                    continue
                with open(filename, "rb") as readfile:
                    shutil.copyfileobj(readfile, outfile)
示例#25
0
def disambig_dump_preds(
    result_idx,
    result_alias_offset,
    config,
    res_dict,
    sent_idx2num_mens,
    sent_idx2row,
    save_folder,
    entity_symbols,
    dump_embs,
    task_name,
):
    """Dumps the predictions of a disambiguation task.

    Args:
        result_idx: batch index of the result arrays
        result_alias_offset: overall offset of the starting example (i.e., the number of previous mens already written)
        config: model config
        res_dict: result dictionary from Emmental predict
        sent_idx2num_mens: Dict sentence idx to number of mentions
        sent_idx2row: Dict sentence idx to row of eval data
        save_folder: folder to save results
        entity_symbols: entity symbols
        dump_embs: whether to save the contextualized embeddings or not
        task_name: task name

    Returns: saved prediction file, saved embedding file (will be None if dump_embs is False)
    """
    num_processes = min(config.run_config.dataset_threads,
                        int(multiprocessing.cpu_count() * 0.9))
    cache_dir = os.path.join(save_folder, f"cache_{result_idx}")
    utils.ensure_dir(cache_dir)
    trie_candidate_map_folder = None
    trie_qid2eid_file = None
    # Save the alias->QID candidate map and the QID->EID mapping in memory efficient structures for faster
    # prediction dumping
    if num_processes > 1:
        entity_prep_dir = data_utils.get_emb_prep_dir(config.data_config)
        trie_candidate_map_folder = os.path.join(entity_prep_dir,
                                                 "for_dumping_preds",
                                                 "alias_cand_trie")
        utils.ensure_dir(trie_candidate_map_folder)
        check_and_create_alias_cand_trie(trie_candidate_map_folder,
                                         entity_symbols)
        trie_qid2eid_file = os.path.join(entity_prep_dir, "for_dumping_preds",
                                         "qid2eid_trie.marisa")
        if not os.path.exists(trie_qid2eid_file):
            utils.create_single_item_trie(entity_symbols.get_qid2eid(),
                                          out_file=trie_qid2eid_file)

    # This is dumping
    disambig_res_dict = {}
    for k in res_dict:
        assert task_name in res_dict[
            k], f"{task_name} not in res_dict for key {k}"
        disambig_res_dict[k] = res_dict[k][task_name]

    # write to file (M x hidden x size for each data point -- next step will deal with recovering original sentence
    # indices for overflowing sentences)
    unmerged_entity_emb_file = os.path.join(save_folder, f"entity_embs.pt")
    merged_entity_emb_file = os.path.join(save_folder,
                                          f"entity_embs_unmerged.pt")
    emb_file_config = os.path.splitext(
        unmerged_entity_emb_file)[0] + "_config.npy"
    M = config.data_config.max_aliases
    K = entity_symbols.max_candidates + (
        not config.data_config.train_in_candidates)
    if dump_embs:
        unmerged_storage_type = np.dtype([
            ("M", int),
            ("K", int),
            ("hidden_size", int),
            ("sent_idx", int),
            ("subsent_idx", int),
            ("alias_list_pos", int, (M, )),
            ("entity_emb", float, M * config.model_config.hidden_size),
            ("final_loss_true", int, (M, )),
            ("final_loss_pred", int, (M, )),
            ("final_loss_prob", float, (M, )),
            ("final_loss_cand_probs", float, M * K),
        ])
        merged_storage_type = np.dtype([
            ("hidden_size", int),
            ("sent_idx", int),
            ("alias_list_pos", int),
            ("entity_emb", float, config.model_config.hidden_size),
            ("final_loss_pred", int),
            ("final_loss_prob", float),
            ("final_loss_cand_probs", float, K),
        ])
    else:
        # don't need to extract contextualized entity embedding
        unmerged_storage_type = np.dtype([
            ("M", int),
            ("K", int),
            ("hidden_size", int),
            ("sent_idx", int),
            ("subsent_idx", int),
            ("alias_list_pos", int, (M, )),
            ("final_loss_true", int, (M, )),
            ("final_loss_pred", int, (M, )),
            ("final_loss_prob", float, (M, )),
            ("final_loss_cand_probs", float, M * K),
        ])
        merged_storage_type = np.dtype([
            ("hidden_size", int),
            ("sent_idx", int),
            ("alias_list_pos", int),
            ("final_loss_pred", int),
            ("final_loss_prob", float),
            ("final_loss_cand_probs", float, K),
        ])
    mmap_file = np.memmap(
        unmerged_entity_emb_file,
        dtype=unmerged_storage_type,
        mode="w+",
        shape=(len(disambig_res_dict["uids"]), ),
    )
    # print("MEMMAP FILE SHAPE", len(disambig_res_dict["uids"]))
    # Init sent_idx to -1 for debugging
    mmap_file[:]["sent_idx"] = -1
    np.save(emb_file_config, unmerged_storage_type, allow_pickle=True)
    log_rank_0_debug(
        logger,
        f"Created file {unmerged_entity_emb_file} to save predictions.")

    log_rank_0_debug(logger, f'{len(disambig_res_dict["uids"])} samples')
    for_iteration = [
        disambig_res_dict["uids"],
        disambig_res_dict["golds"],
        disambig_res_dict["probs"],
        disambig_res_dict["preds"],
    ]
    all_sent_idx = set()
    for i, (uid, gold, probs, model_pred) in enumerate(zip(*for_iteration)):
        # disambig_res_dict["output"] is dict with keys ['_input__alias_orig_list_pos',
        # 'bootleg_pred_1', '_input__sent_idx', '_input__for_dump_gold_cand_K_idx_train', '_input__subsent_idx', 0, 1]
        sent_idx = disambig_res_dict["outputs"]["_input__sent_idx"][i]
        # print("INSIDE LOOP", sent_idx, "AT", i)
        subsent_idx = disambig_res_dict["outputs"]["_input__subsent_idx"][i]
        alias_orig_list_pos = disambig_res_dict["outputs"][
            "_input__alias_orig_list_pos"][i]
        gold_cand_K_idx_train = disambig_res_dict["outputs"][
            "_input__for_dump_gold_cand_K_idx_train"][i]
        output_embeddings = disambig_res_dict["outputs"][
            f"{PRED_LAYER}_ent_embs"][i]
        mmap_file[i]["M"] = M
        mmap_file[i]["K"] = K
        mmap_file[i]["hidden_size"] = config.model_config.hidden_size
        mmap_file[i]["sent_idx"] = sent_idx
        mmap_file[i]["subsent_idx"] = subsent_idx
        mmap_file[i]["alias_list_pos"] = alias_orig_list_pos
        # This will give all aliases seen by the model during training, independent of if it's gold or not
        mmap_file[i][f"final_loss_true"] = gold_cand_K_idx_train.reshape(M)

        # get max for each alias, probs is M x K
        max_probs = probs.max(axis=1)
        pred_cands = probs.argmax(axis=1)

        mmap_file[i]["final_loss_pred"] = pred_cands
        mmap_file[i]["final_loss_prob"] = max_probs
        mmap_file[i]["final_loss_cand_probs"] = probs.reshape(1, -1)

        all_sent_idx.add(str(sent_idx))
        # final_entity_embs is M x K x hidden_size, pred_cands is M
        if dump_embs:
            chosen_entity_embs = select_embs(embs=output_embeddings,
                                             pred_cands=pred_cands,
                                             M=M)

            # write chosen entity embs to file for contextualized entity embeddings
            mmap_file[i]["entity_emb"] = chosen_entity_embs.reshape(1, -1)

    # for i in range(len(mmap_file)):
    #     si = mmap_file[i]["sent_idx"]
    #     if -1 == si:
    #         import pdb
    #         pdb.set_trace()
    #     assert si != -1, f"{i} {mmap_file[i]}"
    # Store all predicted sentences to filter the sentence mapping by
    subset_sent_idx2num_mens = {
        k: v
        for k, v in sent_idx2num_mens.items() if k in all_sent_idx
    }
    # print("ALL SEEN", all_sent_idx)
    subsent_sent_idx2row = {
        k: v
        for k, v in sent_idx2row.items() if k in all_sent_idx
    }
    result_file = get_result_file(result_idx, save_folder)
    log_rank_0_debug(logger, f"Writing predictions to {result_file}...")
    merge_subsentences(
        num_processes=num_processes,
        subset_sent_idx2num_mens=subset_sent_idx2num_mens,
        cache_folder=cache_dir,
        to_save_file=merged_entity_emb_file,
        to_save_storage=merged_storage_type,
        to_read_file=unmerged_entity_emb_file,
        to_read_storage=unmerged_storage_type,
        dump_embs=dump_embs,
    )
    write_data_labels(
        num_processes=num_processes,
        result_alias_offset=result_alias_offset,
        merged_entity_emb_file=merged_entity_emb_file,
        merged_storage_type=merged_storage_type,
        sent_idx2row=subsent_sent_idx2row,
        cache_folder=cache_dir,
        out_file=result_file,
        entity_dump=entity_symbols,
        train_in_candidates=config.data_config.train_in_candidates,
        max_candidates=entity_symbols.max_candidates,
        dump_embs=dump_embs,
        trie_candidate_map_folder=trie_candidate_map_folder,
        trie_qid2eid_file=trie_qid2eid_file,
    )

    out_emb_file = None
    filt_emb_data = np.memmap(merged_entity_emb_file,
                              dtype=merged_storage_type,
                              mode="r+")
    total_mentions_seen = len(filt_emb_data)
    # save easier-to-use embedding file
    if dump_embs:
        hidden_size = filt_emb_data[0]["hidden_size"]
        out_emb_file = get_emb_file(result_idx, save_folder)
        np.save(out_emb_file,
                filt_emb_data["entity_emb"].reshape(-1, hidden_size))
        log_rank_0_debug(
            logger,
            f"Saving contextual entity embeddings for {result_idx} to {out_emb_file}",
        )
    filt_emb_data = None

    # Cleanup cache - sometimes the file in cache_dir is still open so we need to retry to delete it
    try_rmtree(cache_dir)

    log_rank_0_debug(logger,
                     f"Wrote predictions for {result_idx} to {result_file}")
    return result_file, out_emb_file, all_sent_idx, total_mentions_seen
示例#26
0
    def load_regularization_mapping(cls, data_config,
                                    num_types_with_pad_and_unk, type2row_dict,
                                    reg_file):
        """Reads in a csv file with columns [qid, regularization].

        In the forward pass, the entity id with associated qid will be
        regularized with probability regularization.

        Args:
            data_config: data config
            num_entities_with_pad_and_nocand: number of types including pad and null option
            type2row_dict: Dict from typeID to row id in the type embedding matrix
            reg_file: regularization csv file

        Returns: Tensor where each value is the regularization value for EID
        """
        reg_str = os.path.splitext(os.path.basename(reg_file.replace("/",
                                                                     "_")))[0]
        prep_dir = data_utils.get_data_prep_dir(data_config)
        prep_file = os.path.join(prep_dir,
                                 f"type_regularization_mapping_{reg_str}.pt")
        utils.ensure_dir(os.path.dirname(prep_file))
        log_rank_0_debug(logger,
                         f"Looking for regularization mapping in {prep_file}")
        if not data_config.overwrite_preprocessed_data and os.path.exists(
                prep_file):
            log_rank_0_debug(
                logger,
                f"Loading existing entity regularization mapping from {prep_file}",
            )
            start = time.time()
            typeid2reg = torch.load(prep_file)
            log_rank_0_debug(
                logger,
                f"Loaded existing entity regularization mapping in {round(time.time() - start, 2)}s",
            )
        else:
            start = time.time()
            log_rank_0_debug(
                logger,
                f"Building entity regularization mapping from {reg_file}")
            typeid2reg_raw = pd.read_csv(reg_file)
            assert (
                "typeid" in typeid2reg_raw.columns
                and "regularization" in typeid2reg_raw.columns
            ), f"Expected typeid and regularization as the column names for {reg_file}"
            # default of no mask
            typeid2reg_arr = [0.0] * num_types_with_pad_and_unk
            for row_idx, row in typeid2reg_raw.iterrows():
                # Happens when we filter QIDs not in our entity db and the max typeid is smaller than the total number
                if int(row["typeid"]) not in type2row_dict:
                    continue
                typeid = type2row_dict[int(row["typeid"])]
                typeid2reg_arr[typeid] = row["regularization"]
            typeid2reg = torch.Tensor(typeid2reg_arr)
            torch.save(typeid2reg, prep_file)
            log_rank_0_debug(
                logger,
                f"Finished building and saving entity regularization mapping in {round(time.time() - start, 2)}s.",
            )
        return typeid2reg
示例#27
0
    def __init__(
        self,
        main_args,
        emb_args,
        entity_symbols,
        key,
        cpu,
        normalize,
        dropout1d_perc,
        dropout2d_perc,
    ):
        super(TypeEmb, self).__init__(
            main_args=main_args,
            emb_args=emb_args,
            entity_symbols=entity_symbols,
            key=key,
            cpu=cpu,
            normalize=normalize,
            dropout1d_perc=dropout1d_perc,
            dropout2d_perc=dropout2d_perc,
        )
        allowable_keys = {
            "max_types",
            "type_dim",
            "type_labels",
            "type_vocab",
            "merge_func",
            "attn_hidden_size",
            "regularize_mapping",
        }
        correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args)
        if not correct:
            raise ValueError(f"The key {bad_key} is not in {allowable_keys}")
        assert (
            "max_types"
            in emb_args), "Type embedding requires max_types to be set in args"
        assert (
            "type_dim"
            in emb_args), "Type embedding requires type_dim to be set in args"
        assert (
            "type_labels" in emb_args
        ), "Type embedding requires type_labels to be set in args. A Dict from QID -> TypeId or TypeName"
        assert (
            "type_vocab" in emb_args
        ), "Type embedding requires type_vocab to be set in args. A Dict from TypeName -> TypeId"
        assert (self.cpu is False
                ), f"We don't support putting type embeddings on CPU right now"
        self.merge_func = self.average_types
        self.orig_dim = emb_args.type_dim
        self.add_attn = None
        # Function for merging multiple types
        if "merge_func" in emb_args:
            assert emb_args.merge_func in [
                "average",
                "addattn",
            ], (f"{key}: You have set the type merge_func to be {emb_args.merge_func} but"
                f" that is not in the allowable list of [average, addattn]")
            if emb_args.merge_func == "addattn":
                if "attn_hidden_size" in emb_args:
                    attn_hidden_size = emb_args.attn_hidden_size
                else:
                    attn_hidden_size = 100
                # Softmax of types using the sentence context
                self.add_attn = PositionAwareAttention(
                    input_size=self.orig_dim,
                    attn_size=attn_hidden_size,
                    feature_size=0)
                self.merge_func = self.add_attn_merge
        self.max_types = emb_args.max_types
        (
            eid2typeids_table,
            self.type2row_dict,
            num_types_with_unk,
            self.prep_file,
        ) = self.prep(
            data_config=main_args.data_config,
            emb_args=emb_args,
            entity_symbols=entity_symbols,
        )
        self.register_buffer("eid2typeids_table",
                             eid2typeids_table,
                             persistent=False)
        # self.eid2typeids_table.requires_grad = False
        self.num_types_with_pad_and_unk = num_types_with_unk + 1

        # Regularization mapping goes from typeid to 2d dropout percent
        if "regularize_mapping" in emb_args:
            typeid2reg = torch.zeros(self.num_types_with_pad_and_unk)
        else:
            typeid2reg = None

        if not self.from_pretrained:
            if "regularize_mapping" in emb_args:
                if self.dropout1d_perc > 0 or self.dropout2d_perc > 0:
                    logger.warning(
                        f"You have 1D or 2D regularization set with a regularize_mapping. Do you mean to do this?"
                    )
                log_rank_0_info(
                    logger,
                    f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}",
                )
                typeid2reg = self.load_regularization_mapping(
                    main_args.data_config,
                    self.num_types_with_pad_and_unk,
                    self.type2row_dict,
                    emb_args.regularize_mapping,
                )
        self.register_buffer("typeid2reg", typeid2reg)
        assert self.eid2typeids_table.shape[1] == emb_args.max_types, (
            f"Something went wrong with loading type file."
            f" The given max types {emb_args.max_types} does not match that "
            f"of type table {self.eid2typeids_table.shape[1]}")
        log_rank_0_debug(
            logger,
            f"{key}: Type embedding with {self.max_types} types with dim {self.orig_dim}. "
            f"Setting merge_func to be {self.merge_func.__name__} in type emb.",
        )
示例#28
0
def get_dataloader_embeddings(main_args, entity_symbols):
    """Gets KG embeddings that need to be processed in the __get_item__ method
    of a dataset (e.g., querying a sparce numpy matrix). We save, for each KG
    embedding class that needs this preprocessing, the adjacency matrix (for KG
    connections), the processing function to run in __get_item__, and the file
    to load the adj matrix for dumping/loading.

    Args:
        main_args: main arguments
        entity_symbols: entity symbols

    Returns: Dict of KG metadata for using in the __get_item__ method.
    """
    batch_on_the_fly_kg_adj = {}
    for emb in main_args.data_config.ent_embeddings:
        batch_on_fly = "batch_on_the_fly" in emb and emb[
            "batch_on_the_fly"] is True
        # Find embeddings that have a "batch of the fly" key
        if batch_on_fly:
            log_rank_0_debug(
                logger,
                f"Loading class {emb.load_class} for preprocessing as on the fly or in data prep embeddings",
            )
            (
                cpu,
                dropout1d_perc,
                dropout2d_perc,
                emb_args,
                freeze,
                normalize,
                through_bert,
            ) = embedding_utils.get_embedding_args(emb)
            try:
                # Load the object
                mod, load_class = import_class("bootleg.embeddings",
                                               emb.load_class)
                kg_class = getattr(mod, load_class)(
                    main_args=main_args,
                    emb_args=emb_args,
                    entity_symbols=entity_symbols,
                    key=emb.key,
                    cpu=cpu,
                    normalize=normalize,
                    dropout1d_perc=dropout1d_perc,
                    dropout2d_perc=dropout2d_perc,
                )
                # Extract its kg adj, we'll use this later
                # Extract the kg_adj_process_func (how to process the embeddings in __get_item__ or dataset prep)
                # Extract the prep_file. We use this to load the kg_adj back after
                # saving/loading state using scipy.sparse.load_npz(prep_file)
                assert hasattr(
                    kg_class, "kg_adj"
                ), f"The embedding class {emb.key} does not have a kg_adj attribute and it needs to."
                assert hasattr(
                    kg_class, "kg_adj_process_func"
                ), f"The embedding class {emb.key} does not have a kg_adj_process_func attribute and it needs to."
                assert hasattr(kg_class, "prep_file"), (
                    f"The embedding class {emb.key} does not have a prep_file attribute and it needs to. We will call"
                    f" `scipy.sparse.load_npz(prep_file)` to load the kg_adj matrix."
                )
                batch_on_the_fly_kg_adj[emb.key] = {
                    "kg_adj": kg_class.kg_adj,
                    "kg_adj_process_func": kg_class.kg_adj_process_func,
                    "prep_file": kg_class.prep_file,
                }
            except AttributeError as e:
                logger.warning(
                    f"No prep method found for {emb.load_class} with error {e}"
                )
                raise
            except Exception as e:
                print("ERROR", e)
                raise
    return batch_on_the_fly_kg_adj
示例#29
0
    def build_kg_adj(cls, kg_adj_file, entity_symbols, threshold, log_weight):
        """Builds the KG adjacency matrix from inputs.

        Args:
            kg_adj_file: KG adjacency file
            entity_symbols: entity symbols
            threshold: weight threshold to count as an edge
            log_weight: whether to log the weight after the threshold

        Returns: KG adjacency
        """
        G = nx.Graph()
        qids = set(entity_symbols.get_all_qids())
        edges_to_add = []
        num_added = 0
        num_total = 0
        file_ending = os.path.splitext(kg_adj_file)[1][1:]
        # Get ending to determine if txt or json file
        assert file_ending in [
            "json",
            "txt",
        ], f"We only support loading txt or json files for edge weights. You provided {file_ending}"
        with open(kg_adj_file) as f:
            if file_ending == "json":
                all_edges = json.load(f)
                for head in all_edges:
                    for tail in all_edges[head]:
                        weight = all_edges[head][tail]
                        num_total += 1
                        if head in qids and tail in qids and weight > threshold:
                            num_added += 1
                            if log_weight:
                                edges_to_add.append(
                                    (head, tail, np.log(weight)))
                            else:
                                edges_to_add.append((head, tail, weight))
            else:
                for line in f:
                    splt = line.strip().split()
                    if len(splt) == 2:
                        head, tail = splt
                        weight = 1.0
                    elif len(splt) == 3:
                        head, tail, weight = splt
                    else:
                        raise ValueError(
                            f"A line {line} in {kg_adj_file} has not 2 or 3 values after called split()."
                        )
                    num_total += 1
                    # head and tail must be in list of qids
                    if head in qids and tail in qids and weight > threshold:
                        num_added += 1
                        if log_weight:
                            edges_to_add.append((head, tail, np.log(weight)))
                        else:
                            edges_to_add.append((head, tail, weight))
        log_rank_0_debug(
            logger,
            f"Adding {num_added} out of {num_total} items from {kg_adj_file}")
        G.add_weighted_edges_from(edges_to_add)
        # convert to entityids
        G = nx.relabel_nodes(G, entity_symbols.get_qid2eid())
        # create adjacency matrix
        adj = nx.adjacency_matrix(
            G, nodelist=range(entity_symbols.num_entities_with_pad_and_nocand))
        assert (
            adj.sum() > 0
        ), f"Your KG Adj matrix has all 0 values. Something was likely parsed wrong."
        # assert that the padded entity has no connections
        assert (adj[:, -1] != 0).sum() == 0
        assert (adj[-1, :] != 0).sum() == 0
        # assert that the unk entity has no connections
        assert (adj[:, 0] != 0).sum() == 0
        assert (adj[0, :] != 0).sum() == 0
        return adj