示例#1
0
 def __init__(self, main_args, emb_args, model_device, entity_symbols,
              word_symbols, word_emb, key):
     super(LearnedEntityEmb, self).__init__(main_args=main_args,
                                            emb_args=emb_args,
                                            model_device=model_device,
                                            entity_symbols=entity_symbols,
                                            word_symbols=word_symbols,
                                            word_emb=word_emb,
                                            key=key)
     self.logger = logging_utils.get_logger(main_args)
     self.learned_embedding_size = emb_args.learned_embedding_size
     self.normalize = True
     self.learned_entity_embedding = nn.Embedding(
         entity_symbols.num_entities_with_pad_and_nocand,
         self.learned_embedding_size,
         padding_idx=-1,
         sparse=True)
     # If tail_init is false, all embeddings are randomly intialized.
     # If tail_init is true, we initialize all embeddings to be the same.
     self.tail_init = True
     self.tail_init_zeros = False
     # None init vec will be random
     init_vec = None
     if "tail_init" in emb_args:
         self.tail_init = emb_args.tail_init
     if "tail_init_zeros" in emb_args:
         self.tail_init_zeros = emb_args.tail_init_zeros
         self.tail_init = False
         init_vec = torch.zeros(1, self.learned_embedding_size)
     assert not (self.tail_init and self.tail_init_zeros
                 ), f"Can only have one of tail_init or tail_init_zeros set"
     if self.tail_init or self.tail_init_zeros:
         if self.tail_init_zeros:
             self.logger.debug(
                 f"All learned entity embeddings are initialized to zero.")
         else:
             self.logger.debug(
                 f"All learned entity embeddings are initialized to the same value."
             )
         init_vec = model_utils.init_embeddings_to_vec(
             self.learned_entity_embedding, pad_idx=-1, vec=init_vec)
         vec_save_file = os.path.join(
             train_utils.get_save_folder(main_args.run_config),
             "init_vec_entity_embs.npy")
         self.logger.debug(f"Saving init vector to {vec_save_file}")
         np.save(vec_save_file, init_vec)
     else:
         self.logger.debug(
             f"All learned embeddings are randomly initialized.")
     self._dim = main_args.model_config.hidden_size
     self.eid2reg = None
     # Regularization mapping goes from eid to 2d dropout percent
     if "regularize_mapping" in emb_args:
         self.logger.debug(
             f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}"
         )
         self.eid2reg = self.load_regularization_mapping(
             main_args, entity_symbols, emb_args.regularize_mapping,
             self.logger.debug)
         self.eid2reg = self.eid2reg.to(model_device)
示例#2
0
    def test_initialization(self):
        self.learned_entity_embedding = nn.Embedding(
            self.entity_symbols.num_entities_with_pad_and_nocand,
            4,
            padding_idx=-1,
            sparse=True,
        )
        gold_emb = self.learned_entity_embedding.weight.data[:]

        gold_emb[1:] = torch.tensor(
            [
                [1.0, 2.0, 3.0, 4.0],
                [1.0, 2.0, 3.0, 4.0],
                [1.0, 2.0, 3.0, 4.0],
                [1.0, 2.0, 3.0, 4.0],
                [0.0, 0.0, 0.0, 0.0],
            ]
        )

        init_vec = torch.tensor([1.0, 2.0, 3.0, 4.0])
        init_vec_out = model_utils.init_embeddings_to_vec(
            self.learned_entity_embedding, pad_idx=-1, vec=init_vec
        )
        assert torch.equal(init_vec, init_vec_out)
        assert torch.equal(gold_emb, self.learned_entity_embedding.weight.data)
示例#3
0
    def __init__(
        self,
        main_args,
        emb_args,
        entity_symbols,
        key,
        cpu,
        normalize,
        dropout1d_perc,
        dropout2d_perc,
    ):
        super(LearnedEntityEmb, self).__init__(
            main_args=main_args,
            emb_args=emb_args,
            entity_symbols=entity_symbols,
            key=key,
            cpu=cpu,
            normalize=normalize,
            dropout1d_perc=dropout1d_perc,
            dropout2d_perc=dropout2d_perc,
        )
        allowable_keys = {
            "learned_embedding_size",
            "regularize_mapping",
            "tail_init",
            "tail_init_zeros",
        }
        correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args)
        if not correct:
            raise ValueError(f"The key {bad_key} is not in {allowable_keys}")
        assert (
            "learned_embedding_size" in emb_args
        ), f"LearnedEntityEmb must have learned_embedding_size in args"
        self.learned_embedding_size = emb_args.learned_embedding_size

        # Set sparsity based on optimizer and fp16. The None optimizer is Bootleg's SparseDenseAdam.
        # If fp16 is True, must use dense.
        optimiz = main_args.learner_config.optimizer_config.optimizer
        if optimiz in [None, "sparse_adam"] and main_args.learner_config.fp16 is False:
            sparse = True
        else:
            sparse = False

        if (
            torch.distributed.is_initialized()
            and main_args.model_config.distributed_backend == "nccl"
        ):
            sparse = False
        log_rank_0_debug(
            logger, f"Setting sparsity for entity embeddings to be {sparse}"
        )
        self.learned_entity_embedding = nn.Embedding(
            entity_symbols.num_entities_with_pad_and_nocand,
            self.learned_embedding_size,
            padding_idx=-1,
            sparse=sparse,
        )
        self._dim = main_args.model_config.hidden_size

        if "regularize_mapping" in emb_args:
            eid2reg = torch.zeros(entity_symbols.num_entities_with_pad_and_nocand)
        else:
            eid2reg = None

        # If tail_init is false, all embeddings are randomly intialized.
        # If tail_init is true, we initialize all embeddings to be the same.
        self.tail_init = True
        self.tail_init_zeros = False
        # None init vec will be random
        init_vec = None
        if not self.from_pretrained:
            if "tail_init" in emb_args:
                self.tail_init = emb_args.tail_init
            if "tail_init_zeros" in emb_args:
                self.tail_init_zeros = emb_args.tail_init_zeros
                self.tail_init = False
                init_vec = torch.zeros(1, self.learned_embedding_size)
            assert not (
                self.tail_init and self.tail_init_zeros
            ), f"Can only have one of tail_init or tail_init_zeros set"
            if self.tail_init or self.tail_init_zeros:
                if self.tail_init_zeros:
                    log_rank_0_debug(
                        logger,
                        f"All learned entity embeddings are initialized to zero.",
                    )
                else:
                    log_rank_0_debug(
                        logger,
                        f"All learned entity embeddings are initialized to the same value.",
                    )
                init_vec = model_utils.init_embeddings_to_vec(
                    self.learned_entity_embedding, pad_idx=-1, vec=init_vec
                )
                vec_save_file = os.path.join(
                    emmental.Meta.log_path, "init_vec_entity_embs.npy"
                )
                log_rank_0_debug(logger, f"Saving init vector to {vec_save_file}")
                if (
                    torch.distributed.is_initialized()
                    and torch.distributed.get_rank() == 0
                ):
                    np.save(vec_save_file, init_vec)
            else:
                log_rank_0_debug(
                    logger, f"All learned embeddings are randomly initialized."
                )
            # Regularization mapping goes from eid to 2d dropout percent
            if "regularize_mapping" in emb_args:
                if self.dropout1d_perc > 0 or self.dropout2d_perc > 0:
                    log_rank_0_debug(
                        logger,
                        f"You have 1D or 2D regularization set with a regularize_mapping. Do you mean to do this?",
                    )
                log_rank_0_debug(
                    logger,
                    f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}",
                )
                eid2reg = self.load_regularization_mapping(
                    main_args.data_config,
                    entity_symbols,
                    emb_args.regularize_mapping,
                )
        self.register_buffer("eid2reg", eid2reg)
示例#4
0
    def __init__(
        self,
        main_args,
        emb_args,
        entity_symbols,
        key,
        cpu,
        normalize,
        dropout1d_perc,
        dropout2d_perc,
    ):
        super(TopKEntityEmb, self).__init__(
            main_args=main_args,
            emb_args=emb_args,
            entity_symbols=entity_symbols,
            key=key,
            cpu=cpu,
            normalize=normalize,
            dropout1d_perc=dropout1d_perc,
            dropout2d_perc=dropout2d_perc,
        )
        allowable_keys = {
            "learned_embedding_size",
            "perc_emb_drop",
            "qid2topk_eid",
            "regularize_mapping",
            "tail_init",
            "tail_init_zeros",
        }
        correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args)
        if not correct:
            raise ValueError(f"The key {bad_key} is not in {allowable_keys}")
        assert (
            "learned_embedding_size" in emb_args
        ), f"TopKEntityEmb must have learned_embedding_size in args"
        assert "perc_emb_drop" in emb_args, (
            f"To use TopKEntityEmb we need perc_emb_drop to be in the args. This gives the percentage of embeddings"
            f" removed."
        )
        self.learned_embedding_size = emb_args.learned_embedding_size
        # We remove perc_emb_drop percent of the embeddings and add one to represent the new toes embedding
        num_topk_entities_with_pad_and_nocand = (
            entity_symbols.num_entities_with_pad_and_nocand
            - int(emb_args.perc_emb_drop * entity_symbols.num_entities)
            + 1
        )
        # Mapping of entity to the new eid mapping
        eid2topkeid = torch.arange(0, entity_symbols.num_entities_with_pad_and_nocand)
        # There are issues with using -1 index into the embeddings; so we manually set it to be the last value
        eid2topkeid[-1] = num_topk_entities_with_pad_and_nocand - 1
        if "qid2topk_eid" not in emb_args:
            assert self.from_pretrained, (
                f"If you don't provide the qid2topk_eid mapping as an argument to TopKEntityEmb, "
                f"you must be loading a model from a checkpoint to build this index mapping"
            )

        self.learned_entity_embedding = nn.Embedding(
            num_topk_entities_with_pad_and_nocand,
            self.learned_embedding_size,
            padding_idx=-1,
            sparse=False,
        )

        self._dim = main_args.model_config.hidden_size

        if "regularize_mapping" in emb_args:
            eid2reg = torch.zeros(num_topk_entities_with_pad_and_nocand)
        else:
            eid2reg = None
        # If tail_init is false, all embeddings are randomly intialized.
        # If tail_init is true, we initialize all embeddings to be the same.
        self.tail_init = True
        self.tail_init_zeros = False
        # None init vec will be random
        init_vec = None
        if not self.from_pretrained:
            qid2topk_eid = utils.load_json_file(emb_args.qid2topk_eid)
            assert (
                len(qid2topk_eid) == entity_symbols.num_entities
            ), f"You must have an item in qid2topk_eid for each qid in entity_symbols"
            for qid in entity_symbols.get_all_qids():
                old_eid = entity_symbols.get_eid(qid)
                new_eid = qid2topk_eid[qid]
                eid2topkeid[old_eid] = new_eid
            assert eid2topkeid[0] == 0, f"The 0 eid shouldn't be changed"
            assert (
                eid2topkeid[-1] == num_topk_entities_with_pad_and_nocand - 1
            ), "The -1 eid should still map to -1"

            if "tail_init" in emb_args:
                self.tail_init = emb_args.tail_init
            if "tail_init_zeros" in emb_args:
                self.tail_init_zeros = emb_args.tail_init_zeros
                self.tail_init = False
                init_vec = torch.zeros(1, self.learned_embedding_size)
            assert not (
                self.tail_init and self.tail_init_zeros
            ), f"Can only have one of tail_init or tail_init_zeros set"
            if self.tail_init or self.tail_init_zeros:
                if self.tail_init_zeros:
                    log_rank_0_debug(
                        logger,
                        f"All learned entity embeddings are initialized to zero.",
                    )
                else:
                    log_rank_0_debug(
                        logger,
                        f"All learned entity embeddings are initialized to the same value.",
                    )
                init_vec = model_utils.init_embeddings_to_vec(
                    self.learned_entity_embedding, pad_idx=-1, vec=init_vec
                )
                vec_save_file = os.path.join(
                    emmental.Meta.log_path, "init_vec_entity_embs.npy"
                )
                log_rank_0_debug(logger, f"Saving init vector to {vec_save_file}")
                if (
                    torch.distributed.is_initialized()
                    and torch.distributed.get_rank() == 0
                ):
                    np.save(vec_save_file, init_vec)
            else:
                log_rank_0_debug(
                    logger, f"All learned embeddings are randomly initialized."
                )
            # Regularization mapping goes from eid to 2d dropout percent
            if "regularize_mapping" in emb_args:
                log_rank_0_debug(
                    logger,
                    f"You are using regularization mapping with a topK entity embedding. "
                    f"This means all QIDs that are mapped to the same"
                    f" EID will get the same regularization value.",
                )
                if self.dropout1d_perc > 0 or self.dropout2d_perc > 0:
                    log_rank_0_debug(
                        logger,
                        f"You have 1D or 2D regularization set with a regularize_mapping. Do you mean to do this?",
                    )
                log_rank_0_debug(
                    logger,
                    f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}",
                )
                eid2reg = self.load_regularization_mapping(
                    main_args.data_config,
                    entity_symbols,
                    qid2topk_eid,
                    num_topk_entities_with_pad_and_nocand,
                    emb_args.regularize_mapping,
                )
        # Keep this mapping so a topK model can simply be loaded without needing the new eid mapping
        self.register_buffer("eid2topkeid", eid2topkeid)
        self.register_buffer("eid2reg", eid2reg)
示例#5
0
 def __init__(self, main_args, emb_args, model_device, entity_symbols,
              word_symbols, word_emb, key):
     super(TopKEntityEmb, self).__init__(main_args=main_args,
                                         emb_args=emb_args,
                                         model_device=model_device,
                                         entity_symbols=entity_symbols,
                                         word_symbols=word_symbols,
                                         word_emb=word_emb,
                                         key=key)
     self.logger = logging_utils.get_logger(main_args)
     self.learned_embedding_size = emb_args.learned_embedding_size
     self.normalize = True
     assert "perc_emb_drop" in emb_args, f"To use TopKEntityEmb we need perc_emb_drop to be in the args. This gives the percentage of embeddings" \
                                         f" removed."
     # We remove perc_emb_drop percent of the embeddings and add one to represent the new toes embedding
     num_topk_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand - int(
         emb_args.perc_emb_drop * entity_symbols.num_entities) + 1
     # Mapping of entity to the new eid mapping
     qid2topk_eid = {}
     eid2topkeid = torch.arange(
         0, entity_symbols.num_entities_with_pad_and_nocand)
     # There are issues with using -1 index into the embeddings; so we manually set it to be the last value
     eid2topkeid[-1] = num_topk_entities_with_pad_and_nocand - 1
     if "qid2topk_eid" not in emb_args:
         assert len(main_args.run_config.init_checkpoint) > 0, f"If you don't provide the qid2topk_eid mapping as an argument to TopKEntityEmb, " \
                                                               f"you must be loading a model from a checkpoint to build this index mapping"
     else:
         qid2topk_eid = utils.load_json_file(emb_args.qid2topk_eid)
         assert len(
             qid2topk_eid
         ) == entity_symbols.num_entities, f"You must have an item in qid2topk_eid for each qid in entity_symbols"
         for qid in entity_symbols.get_all_qids():
             old_eid = entity_symbols.get_eid(qid)
             new_eid = qid2topk_eid[qid]
             eid2topkeid[old_eid] = new_eid
         assert eid2topkeid[0] == 0, f"The 0 eid shouldn't be changed"
         assert eid2topkeid[
             -1] == num_topk_entities_with_pad_and_nocand - 1, "The -1 eid should still map to -1"
     self.learned_entity_embedding = nn.Embedding(
         num_topk_entities_with_pad_and_nocand,
         self.learned_embedding_size,
         padding_idx=-1,
         sparse=True)
     # Keep this mapping so a topK model can simply be loaded without needing the new eid mapping
     self.register_buffer("eid2topkeid", eid2topkeid)
     # If tail_init is false, all embeddings are randomly intialized.
     # If tail_init is true, we initialize all embeddings to be the same.
     self.tail_init = True
     self.tail_init_zeros = False
     # None init vec will be random
     init_vec = None
     if "tail_init" in emb_args:
         self.tail_init = emb_args.tail_init
     if "tail_init_zeros" in emb_args:
         self.tail_init_zeros = emb_args.tail_init_zeros
         self.tail_init = False
         init_vec = torch.zeros(1, self.learned_embedding_size)
     assert not (self.tail_init and self.tail_init_zeros
                 ), f"Can only have one of tail_init or tail_init_zeros set"
     if self.tail_init or self.tail_init_zeros:
         if self.tail_init_zeros:
             self.logger.debug(
                 f"All learned entity embeddings are initialized to zero.")
         else:
             self.logger.debug(
                 f"All learned entity embeddings are initialized to the same value."
             )
         init_vec = model_utils.init_embeddings_to_vec(
             self.learned_entity_embedding, pad_idx=-1, vec=init_vec)
         vec_save_file = os.path.join(
             train_utils.get_save_folder(main_args.run_config),
             "init_vec_entity_embs.npy")
         self.logger.debug(f"Saving init vector to {vec_save_file}")
         np.save(vec_save_file, init_vec)
     else:
         self.logger.debug(
             f"All learned embeddings are randomly initialized.")
     self._dim = main_args.model_config.hidden_size
     self.eid2reg = None
     # Regularization mapping goes from eid to 2d dropout percent
     if "regularize_mapping" in emb_args:
         self.logger.warning(
             f"You are using regularization mapping with a topK entity embedding. This means all QIDs that are mapped to the same"
             f" EID will get the same regularization value.")
         self.logger.debug(
             f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}"
         )
         self.eid2reg = self.load_regularization_mapping(
             main_args, qid2topk_eid, num_topk_entities_with_pad_and_nocand,
             emb_args.regularize_mapping, self.logger.debug)
         self.eid2reg = self.eid2reg.to(model_device)