Beispiel #1
0
    def __init__(
        self,
        main_args,
        emb_args,
        entity_symbols,
        key,
        cpu,
        normalize,
        dropout1d_perc,
        dropout2d_perc,
    ):
        super(TitleEmb, self).__init__(
            main_args=main_args,
            emb_args=emb_args,
            entity_symbols=entity_symbols,
            key=key,
            cpu=cpu,
            normalize=normalize,
            dropout1d_perc=dropout1d_perc,
            dropout2d_perc=dropout2d_perc,
        )
        allowable_keys = {"proj", "requires_grad"}
        correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args)
        if not correct:
            raise ValueError(f"The key {bad_key} is not in {allowable_keys}")
        self.orig_dim = BERT_WORD_DIM
        self.merge_func = self.average_titles
        self.M = main_args.data_config.max_aliases
        self.K = get_max_candidates(entity_symbols, main_args.data_config)
        self._dim = main_args.model_config.hidden_size
        if "proj" in emb_args:
            self._dim = emb_args.proj
        self.requires_grad = True
        if "requires_grad" in emb_args:
            self.requires_grad = emb_args.requires_grad
        self.title_proj = torch.nn.Linear(self.orig_dim, self._dim)
        log_rank_0_debug(
            logger,
            f'Setting the "proj" parameter to {self._dim} and the "requires_grad" parameter to {self.requires_grad}',
        )

        (
            entity2titleid_table,
            entity2titlemask_table,
            entity2tokentypeid_table,
        ) = self.prep(
            data_config=main_args.data_config,
            entity_symbols=entity_symbols,
        )
        self.register_buffer("entity2titleid_table",
                             entity2titleid_table,
                             persistent=False)
        self.register_buffer("entity2titlemask_table",
                             entity2titlemask_table,
                             persistent=False)
        self.register_buffer("entity2tokentypeid_table",
                             entity2tokentypeid_table,
                             persistent=False)
Beispiel #2
0
    def __init__(
        self,
        main_args,
        emb_args,
        entity_symbols,
        key,
        cpu,
        normalize,
        dropout1d_perc,
        dropout2d_perc,
    ):
        super(StaticEmb, self).__init__(
            main_args=main_args,
            emb_args=emb_args,
            entity_symbols=entity_symbols,
            key=key,
            cpu=cpu,
            normalize=normalize,
            dropout1d_perc=dropout1d_perc,
            dropout2d_perc=dropout2d_perc,
        )
        allowable_keys = {"emb_file", "proj"}
        correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args)
        if not correct:
            raise ValueError(f"The key {bad_key} is not in {allowable_keys}")
        assert "emb_file" in emb_args, f"Must have emb_file in args for StaticEmb"
        self.entity2static = self.prep(
            data_config=main_args.data_config,
            emb_args=emb_args,
            entity_symbols=entity_symbols,
        )
        self.orig_dim = self.entity2static.shape[1]
        # entity2static_embedding = torch.nn.Embedding(
        #     entity_symbols.num_entities_with_pad_and_nocand,
        #     self.orig_dim,
        #     padding_idx=-1,
        #     sparse=True,
        # )

        self.normalize = False
        if self.orig_dim > 1:
            self.normalize = True
        self.proj = None
        self._dim = self.orig_dim
        if "proj" in emb_args:
            log_rank_0_debug(
                logger,
                f"Adding a projection layer to the static emb to go to dim {emb_args.proj}",
            )
            self._dim = emb_args.proj
            self.proj = MLP(
                input_size=self.orig_dim,
                num_hidden_units=None,
                output_size=self._dim,
                num_layers=1,
            )
Beispiel #3
0
    def __init__(
        self,
        main_args,
        emb_args,
        entity_symbols,
        key,
        cpu,
        normalize,
        dropout1d_perc,
        dropout2d_perc,
    ):
        super(KGAdjEmb, self).__init__(
            main_args=main_args,
            emb_args=emb_args,
            entity_symbols=entity_symbols,
            key=key,
            cpu=cpu,
            normalize=normalize,
            dropout1d_perc=dropout1d_perc,
            dropout2d_perc=dropout2d_perc,
        )
        allowable_keys = {"kg_adj", "threshold", "log_weight"}
        correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args)
        if not correct:
            raise ValueError(f"The key {bad_key} is not in {allowable_keys}")
        assert "kg_adj" in emb_args, f"KG embedding requires kg_adj to be set in args"
        assert (self.normalize is
                False), f"We can't normalize a KGAdjEmb as it has hidden dim 1"
        assert (self.dropout1d_perc == 0.0
                ), f"We can't dropout 1d a KGAdjEmb as it has hidden dim 1"
        assert (self.dropout2d_perc == 0.0
                ), f"We can't dropout 2d a KGAdjEmb as it has hidden dim 1"
        # This determines that, when prepping the embeddings, we will query the kg_adj matrix and sum the results -
        # generating one value per entity candidate
        self.kg_adj_process_func = embedding_utils.prep_kg_feature_sum
        self._dim = 1

        self.threshold_weight = 0
        if "threshold" in emb_args:
            self.threshold_weight = float(emb_args.threshold)
        self.log_weight = False
        if "log_weight" in emb_args:
            self.log_weight = emb_args.log_weight
            assert type(self.log_weight) is bool
        log_rank_0_debug(
            logger,
            f"Setting log_weight to be {self.log_weight} and threshold to be {self.threshold_weight} in {key}",
        )

        self.kg_adj, self.prep_file = self.prep(
            data_config=main_args.data_config,
            emb_args=emb_args,
            entity_symbols=entity_symbols,
            threshold=self.threshold_weight,
            log_weight=self.log_weight,
        )
Beispiel #4
0
def get_embedding_args(emb):
    """Extract the embedding arguments that are the same for _all_ embedding
    objects (see base_emb.py). These are defined in the upper level of the
    config.

    Allowed arguments:
        - cpu: True/False (whether embedding on CPU or not)
        - freeze: True/False (freeze parameters or not)
        - dropout1d: float between 0, 1
        - dropout2d: float between 0, 1
        - normalize: True/False
        - sent_through_bert: True/False (whether this embedding outputs indices for BERT encoder -- see bert_encoder.py)

    Args:
        emb: embedding dictionary arguments from config

    Returns: parsed arguments with defaults
    """
    # batch_on_the_fly is used for determining when embeddings are prepped (see data.py)
    allowable_keys = {
        "args",
        "load_class",
        "key",
        "cpu",
        "batch_on_the_fly",
        FREEZE,
        DROPOUT_1D,
        DROPOUT_2D,
        NORMALIZE,
        SEND_THROUGH_BERT,
    }
    emb_args = emb.get("args", None)
    assert (
        "load_class" in emb
    ), "You must specify a load_class in the embedding config: {load_class: ..., key: ...}"
    assert (
        "key" in emb
    ), "You must specify a key in the embedding config: {load_class: ..., key: ...}"
    correct, bad_key = assert_keys_in_dict(allowable_keys, emb)
    if not correct:
        raise ValueError(f"The key {bad_key} is not in {allowable_keys}")
    # Add cpu
    cpu = emb.get("cpu", False)
    assert type(cpu) is bool
    # Add freeze
    freeze = emb.get(FREEZE, False)
    assert type(freeze) is bool
    # Add 1D dropout
    dropout1d_perc = emb.get(DROPOUT_1D, 0.0)
    assert 1.0 >= dropout1d_perc >= 0.0
    # Add 2D dropout
    dropout2d_perc = emb.get(DROPOUT_2D, 0.0)
    assert 1.0 >= dropout2d_perc >= 0.0
    # Add normalize
    normalize = emb.get(NORMALIZE, True)
    assert type(normalize) is bool
    # Add through BERT
    through_bert = emb.get(SEND_THROUGH_BERT, False)
    assert type(through_bert) is bool
    return (
        cpu,
        dropout1d_perc,
        dropout2d_perc,
        emb_args,
        freeze,
        normalize,
        through_bert,
    )
Beispiel #5
0
    def __init__(
        self,
        main_args,
        emb_args,
        entity_symbols,
        key,
        cpu,
        normalize,
        dropout1d_perc,
        dropout2d_perc,
    ):
        super(LearnedEntityEmb, self).__init__(
            main_args=main_args,
            emb_args=emb_args,
            entity_symbols=entity_symbols,
            key=key,
            cpu=cpu,
            normalize=normalize,
            dropout1d_perc=dropout1d_perc,
            dropout2d_perc=dropout2d_perc,
        )
        allowable_keys = {
            "learned_embedding_size",
            "regularize_mapping",
            "tail_init",
            "tail_init_zeros",
        }
        correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args)
        if not correct:
            raise ValueError(f"The key {bad_key} is not in {allowable_keys}")
        assert (
            "learned_embedding_size" in emb_args
        ), f"LearnedEntityEmb must have learned_embedding_size in args"
        self.learned_embedding_size = emb_args.learned_embedding_size

        # Set sparsity based on optimizer and fp16. The None optimizer is Bootleg's SparseDenseAdam.
        # If fp16 is True, must use dense.
        optimiz = main_args.learner_config.optimizer_config.optimizer
        if optimiz in [None, "sparse_adam"] and main_args.learner_config.fp16 is False:
            sparse = True
        else:
            sparse = False

        if (
            torch.distributed.is_initialized()
            and main_args.model_config.distributed_backend == "nccl"
        ):
            sparse = False
        log_rank_0_debug(
            logger, f"Setting sparsity for entity embeddings to be {sparse}"
        )
        self.learned_entity_embedding = nn.Embedding(
            entity_symbols.num_entities_with_pad_and_nocand,
            self.learned_embedding_size,
            padding_idx=-1,
            sparse=sparse,
        )
        self._dim = main_args.model_config.hidden_size

        if "regularize_mapping" in emb_args:
            eid2reg = torch.zeros(entity_symbols.num_entities_with_pad_and_nocand)
        else:
            eid2reg = None

        # If tail_init is false, all embeddings are randomly intialized.
        # If tail_init is true, we initialize all embeddings to be the same.
        self.tail_init = True
        self.tail_init_zeros = False
        # None init vec will be random
        init_vec = None
        if not self.from_pretrained:
            if "tail_init" in emb_args:
                self.tail_init = emb_args.tail_init
            if "tail_init_zeros" in emb_args:
                self.tail_init_zeros = emb_args.tail_init_zeros
                self.tail_init = False
                init_vec = torch.zeros(1, self.learned_embedding_size)
            assert not (
                self.tail_init and self.tail_init_zeros
            ), f"Can only have one of tail_init or tail_init_zeros set"
            if self.tail_init or self.tail_init_zeros:
                if self.tail_init_zeros:
                    log_rank_0_debug(
                        logger,
                        f"All learned entity embeddings are initialized to zero.",
                    )
                else:
                    log_rank_0_debug(
                        logger,
                        f"All learned entity embeddings are initialized to the same value.",
                    )
                init_vec = model_utils.init_embeddings_to_vec(
                    self.learned_entity_embedding, pad_idx=-1, vec=init_vec
                )
                vec_save_file = os.path.join(
                    emmental.Meta.log_path, "init_vec_entity_embs.npy"
                )
                log_rank_0_debug(logger, f"Saving init vector to {vec_save_file}")
                if (
                    torch.distributed.is_initialized()
                    and torch.distributed.get_rank() == 0
                ):
                    np.save(vec_save_file, init_vec)
            else:
                log_rank_0_debug(
                    logger, f"All learned embeddings are randomly initialized."
                )
            # Regularization mapping goes from eid to 2d dropout percent
            if "regularize_mapping" in emb_args:
                if self.dropout1d_perc > 0 or self.dropout2d_perc > 0:
                    log_rank_0_debug(
                        logger,
                        f"You have 1D or 2D regularization set with a regularize_mapping. Do you mean to do this?",
                    )
                log_rank_0_debug(
                    logger,
                    f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}",
                )
                eid2reg = self.load_regularization_mapping(
                    main_args.data_config,
                    entity_symbols,
                    emb_args.regularize_mapping,
                )
        self.register_buffer("eid2reg", eid2reg)
Beispiel #6
0
    def __init__(
        self,
        main_args,
        emb_args,
        entity_symbols,
        key,
        cpu,
        normalize,
        dropout1d_perc,
        dropout2d_perc,
    ):
        super(TopKEntityEmb, self).__init__(
            main_args=main_args,
            emb_args=emb_args,
            entity_symbols=entity_symbols,
            key=key,
            cpu=cpu,
            normalize=normalize,
            dropout1d_perc=dropout1d_perc,
            dropout2d_perc=dropout2d_perc,
        )
        allowable_keys = {
            "learned_embedding_size",
            "perc_emb_drop",
            "qid2topk_eid",
            "regularize_mapping",
            "tail_init",
            "tail_init_zeros",
        }
        correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args)
        if not correct:
            raise ValueError(f"The key {bad_key} is not in {allowable_keys}")
        assert (
            "learned_embedding_size" in emb_args
        ), f"TopKEntityEmb must have learned_embedding_size in args"
        assert "perc_emb_drop" in emb_args, (
            f"To use TopKEntityEmb we need perc_emb_drop to be in the args. This gives the percentage of embeddings"
            f" removed."
        )
        self.learned_embedding_size = emb_args.learned_embedding_size
        # We remove perc_emb_drop percent of the embeddings and add one to represent the new toes embedding
        num_topk_entities_with_pad_and_nocand = (
            entity_symbols.num_entities_with_pad_and_nocand
            - int(emb_args.perc_emb_drop * entity_symbols.num_entities)
            + 1
        )
        # Mapping of entity to the new eid mapping
        eid2topkeid = torch.arange(0, entity_symbols.num_entities_with_pad_and_nocand)
        # There are issues with using -1 index into the embeddings; so we manually set it to be the last value
        eid2topkeid[-1] = num_topk_entities_with_pad_and_nocand - 1
        if "qid2topk_eid" not in emb_args:
            assert self.from_pretrained, (
                f"If you don't provide the qid2topk_eid mapping as an argument to TopKEntityEmb, "
                f"you must be loading a model from a checkpoint to build this index mapping"
            )

        self.learned_entity_embedding = nn.Embedding(
            num_topk_entities_with_pad_and_nocand,
            self.learned_embedding_size,
            padding_idx=-1,
            sparse=False,
        )

        self._dim = main_args.model_config.hidden_size

        if "regularize_mapping" in emb_args:
            eid2reg = torch.zeros(num_topk_entities_with_pad_and_nocand)
        else:
            eid2reg = None
        # If tail_init is false, all embeddings are randomly intialized.
        # If tail_init is true, we initialize all embeddings to be the same.
        self.tail_init = True
        self.tail_init_zeros = False
        # None init vec will be random
        init_vec = None
        if not self.from_pretrained:
            qid2topk_eid = utils.load_json_file(emb_args.qid2topk_eid)
            assert (
                len(qid2topk_eid) == entity_symbols.num_entities
            ), f"You must have an item in qid2topk_eid for each qid in entity_symbols"
            for qid in entity_symbols.get_all_qids():
                old_eid = entity_symbols.get_eid(qid)
                new_eid = qid2topk_eid[qid]
                eid2topkeid[old_eid] = new_eid
            assert eid2topkeid[0] == 0, f"The 0 eid shouldn't be changed"
            assert (
                eid2topkeid[-1] == num_topk_entities_with_pad_and_nocand - 1
            ), "The -1 eid should still map to -1"

            if "tail_init" in emb_args:
                self.tail_init = emb_args.tail_init
            if "tail_init_zeros" in emb_args:
                self.tail_init_zeros = emb_args.tail_init_zeros
                self.tail_init = False
                init_vec = torch.zeros(1, self.learned_embedding_size)
            assert not (
                self.tail_init and self.tail_init_zeros
            ), f"Can only have one of tail_init or tail_init_zeros set"
            if self.tail_init or self.tail_init_zeros:
                if self.tail_init_zeros:
                    log_rank_0_debug(
                        logger,
                        f"All learned entity embeddings are initialized to zero.",
                    )
                else:
                    log_rank_0_debug(
                        logger,
                        f"All learned entity embeddings are initialized to the same value.",
                    )
                init_vec = model_utils.init_embeddings_to_vec(
                    self.learned_entity_embedding, pad_idx=-1, vec=init_vec
                )
                vec_save_file = os.path.join(
                    emmental.Meta.log_path, "init_vec_entity_embs.npy"
                )
                log_rank_0_debug(logger, f"Saving init vector to {vec_save_file}")
                if (
                    torch.distributed.is_initialized()
                    and torch.distributed.get_rank() == 0
                ):
                    np.save(vec_save_file, init_vec)
            else:
                log_rank_0_debug(
                    logger, f"All learned embeddings are randomly initialized."
                )
            # Regularization mapping goes from eid to 2d dropout percent
            if "regularize_mapping" in emb_args:
                log_rank_0_debug(
                    logger,
                    f"You are using regularization mapping with a topK entity embedding. "
                    f"This means all QIDs that are mapped to the same"
                    f" EID will get the same regularization value.",
                )
                if self.dropout1d_perc > 0 or self.dropout2d_perc > 0:
                    log_rank_0_debug(
                        logger,
                        f"You have 1D or 2D regularization set with a regularize_mapping. Do you mean to do this?",
                    )
                log_rank_0_debug(
                    logger,
                    f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}",
                )
                eid2reg = self.load_regularization_mapping(
                    main_args.data_config,
                    entity_symbols,
                    qid2topk_eid,
                    num_topk_entities_with_pad_and_nocand,
                    emb_args.regularize_mapping,
                )
        # Keep this mapping so a topK model can simply be loaded without needing the new eid mapping
        self.register_buffer("eid2topkeid", eid2topkeid)
        self.register_buffer("eid2reg", eid2reg)
Beispiel #7
0
    def __init__(
        self,
        main_args,
        emb_args,
        entity_symbols,
        key,
        cpu,
        normalize,
        dropout1d_perc,
        dropout2d_perc,
    ):
        super(TypeEmb, self).__init__(
            main_args=main_args,
            emb_args=emb_args,
            entity_symbols=entity_symbols,
            key=key,
            cpu=cpu,
            normalize=normalize,
            dropout1d_perc=dropout1d_perc,
            dropout2d_perc=dropout2d_perc,
        )
        allowable_keys = {
            "max_types",
            "type_dim",
            "type_labels",
            "type_vocab",
            "merge_func",
            "attn_hidden_size",
            "regularize_mapping",
        }
        correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args)
        if not correct:
            raise ValueError(f"The key {bad_key} is not in {allowable_keys}")
        assert (
            "max_types"
            in emb_args), "Type embedding requires max_types to be set in args"
        assert (
            "type_dim"
            in emb_args), "Type embedding requires type_dim to be set in args"
        assert (
            "type_labels" in emb_args
        ), "Type embedding requires type_labels to be set in args. A Dict from QID -> TypeId or TypeName"
        assert (
            "type_vocab" in emb_args
        ), "Type embedding requires type_vocab to be set in args. A Dict from TypeName -> TypeId"
        assert (self.cpu is False
                ), f"We don't support putting type embeddings on CPU right now"
        self.merge_func = self.average_types
        self.orig_dim = emb_args.type_dim
        self.add_attn = None
        # Function for merging multiple types
        if "merge_func" in emb_args:
            assert emb_args.merge_func in [
                "average",
                "addattn",
            ], (f"{key}: You have set the type merge_func to be {emb_args.merge_func} but"
                f" that is not in the allowable list of [average, addattn]")
            if emb_args.merge_func == "addattn":
                if "attn_hidden_size" in emb_args:
                    attn_hidden_size = emb_args.attn_hidden_size
                else:
                    attn_hidden_size = 100
                # Softmax of types using the sentence context
                self.add_attn = PositionAwareAttention(
                    input_size=self.orig_dim,
                    attn_size=attn_hidden_size,
                    feature_size=0)
                self.merge_func = self.add_attn_merge
        self.max_types = emb_args.max_types
        (
            eid2typeids_table,
            self.type2row_dict,
            num_types_with_unk,
            self.prep_file,
        ) = self.prep(
            data_config=main_args.data_config,
            emb_args=emb_args,
            entity_symbols=entity_symbols,
        )
        self.register_buffer("eid2typeids_table",
                             eid2typeids_table,
                             persistent=False)
        # self.eid2typeids_table.requires_grad = False
        self.num_types_with_pad_and_unk = num_types_with_unk + 1

        # Regularization mapping goes from typeid to 2d dropout percent
        if "regularize_mapping" in emb_args:
            typeid2reg = torch.zeros(self.num_types_with_pad_and_unk)
        else:
            typeid2reg = None

        if not self.from_pretrained:
            if "regularize_mapping" in emb_args:
                if self.dropout1d_perc > 0 or self.dropout2d_perc > 0:
                    logger.warning(
                        f"You have 1D or 2D regularization set with a regularize_mapping. Do you mean to do this?"
                    )
                log_rank_0_info(
                    logger,
                    f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}",
                )
                typeid2reg = self.load_regularization_mapping(
                    main_args.data_config,
                    self.num_types_with_pad_and_unk,
                    self.type2row_dict,
                    emb_args.regularize_mapping,
                )
        self.register_buffer("typeid2reg", typeid2reg)
        assert self.eid2typeids_table.shape[1] == emb_args.max_types, (
            f"Something went wrong with loading type file."
            f" The given max types {emb_args.max_types} does not match that "
            f"of type table {self.eid2typeids_table.shape[1]}")
        log_rank_0_debug(
            logger,
            f"{key}: Type embedding with {self.max_types} types with dim {self.orig_dim}. "
            f"Setting merge_func to be {self.merge_func.__name__} in type emb.",
        )