def __init__(self, main_args, emb_args, model_device, entity_symbols, word_symbols, word_emb, key): super(LearnedEntityEmb, self).__init__(main_args=main_args, emb_args=emb_args, model_device=model_device, entity_symbols=entity_symbols, word_symbols=word_symbols, word_emb=word_emb, key=key) self.logger = logging_utils.get_logger(main_args) self.learned_embedding_size = emb_args.learned_embedding_size self.normalize = True self.learned_entity_embedding = nn.Embedding( entity_symbols.num_entities_with_pad_and_nocand, self.learned_embedding_size, padding_idx=-1, sparse=True) # If tail_init is false, all embeddings are randomly intialized. # If tail_init is true, we initialize all embeddings to be the same. self.tail_init = True self.tail_init_zeros = False # None init vec will be random init_vec = None if "tail_init" in emb_args: self.tail_init = emb_args.tail_init if "tail_init_zeros" in emb_args: self.tail_init_zeros = emb_args.tail_init_zeros self.tail_init = False init_vec = torch.zeros(1, self.learned_embedding_size) assert not (self.tail_init and self.tail_init_zeros ), f"Can only have one of tail_init or tail_init_zeros set" if self.tail_init or self.tail_init_zeros: if self.tail_init_zeros: self.logger.debug( f"All learned entity embeddings are initialized to zero.") else: self.logger.debug( f"All learned entity embeddings are initialized to the same value." ) init_vec = model_utils.init_embeddings_to_vec( self.learned_entity_embedding, pad_idx=-1, vec=init_vec) vec_save_file = os.path.join( train_utils.get_save_folder(main_args.run_config), "init_vec_entity_embs.npy") self.logger.debug(f"Saving init vector to {vec_save_file}") np.save(vec_save_file, init_vec) else: self.logger.debug( f"All learned embeddings are randomly initialized.") self._dim = main_args.model_config.hidden_size self.eid2reg = None # Regularization mapping goes from eid to 2d dropout percent if "regularize_mapping" in emb_args: self.logger.debug( f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}" ) self.eid2reg = self.load_regularization_mapping( main_args, entity_symbols, emb_args.regularize_mapping, self.logger.debug) self.eid2reg = self.eid2reg.to(model_device)
def test_initialization(self): self.learned_entity_embedding = nn.Embedding( self.entity_symbols.num_entities_with_pad_and_nocand, 4, padding_idx=-1, sparse=True, ) gold_emb = self.learned_entity_embedding.weight.data[:] gold_emb[1:] = torch.tensor( [ [1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0], [0.0, 0.0, 0.0, 0.0], ] ) init_vec = torch.tensor([1.0, 2.0, 3.0, 4.0]) init_vec_out = model_utils.init_embeddings_to_vec( self.learned_entity_embedding, pad_idx=-1, vec=init_vec ) assert torch.equal(init_vec, init_vec_out) assert torch.equal(gold_emb, self.learned_entity_embedding.weight.data)
def __init__( self, main_args, emb_args, entity_symbols, key, cpu, normalize, dropout1d_perc, dropout2d_perc, ): super(LearnedEntityEmb, self).__init__( main_args=main_args, emb_args=emb_args, entity_symbols=entity_symbols, key=key, cpu=cpu, normalize=normalize, dropout1d_perc=dropout1d_perc, dropout2d_perc=dropout2d_perc, ) allowable_keys = { "learned_embedding_size", "regularize_mapping", "tail_init", "tail_init_zeros", } correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args) if not correct: raise ValueError(f"The key {bad_key} is not in {allowable_keys}") assert ( "learned_embedding_size" in emb_args ), f"LearnedEntityEmb must have learned_embedding_size in args" self.learned_embedding_size = emb_args.learned_embedding_size # Set sparsity based on optimizer and fp16. The None optimizer is Bootleg's SparseDenseAdam. # If fp16 is True, must use dense. optimiz = main_args.learner_config.optimizer_config.optimizer if optimiz in [None, "sparse_adam"] and main_args.learner_config.fp16 is False: sparse = True else: sparse = False if ( torch.distributed.is_initialized() and main_args.model_config.distributed_backend == "nccl" ): sparse = False log_rank_0_debug( logger, f"Setting sparsity for entity embeddings to be {sparse}" ) self.learned_entity_embedding = nn.Embedding( entity_symbols.num_entities_with_pad_and_nocand, self.learned_embedding_size, padding_idx=-1, sparse=sparse, ) self._dim = main_args.model_config.hidden_size if "regularize_mapping" in emb_args: eid2reg = torch.zeros(entity_symbols.num_entities_with_pad_and_nocand) else: eid2reg = None # If tail_init is false, all embeddings are randomly intialized. # If tail_init is true, we initialize all embeddings to be the same. self.tail_init = True self.tail_init_zeros = False # None init vec will be random init_vec = None if not self.from_pretrained: if "tail_init" in emb_args: self.tail_init = emb_args.tail_init if "tail_init_zeros" in emb_args: self.tail_init_zeros = emb_args.tail_init_zeros self.tail_init = False init_vec = torch.zeros(1, self.learned_embedding_size) assert not ( self.tail_init and self.tail_init_zeros ), f"Can only have one of tail_init or tail_init_zeros set" if self.tail_init or self.tail_init_zeros: if self.tail_init_zeros: log_rank_0_debug( logger, f"All learned entity embeddings are initialized to zero.", ) else: log_rank_0_debug( logger, f"All learned entity embeddings are initialized to the same value.", ) init_vec = model_utils.init_embeddings_to_vec( self.learned_entity_embedding, pad_idx=-1, vec=init_vec ) vec_save_file = os.path.join( emmental.Meta.log_path, "init_vec_entity_embs.npy" ) log_rank_0_debug(logger, f"Saving init vector to {vec_save_file}") if ( torch.distributed.is_initialized() and torch.distributed.get_rank() == 0 ): np.save(vec_save_file, init_vec) else: log_rank_0_debug( logger, f"All learned embeddings are randomly initialized." ) # Regularization mapping goes from eid to 2d dropout percent if "regularize_mapping" in emb_args: if self.dropout1d_perc > 0 or self.dropout2d_perc > 0: log_rank_0_debug( logger, f"You have 1D or 2D regularization set with a regularize_mapping. Do you mean to do this?", ) log_rank_0_debug( logger, f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}", ) eid2reg = self.load_regularization_mapping( main_args.data_config, entity_symbols, emb_args.regularize_mapping, ) self.register_buffer("eid2reg", eid2reg)
def __init__( self, main_args, emb_args, entity_symbols, key, cpu, normalize, dropout1d_perc, dropout2d_perc, ): super(TopKEntityEmb, self).__init__( main_args=main_args, emb_args=emb_args, entity_symbols=entity_symbols, key=key, cpu=cpu, normalize=normalize, dropout1d_perc=dropout1d_perc, dropout2d_perc=dropout2d_perc, ) allowable_keys = { "learned_embedding_size", "perc_emb_drop", "qid2topk_eid", "regularize_mapping", "tail_init", "tail_init_zeros", } correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args) if not correct: raise ValueError(f"The key {bad_key} is not in {allowable_keys}") assert ( "learned_embedding_size" in emb_args ), f"TopKEntityEmb must have learned_embedding_size in args" assert "perc_emb_drop" in emb_args, ( f"To use TopKEntityEmb we need perc_emb_drop to be in the args. This gives the percentage of embeddings" f" removed." ) self.learned_embedding_size = emb_args.learned_embedding_size # We remove perc_emb_drop percent of the embeddings and add one to represent the new toes embedding num_topk_entities_with_pad_and_nocand = ( entity_symbols.num_entities_with_pad_and_nocand - int(emb_args.perc_emb_drop * entity_symbols.num_entities) + 1 ) # Mapping of entity to the new eid mapping eid2topkeid = torch.arange(0, entity_symbols.num_entities_with_pad_and_nocand) # There are issues with using -1 index into the embeddings; so we manually set it to be the last value eid2topkeid[-1] = num_topk_entities_with_pad_and_nocand - 1 if "qid2topk_eid" not in emb_args: assert self.from_pretrained, ( f"If you don't provide the qid2topk_eid mapping as an argument to TopKEntityEmb, " f"you must be loading a model from a checkpoint to build this index mapping" ) self.learned_entity_embedding = nn.Embedding( num_topk_entities_with_pad_and_nocand, self.learned_embedding_size, padding_idx=-1, sparse=False, ) self._dim = main_args.model_config.hidden_size if "regularize_mapping" in emb_args: eid2reg = torch.zeros(num_topk_entities_with_pad_and_nocand) else: eid2reg = None # If tail_init is false, all embeddings are randomly intialized. # If tail_init is true, we initialize all embeddings to be the same. self.tail_init = True self.tail_init_zeros = False # None init vec will be random init_vec = None if not self.from_pretrained: qid2topk_eid = utils.load_json_file(emb_args.qid2topk_eid) assert ( len(qid2topk_eid) == entity_symbols.num_entities ), f"You must have an item in qid2topk_eid for each qid in entity_symbols" for qid in entity_symbols.get_all_qids(): old_eid = entity_symbols.get_eid(qid) new_eid = qid2topk_eid[qid] eid2topkeid[old_eid] = new_eid assert eid2topkeid[0] == 0, f"The 0 eid shouldn't be changed" assert ( eid2topkeid[-1] == num_topk_entities_with_pad_and_nocand - 1 ), "The -1 eid should still map to -1" if "tail_init" in emb_args: self.tail_init = emb_args.tail_init if "tail_init_zeros" in emb_args: self.tail_init_zeros = emb_args.tail_init_zeros self.tail_init = False init_vec = torch.zeros(1, self.learned_embedding_size) assert not ( self.tail_init and self.tail_init_zeros ), f"Can only have one of tail_init or tail_init_zeros set" if self.tail_init or self.tail_init_zeros: if self.tail_init_zeros: log_rank_0_debug( logger, f"All learned entity embeddings are initialized to zero.", ) else: log_rank_0_debug( logger, f"All learned entity embeddings are initialized to the same value.", ) init_vec = model_utils.init_embeddings_to_vec( self.learned_entity_embedding, pad_idx=-1, vec=init_vec ) vec_save_file = os.path.join( emmental.Meta.log_path, "init_vec_entity_embs.npy" ) log_rank_0_debug(logger, f"Saving init vector to {vec_save_file}") if ( torch.distributed.is_initialized() and torch.distributed.get_rank() == 0 ): np.save(vec_save_file, init_vec) else: log_rank_0_debug( logger, f"All learned embeddings are randomly initialized." ) # Regularization mapping goes from eid to 2d dropout percent if "regularize_mapping" in emb_args: log_rank_0_debug( logger, f"You are using regularization mapping with a topK entity embedding. " f"This means all QIDs that are mapped to the same" f" EID will get the same regularization value.", ) if self.dropout1d_perc > 0 or self.dropout2d_perc > 0: log_rank_0_debug( logger, f"You have 1D or 2D regularization set with a regularize_mapping. Do you mean to do this?", ) log_rank_0_debug( logger, f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}", ) eid2reg = self.load_regularization_mapping( main_args.data_config, entity_symbols, qid2topk_eid, num_topk_entities_with_pad_and_nocand, emb_args.regularize_mapping, ) # Keep this mapping so a topK model can simply be loaded without needing the new eid mapping self.register_buffer("eid2topkeid", eid2topkeid) self.register_buffer("eid2reg", eid2reg)
def __init__(self, main_args, emb_args, model_device, entity_symbols, word_symbols, word_emb, key): super(TopKEntityEmb, self).__init__(main_args=main_args, emb_args=emb_args, model_device=model_device, entity_symbols=entity_symbols, word_symbols=word_symbols, word_emb=word_emb, key=key) self.logger = logging_utils.get_logger(main_args) self.learned_embedding_size = emb_args.learned_embedding_size self.normalize = True assert "perc_emb_drop" in emb_args, f"To use TopKEntityEmb we need perc_emb_drop to be in the args. This gives the percentage of embeddings" \ f" removed." # We remove perc_emb_drop percent of the embeddings and add one to represent the new toes embedding num_topk_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand - int( emb_args.perc_emb_drop * entity_symbols.num_entities) + 1 # Mapping of entity to the new eid mapping qid2topk_eid = {} eid2topkeid = torch.arange( 0, entity_symbols.num_entities_with_pad_and_nocand) # There are issues with using -1 index into the embeddings; so we manually set it to be the last value eid2topkeid[-1] = num_topk_entities_with_pad_and_nocand - 1 if "qid2topk_eid" not in emb_args: assert len(main_args.run_config.init_checkpoint) > 0, f"If you don't provide the qid2topk_eid mapping as an argument to TopKEntityEmb, " \ f"you must be loading a model from a checkpoint to build this index mapping" else: qid2topk_eid = utils.load_json_file(emb_args.qid2topk_eid) assert len( qid2topk_eid ) == entity_symbols.num_entities, f"You must have an item in qid2topk_eid for each qid in entity_symbols" for qid in entity_symbols.get_all_qids(): old_eid = entity_symbols.get_eid(qid) new_eid = qid2topk_eid[qid] eid2topkeid[old_eid] = new_eid assert eid2topkeid[0] == 0, f"The 0 eid shouldn't be changed" assert eid2topkeid[ -1] == num_topk_entities_with_pad_and_nocand - 1, "The -1 eid should still map to -1" self.learned_entity_embedding = nn.Embedding( num_topk_entities_with_pad_and_nocand, self.learned_embedding_size, padding_idx=-1, sparse=True) # Keep this mapping so a topK model can simply be loaded without needing the new eid mapping self.register_buffer("eid2topkeid", eid2topkeid) # If tail_init is false, all embeddings are randomly intialized. # If tail_init is true, we initialize all embeddings to be the same. self.tail_init = True self.tail_init_zeros = False # None init vec will be random init_vec = None if "tail_init" in emb_args: self.tail_init = emb_args.tail_init if "tail_init_zeros" in emb_args: self.tail_init_zeros = emb_args.tail_init_zeros self.tail_init = False init_vec = torch.zeros(1, self.learned_embedding_size) assert not (self.tail_init and self.tail_init_zeros ), f"Can only have one of tail_init or tail_init_zeros set" if self.tail_init or self.tail_init_zeros: if self.tail_init_zeros: self.logger.debug( f"All learned entity embeddings are initialized to zero.") else: self.logger.debug( f"All learned entity embeddings are initialized to the same value." ) init_vec = model_utils.init_embeddings_to_vec( self.learned_entity_embedding, pad_idx=-1, vec=init_vec) vec_save_file = os.path.join( train_utils.get_save_folder(main_args.run_config), "init_vec_entity_embs.npy") self.logger.debug(f"Saving init vector to {vec_save_file}") np.save(vec_save_file, init_vec) else: self.logger.debug( f"All learned embeddings are randomly initialized.") self._dim = main_args.model_config.hidden_size self.eid2reg = None # Regularization mapping goes from eid to 2d dropout percent if "regularize_mapping" in emb_args: self.logger.warning( f"You are using regularization mapping with a topK entity embedding. This means all QIDs that are mapped to the same" f" EID will get the same regularization value.") self.logger.debug( f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}" ) self.eid2reg = self.load_regularization_mapping( main_args, qid2topk_eid, num_topk_entities_with_pad_and_nocand, emb_args.regularize_mapping, self.logger.debug) self.eid2reg = self.eid2reg.to(model_device)