def get_eval_slice_subset_indices(args, eval_slice_dataset, dataset): logger = logging_utils.get_logger(args) logger.debug('Starting to sample indices') eval_slices = args.run_config.eval_slices # Get unique sentence indexes from samples for all slices # Will take union of all the data rows that map from these sentence indexes to eval sent_indices = set() for slice_name in eval_slices: # IF THE SEED CHANGES WHAT IS SAMPLED FOR DEV WILL CHANGE # randomly sample indices from slice slice_indexes = eval_slice_dataset.get_non_empty_sent_idxs(slice_name) perc_eval_examples = int(args.run_config.perc_eval * len(slice_indexes)) data_len = max(perc_eval_examples, args.run_config.min_eval_size, 1) if data_len >= len(slice_indexes): # if requested sample is larger than the actual data, just use the whole slice random_indices = range(len(slice_indexes)) else: random_indices = np.random.choice(len(slice_indexes), data_len, replace=False) for idx in random_indices: sent_indices.add(slice_indexes[idx]) logger.debug('Starting to gather indices') # Get corresponding indices for key set indices = [] for sent_idx in sent_indices: if sent_idx in dataset.sent_idx_to_idx: samples = dataset.sent_idx_to_idx[sent_idx] for data_idx in samples: indices.append(data_idx) logger.info(f'Sampled {len(indices)} indices from dataset (dev/test) for evaluation.') return indices
def __init__(self, args, model_device, entity_symbols, word_symbols): super(Model, self).__init__() self.model_device = model_device self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand self.logger = logging_utils.get_logger(args) # embeddings self.emb_layer = EmbeddingLayer(args, self.model_device, entity_symbols, word_symbols) self.type_pred = False if args.data_config.type_prediction.use_type_pred: self.type_pred = True # Add 1 for pad type self.type_prediction = TypePred( args.model_config.hidden_size, args.data_config.type_prediction.dim, args.data_config.type_prediction.num_types + 1) self.emb_combiner = EmbCombinerProj(args, self.emb_layer.emb_sizes, self.emb_layer.sent_emb_size, word_symbols, entity_symbols) # attention network mod, load_class = import_class("bootleg.layers.attn_networks", args.model_config.attn_load_class) self.attn_network = getattr(mod, load_class)(args, self.emb_layer.emb_sizes, self.emb_layer.sent_emb_size, entity_symbols, word_symbols) # slice heads self.slice_heads = self.get_slice_method(args, entity_symbols) self.freeze_components(args)
def __init__(self, main_args, emb_args, model_device, entity_symbols, word_symbols, word_emb, key): super(AvgTitleEmb, self).__init__(main_args=main_args, emb_args=emb_args, model_device=model_device, entity_symbols=entity_symbols, word_symbols=word_symbols, word_emb=word_emb, key=key) # self.word_emb = word_emb self.logger = logging_utils.get_logger(main_args) self.model_device = model_device self.word_emb = word_emb self.word_symbols = word_symbols self.orig_dim = word_emb.get_dim() self.merge_func = self.average_titles self.normalize = True self._dim = main_args.model_config.hidden_size self.requires_grad_title = word_emb.requires_grad if "freeze_word_emb_for_titles" in emb_args: self.requires_grad_title = not emb_args.freeze_word_emb_for_titles assert not self.requires_grad_title or word_emb.requires_grad,\ "Inconsistent Args: You have to not freeze word embeddings for titles but freeze word embeddings" self.entity2titleid_table = self.prep(main_args=main_args, word_symbols=word_symbols, entity_symbols=entity_symbols, log_func=self.logger.debug) self.entity2titleid_table = self.entity2titleid_table.to(model_device)
def __setstate__(self, state): self.__dict__.update(state) self.alias2entity_table = torch.tensor( np.memmap(self.prep_file, dtype='int64', mode='r', shape=(self.num_aliases_with_pad, self.K))) self.logger = logging_utils.get_logger(self.args)
def __init__(self, args, use_weak_label, input_src, dataset_name, is_writer, distributed, dataset_is_eval): self.logger = logging_utils.get_logger(args) self.args = args self.dataset_is_eval = dataset_is_eval self.storage_type = None self.dataset_name = dataset_name self.config_dataset_name = data_utils.get_slice_storage_file( dataset_name) self.sent_idx_to_idx_file = data_utils.get_sent_idx_file(dataset_name) # load memory mapped files self.logger.info(f"Loading slices...") self.logger.debug("Seeing if " + dataset_name + " exists") start = time.time() if (not args.data_config.overwrite_preprocessed_data and os.path.exists(self.dataset_name) and os.path.exists(self.config_dataset_name) and os.path.exists(self.sent_idx_to_idx_file)): self.logger.debug(f"Will load existing dataset {dataset_name}") else: self.logger.debug(f"Building dataset with {input_src}") # only prep data once per node if is_writer: self.storage_type = prep_slice( args=args, file=os.path.basename(input_src), use_weak_label=use_weak_label, dataset_is_eval=dataset_is_eval, dataset_name=self.dataset_name, sent_idx_file=self.sent_idx_to_idx_file, storage_config=self.config_dataset_name, logger=self.logger) np.save(self.config_dataset_name, self.storage_type, allow_pickle=True) if distributed: # Make sure all processes wait for data to be created dist.barrier() self.logger.debug( f"Finished building and saving dataset in {round(time.time() - start, 2)}s." ) self.storage_type = np.load(self.config_dataset_name, allow_pickle=True).item() self.sent_idx_arr = np.memmap(self.sent_idx_to_idx_file, dtype=np.int, mode='r') st = time.time() # Load and reformat it to be the proper recarray shape of # rows x 1 self.data = np.expand_dims(np.memmap(self.dataset_name, dtype=self.storage_type, mode='r').view(np.recarray), axis=1) assert len(self.data) > 0 assert len(self.sent_idx_arr) > 0 self.logger.info(f"Finished loading slices.") self.data_len = len(self.data)
def __init__(self, args, entity_symbols): super(SliceHeadsSBL, self).__init__() self.logger = logging_utils.get_logger(args) self.dropout = args.train_config.dropout if "use_ind_attn" in args.model_config.custom_args and args.model_config.custom_args.use_ind_attn: self.logger.info('Using attention only over indicator confidences') self.use_ind_attn = args.model_config.custom_args.use_ind_attn else: self.use_ind_attn = False # Debugging parameter to see what happens when we remove the "final_loss" merging head if "remove_final_loss" in args.model_config.custom_args: self.remove_final_loss = args.model_config.custom_args.remove_final_loss else: self.remove_final_loss = False self.logger.info( f"Remove final loss: {self.remove_final_loss} and sen's trick {self.use_ind_attn}" ) # Softmax temperature self.temperature = args.train_config.softmax_temp self.hidden_size = args.model_config.hidden_size self.K = entity_symbols.max_candidates + ( not args.data_config.train_in_candidates) self.M = args.data_config.max_aliases self.train_heads = args.train_config.train_heads # Predicts whether or not an example is in the slice self.indicator_heads = nn.ModuleDict() self.ind_alias_mha = nn.ModuleDict() # Creates slice expert representations self.transform_modules = nn.ModuleDict() for slice_head in self.train_heads: # Generates a BxMxH representation that gets fed to the linear layer for predictions # This does an attention over the alias word and sentence (with added alias-slice learned embedding) and # then does an attention over the alias candidates (with added entity-slice learned embedding) self.ind_alias_mha[slice_head] = AliasMHA(args) # Binary prediction of in the slice or not self.indicator_heads[slice_head] = nn.Linear(self.hidden_size, 2) # transform layer for each slice self.transform_modules[slice_head] = nn.Linear( self.hidden_size, self.hidden_size) # Shared prediction layer to get confidences self.shared_slice_pred_head = nn.Linear(self.hidden_size, 1) # Embedding for each slice head for the indicator (added to queries in the AliasMHA heads) self.slice_emb_ind_alias = nn.Embedding(len(self.train_heads), self.hidden_size) self.slice_emb_ind_ent = nn.Embedding(len(self.train_heads), self.hidden_size) # Final prediction layer self.final_pred_head = nn.Linear(self.hidden_size, 1)
def __init__(self, main_args, emb_args, model_device, entity_symbols, word_symbols, word_emb, key): super(TypeEmb, self).__init__(main_args=main_args, emb_args=emb_args, model_device=model_device, entity_symbols=entity_symbols, word_symbols=word_symbols, word_emb=word_emb, key=key) self.logger = logging_utils.get_logger(main_args) self.merge_func = self.average_types self.orig_dim = emb_args.type_dim # Function for merging multiple types if "merge_func" in emb_args: if emb_args.merge_func not in ["average", "softattn", "addattn"]: self.logger.warning( f"{key}: You have set the type merge_func to be {emb_args.merge_func} but that is not in the allowable list of [average, sofftattn]" ) elif emb_args.merge_func == "softattn": if "attn_hidden_size" in emb_args: attn_hidden_size = emb_args.attn_hidden_size else: attn_hidden_size = 100 self.logger.debug( f"{key}: Setting merge_func to be soft_attn in type emb with context size {attn_hidden_size}" ) # Softmax of types using the sentence context self.soft_attn = SoftAttn( emb_dim=self.orig_dim, context_dim=main_args.model_config.hidden_size, size=attn_hidden_size) self.merge_func = self.soft_attn_merge elif emb_args.merge_func == "addattn": if "attn_hidden_size" in emb_args: attn_hidden_size = emb_args.attn_hidden_size else: attn_hidden_size = 100 self.logger.debug( f"{key}: Setting merge_func to be add_attn in type emb with context size {attn_hidden_size}" ) # Softmax of types using the sentence context self.add_attn = PositionAwareAttention( input_size=self.orig_dim, attn_size=attn_hidden_size, feature_size=0) self.merge_func = self.add_attn_merge self.normalize = True self.entity_symbols = entity_symbols self.max_types = emb_args.max_types self.eid2typeids_table, self.type2rowid, num_types_with_unk, self.prep_file = self.prep( main_args=main_args, emb_args=emb_args, entity_symbols=entity_symbols, log_func=self.logger.debug) self.num_types_with_pad_and_unk = num_types_with_unk + 1 self.eid2typeids_table = self.eid2typeids_table.to(model_device)
def __setstate__(self, state): self.__dict__.update(state) self.data = np.expand_dims(np.memmap(self.dataset_name, dtype=self.storage_type, mode='r').view(np.recarray), axis=1) self.sent_idx_arr = np.memmap(self.sent_idx_to_idx_file, dtype=np.int, mode='r') self.logger = logging_utils.get_logger(self.args)
def __setstate__(self, state): self.__dict__.update(state) self.data = np.memmap(self.dataset_name, dtype=self.storage_type, mode='r') self.batch_prepped_emb_files = {} for emb_name, file_name in self.batch_prepped_emb_file_names.items(): self.batch_prepped_emb_files[emb_name] = np.memmap( self.batch_prepped_emb_file_names[emb_name], dtype=self.batch_prep_config[emb_name]['dtype'], shape=tuple(self.batch_prep_config[emb_name]['shape']), mode='r') self.logger = logging_utils.get_logger(self.args)
def __init__(self, args, model_device, entity_symbols, word_symbols): super(BaselineModel, self).__init__(args, model_device, entity_symbols, word_symbols) self.model_device = model_device self.logger = logging_utils.get_logger(args) mod, load_class = import_class("bootleg.layers.attn_networks", args.model_config.attn_load_class) self.emb_layer = EmbeddingLayerNoProj(args, self.model_device, entity_symbols, word_symbols) self.attn_network = getattr(mod, load_class)(args, self.emb_layer.emb_sizes, self.emb_layer.sent_emb_size, entity_symbols, word_symbols) self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand self.freeze_components(args)
def __init__(self, args, embedding_sizes, sent_emb_size, entity_symbols, word_symbols): super(AttnNetwork, self).__init__() self.logger = logging_utils.get_logger(args) self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand # Number of candidates self.K = entity_symbols.max_candidates + ( not args.data_config.train_in_candidates) # Number of aliases self.M = args.data_config.max_aliases self.sent_emb_size = sent_emb_size self.sent_len = args.data_config.max_word_token_len + 2 * word_symbols.is_bert self.hidden_size = args.model_config.hidden_size self.num_heads = args.model_config.num_heads self.num_model_stages = args.model_config.num_model_stages assert self.num_model_stages > 0, f"You must have > 0 model stages. You have {self.num_model_stages}" self.num_fc_layers = args.model_config.num_fc_layers self.ff_inner_size = args.model_config.ff_inner_size
def __init__(self, args, entity_symbols): super(AliasEntityTable, self).__init__() self.args = args self.logger = logging_utils.get_logger(args) self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand self.num_aliases_with_pad = len(entity_symbols.get_all_aliases()) + 1 self.M = args.data_config.max_aliases self.K = entity_symbols.max_candidates + ( not args.data_config.train_in_candidates) self.alias2entity_table, self.prep_file = self.prep( args, entity_symbols, num_aliases_with_pad=self.num_aliases_with_pad, num_cands_K=self.K, log_func=self.logger.debug) # Small check that loading was done correctly. This isn't a catch all, but will catch is the same or something went wrong. assert np.all( np.array(self.alias2entity_table[-1]) == np.ones(self.K) * -1 ), f"The last row of the alias table isn't -1, something wasn't loaded right."
def create_dataloader(args, dataset, batch_size, eval_slice_dataset=None, world_size=None, rank=None): logger = logging_utils.get_logger(args) if eval_slice_dataset is not None and not args.run_config.distributed: indices = get_eval_slice_subset_indices(args, eval_slice_dataset=eval_slice_dataset, dataset=dataset) # Form sampler with for indices from eval_slice_dataset sampler = SubsetRandomSampler(indices) elif args.run_config.distributed: # wrap dataset object to use a subsetsampler with distributed if eval_slice_dataset is not None: indices = get_eval_slice_subset_indices(args, eval_slice_dataset=eval_slice_dataset, dataset=dataset) dataset = DistributedIndicesWrapper(dataset, torch.tensor(indices)) sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) else: sampler = None dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=(sampler is None), sampler=sampler, num_workers=args.run_config.dataloader_threads, pin_memory=False) return dataloader, sampler
def __init__(self, main_args, emb_args, model_device, entity_symbols, word_symbols, word_emb, key): super(EntityEmb, self).__init__() # Metaclasses are types that create classes # https://stackoverflow.com/questions/100003/what-are-metaclasses-in-python # We use this to enforce that a self.normalize attribute is instantiated in subclasses __metaclass__ = RequiredAttributes("normalize") self.logger = logging_utils.get_logger(main_args) self.entity_symbols = entity_symbols self.key = key self.dropout_perc = 0 self.mask_perc = 0 # Used for 2d dropout if MASK_PERC in emb_args: self.mask_perc = emb_args[MASK_PERC] self.logger.debug( f'Setting {self.key} mask perc to {self.mask_perc}') # Used for 1d dropout if "dropout" in emb_args: self.dropout_perc = emb_args.dropout_perc self.logger.debug(f'Setting {self.key} dropout to {self.dropout}')
def __init__(self, args, model_device, entity_symbols, word_symbols): super(Model, self).__init__() self.model_device = model_device self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand self.logger = logging_utils.get_logger(args) # embeddings self.emb_layer = EmbeddingLayer(args, self.model_device, entity_symbols, word_symbols) self.emb_combiner = EmbCombinerProj(args, self.emb_layer.emb_sizes, self.emb_layer.sent_emb_size, word_symbols, entity_symbols) # attention network mod, load_class = import_class("bootleg.layers.attn_networks", args.model_config.attn_load_class) self.attn_network = getattr(mod, load_class)(args, self.emb_layer.emb_sizes, self.emb_layer.sent_emb_size, entity_symbols, word_symbols) # slice heads self.slice_heads = self.get_slice_method(args, entity_symbols) self.freeze_components(args)
def __init__(self, main_args, emb_args, model_device, entity_symbols, word_symbols, word_emb, key): super(LearnedEntityEmb, self).__init__(main_args=main_args, emb_args=emb_args, model_device=model_device, entity_symbols=entity_symbols, word_symbols=word_symbols, word_emb=word_emb, key=key) self.logger = logging_utils.get_logger(main_args) self.learned_embedding_size = emb_args.learned_embedding_size self.normalize = True self.learned_entity_embedding = nn.Embedding( entity_symbols.num_entities_with_pad_and_nocand, self.learned_embedding_size, padding_idx=-1, sparse=True) # If tail_init is false, all embeddings are randomly intialized. # If tail_init is true, we initialize all embeddings to be the same. tail_init = True if "tail_init" in emb_args: tail_init = emb_args.tail_init if tail_init: qid_count_dict = {} self.logger.debug( f"All learned entity embeddings are initialized to the same value." ) init_vec = model_utils.init_tail_embeddings( self.learned_entity_embedding, qid_count_dict, entity_symbols, pad_idx=-1) else: self.logger.debug( f"All learned embeddings are randomly initialized.") self._dim = main_args.model_config.hidden_size self.dropout = nn.Dropout(self.dropout_perc)
def __init__(self, main_args, emb_args, model_device, entity_symbols, word_symbols=None, word_emb=None, key=""): super(KGWeightedAdjEmb, self).__init__(main_args=main_args, emb_args=emb_args, model_device=model_device, entity_symbols=entity_symbols, word_symbols=word_symbols, word_emb=word_emb, key=key) # needed to recreate logger self.main_args = main_args self.logger = logging_utils.get_logger(main_args) self._dim = 1 self.model_device = model_device self.kg_adj, self.prep_file = self.prep(main_args=main_args, emb_args=emb_args, entity_symbols=entity_symbols, log_func=self.logger.debug)
def __setstate__(self, state): self.__dict__.update(state) self.logger = logging_utils.get_logger(self.main_args) # we can assume the adjacency matrix has already been built and saved self.kg_adj = scipy.sparse.load_npz(self.prep_file)
def __init__(self, emb_args, main_args, word_emb_dim, word_symbols): super(BaseSentEmbedding, self).__init__() self.logger = logging_utils.get_logger(main_args) self._key = "sentence" self._dim = word_emb_dim
def __init__(self, main_args, emb_args, model_device, entity_symbols, word_symbols=None, word_emb=None, key=""): super(KGRelEmb, self).__init__(main_args=main_args, emb_args=emb_args, model_device=model_device, entity_symbols=entity_symbols, word_symbols=word_symbols, word_emb=word_emb, key=key) # needed to recreate logger self.main_args = main_args self.logger = logging_utils.get_logger(main_args) self.model_device = model_device self.mask_candidates = True self.kg_adj, self.kg_relations, self.rel2rowid, self.prep_file_adj = self.prep( main_args=main_args, emb_args=emb_args, entity_symbols=entity_symbols, log_func=self.logger.debug) # initialize learned relation embedding self.num_relations_with_pad = len(self.kg_relations) + 1 self._dim = emb_args.rel_dim # Sparse cannot be true for the relation embedding or we get distributed Gloo errors depending on the batch (Gloo EnforceNotMet...) self.relation_emb = torch.nn.Embedding(self.num_relations_with_pad, self._dim, padding_idx=0, sparse=False).to(model_device) self.merge_func = self.average_rels if "merge_func" in emb_args: if emb_args.merge_func not in ["average", "softattn", "addattn"]: self.logger.warning( f"{key}: You have set the type merge_func to be {emb_args.merge_func} but that is not in the allowable list of [average, sofftattn]" ) elif emb_args.merge_func == "softattn": if "attn_hidden_size" in emb_args: attn_hidden_size = emb_args.attn_hidden_size else: attn_hidden_size = 100 self.sub_chunk = 3000 self.logger.debug( f"{key}: Setting merge_func to be soft_attn in relation emb with context size {attn_hidden_size} and sub_chunk of {self.sub_chunk}" ) # Softmax of types using the sentence context self.soft_attn = SoftAttn( emb_dim=self._dim, context_dim=main_args.model_config.hidden_size, size=attn_hidden_size) self.merge_func = self.soft_attn_merge elif emb_args.merge_func == "addattn": if "attn_hidden_size" in emb_args: attn_hidden_size = emb_args.attn_hidden_size else: attn_hidden_size = 100 self.logger.info( f"{key}: Setting merge_func to be add_attn in relation emb with context size {attn_hidden_size}" ) # Softmax of types using the sentence context self.add_attn = PositionAwareAttention( input_size=self._dim, attn_size=attn_hidden_size, feature_size=0) self.merge_func = self.add_attn_merge self.normalize = True self.logger.debug( f'{key}: Using {self._dim}-dim relation embedding for {len(self.kg_relations)} relations with 1 pad relation. Normalize is {self.normalize}' )
def __init__(self, main_args, emb_args, model_device, entity_symbols, word_symbols, word_emb, key): super(TypeEmb, self).__init__(main_args=main_args, emb_args=emb_args, model_device=model_device, entity_symbols=entity_symbols, word_symbols=word_symbols, word_emb=word_emb, key=key) self.logger = logging_utils.get_logger(main_args) self.merge_func = self.average_types self.orig_dim = emb_args.type_dim # Function for merging multiple types if "merge_func" in emb_args: if emb_args.merge_func not in ["average", "softattn", "addattn"]: self.logger.warning( f"{key}: You have set the type merge_func to be {emb_args.merge_func} but that is not in the allowable list of [average, sofftattn]" ) elif emb_args.merge_func == "softattn": if "attn_hidden_size" in emb_args: attn_hidden_size = emb_args.attn_hidden_size else: attn_hidden_size = 100 # Softmax of types using the sentence context self.soft_attn = SoftAttn( emb_dim=self.orig_dim, context_dim=main_args.model_config.hidden_size, size=attn_hidden_size) self.merge_func = self.soft_attn_merge elif emb_args.merge_func == "addattn": if "attn_hidden_size" in emb_args: attn_hidden_size = emb_args.attn_hidden_size else: attn_hidden_size = 100 # Softmax of types using the sentence context self.add_attn = PositionAwareAttention( input_size=self.orig_dim, attn_size=attn_hidden_size, feature_size=0) self.merge_func = self.add_attn_merge self.normalize = True self.entity_symbols = entity_symbols self.max_types = emb_args.max_types self.eid2typeids_table, self.type2row_dict, num_types_with_unk, self.prep_file = self.prep( main_args=main_args, emb_args=emb_args, entity_symbols=entity_symbols, log_func=self.logger.debug) self.num_types_with_pad_and_unk = num_types_with_unk + 1 self.eid2typeids_table = self.eid2typeids_table.to(model_device) # Regularization mapping goes from typeid to 2d dropout percent self.typeid2reg = None if "regularize_mapping" in emb_args: self.logger.debug( f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}" ) self.typeid2reg = self.load_regularization_mapping( main_args, self.num_types_with_pad_and_unk, self.type2row_dict, emb_args.regularize_mapping, self.logger.debug) self.typeid2reg = self.typeid2reg.to(model_device) assert self.eid2typeids_table.shape[1] == emb_args.max_types, f"Something went wrong with loading type file." \ f" The given max types {emb_args.max_types} does not match that of type table {self.eid2typeids_table.shape[1]}" self.logger.debug( f"{key}: Type embedding with {self.max_types} types with dim {self.orig_dim}. Setting merge_func to be {self.merge_func.__name__} in type emb." )
def __init__(self, args, use_weak_label, input_src, dataset_name, is_writer, distributed, word_symbols, entity_symbols, slice_dataset=None, dataset_is_eval=False): # Need to save args to reinstantiate logger self.args = args self.logger = logging_utils.get_logger(args) # Number of candidates, including NIL if a NIL model (train_in_candidates is False) self.K = entity_symbols.max_candidates + ( not args.data_config.train_in_candidates) self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand self.dataset_name = dataset_name self.slice_dataset = slice_dataset self.dataset_is_eval = dataset_is_eval # Slice names used for eval slices and a slicing model self.slice_names = train_utils.get_data_slices(args, dataset_is_eval) self.storage_type_file = data_utils.get_storage_file(self.dataset_name) # Mappings from sent_idx to row_id in dataset self.sent_idx_file = os.path.splitext( dataset_name)[0] + "_sent_idx.json" self.type_pred = False if args.data_config.type_prediction.use_type_pred: self.type_pred = True self.eid2typeid, self.num_types_with_pad = self.load_coarse_type_table( args, entity_symbols) # Load memory mapped file self.logger.info("Loading dataset...") self.logger.debug("Seeing if " + dataset_name + " exists") if (args.data_config.overwrite_preprocessed_data or (not os.path.exists(self.dataset_name)) or (not os.path.exists(self.sent_idx_file)) or (not os.path.exists(self.storage_type_file)) or (not os.path.exists( data_utils.get_batch_prep_config(self.dataset_name)))): start = time.time() self.logger.debug(f"Building dataset with {input_src}") # Only prep data once per node if is_writer: prep_data(args, use_weak_label=use_weak_label, dataset_is_eval=self.dataset_is_eval, input_src=input_src, dataset_name=dataset_name, prep_dir=data_utils.get_data_prep_dir(args)) if distributed: # Make sure all processes wait for data to be created dist.barrier() self.logger.debug( f"Finished building and saving dataset in {round(time.time() - start, 2)}s." ) start = time.time() # Storage type for loading memory mapped file of dataset self.storage_type = pickle.load(open(self.storage_type_file, 'rb')) self.data = np.memmap(self.dataset_name, dtype=self.storage_type, mode='r') self.data_len = len(self.data) # Mapping from sentence idx to rows in the dataset (indices). # Needed when sampling sentence indices from slices for evaluation. sent_idx_to_idx_str = utils.load_json_file(self.sent_idx_file) self.sent_idx_to_idx = { int(i): val for i, val in sent_idx_to_idx_str.items() } self.logger.info(f"Finished loading dataset.") # Stores info about the batch prepped embedding memory mapped files and their shapes and datatypes # so we can load them self.batch_prep_config = utils.load_json_file( data_utils.get_batch_prep_config(self.dataset_name)) self.batch_prepped_emb_files = {} self.batch_prepped_emb_file_names = {} for emb in args.data_config.ent_embeddings: if 'batch_prep' in emb and emb['batch_prep']: assert emb.key in self.batch_prep_config, f'Need to prep {emb.key}. Please call prep instead of run with batch_prep_embeddings set to true.' self.batch_prepped_emb_file_names[emb.key] = os.path.join( os.path.dirname(self.dataset_name), os.path.basename( self.batch_prep_config[emb.key]['file_name'])) self.batch_prepped_emb_files[emb.key] = np.memmap( self.batch_prepped_emb_file_names[emb.key], dtype=self.batch_prep_config[emb.key]['dtype'], shape=tuple(self.batch_prep_config[emb.key]['shape']), mode='r') assert len(self.batch_prepped_emb_files[emb.key]) == self.data_len,\ f'Preprocessed emb data file {self.batch_prep_config[emb.key]["file_name"]} does not match length of main data file.' # Stores embeddings that we compute on the fly; these are embeddings where batch_on_the_fly is set to true. self.batch_on_the_fly_embs = {} for emb in args.data_config.ent_embeddings: if 'batch_on_the_fly' in emb and emb['batch_on_the_fly'] is True: mod, load_class = import_class("bootleg.embeddings", emb.load_class) try: self.batch_on_the_fly_embs[emb.key] = getattr( mod, load_class)(main_args=args, emb_args=emb['args'], entity_symbols=entity_symbols, model_device=None, word_symbols=None, key=emb.key) except AttributeError as e: self.logger.warning( f'No prep method found for {emb.load_class} with error {e}' ) except Exception as e: print("ERROR", e) # The data in this table shouldn't be pickled since we delete it in the class __getstate__ self.alias2entity_table = AliasEntityTable( args=args, entity_symbols=entity_symbols) # Random NIL percent self.mask_perc = args.train_config.random_nil_perc self.random_nil = False # Don't want to random mask for eval if not dataset_is_eval: # Whether to use a random NIL training regime self.random_nil = args.train_config.random_nil if self.random_nil: self.logger.info( f'Using random nils during training with {self.mask_perc} percent' )
def __init__(self, args=None, model_device=None): self.model_device = model_device self.logger = logging_utils.get_logger(args) self.crit_pred = nn.NLLLoss(ignore_index=-1) self.type_pred = nn.CrossEntropyLoss(ignore_index=-1) self.weights = {train_utils.get_slice_head_ind_name(slice_name):None for slice_name in args.train_config.train_heads}
def __init__(self, main_args, emb_args, model_device, entity_symbols, word_symbols, word_emb, key): super(TopKEntityEmb, self).__init__(main_args=main_args, emb_args=emb_args, model_device=model_device, entity_symbols=entity_symbols, word_symbols=word_symbols, word_emb=word_emb, key=key) self.logger = logging_utils.get_logger(main_args) self.learned_embedding_size = emb_args.learned_embedding_size self.normalize = True assert "perc_emb_drop" in emb_args, f"To use TopKEntityEmb we need perc_emb_drop to be in the args. This gives the percentage of embeddings" \ f" removed." # We remove perc_emb_drop percent of the embeddings and add one to represent the new toes embedding num_topk_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand - int( emb_args.perc_emb_drop * entity_symbols.num_entities) + 1 # Mapping of entity to the new eid mapping qid2topk_eid = {} eid2topkeid = torch.arange( 0, entity_symbols.num_entities_with_pad_and_nocand) # There are issues with using -1 index into the embeddings; so we manually set it to be the last value eid2topkeid[-1] = num_topk_entities_with_pad_and_nocand - 1 if "qid2topk_eid" not in emb_args: assert len(main_args.run_config.init_checkpoint) > 0, f"If you don't provide the qid2topk_eid mapping as an argument to TopKEntityEmb, " \ f"you must be loading a model from a checkpoint to build this index mapping" else: qid2topk_eid = utils.load_json_file(emb_args.qid2topk_eid) assert len( qid2topk_eid ) == entity_symbols.num_entities, f"You must have an item in qid2topk_eid for each qid in entity_symbols" for qid in entity_symbols.get_all_qids(): old_eid = entity_symbols.get_eid(qid) new_eid = qid2topk_eid[qid] eid2topkeid[old_eid] = new_eid assert eid2topkeid[0] == 0, f"The 0 eid shouldn't be changed" assert eid2topkeid[ -1] == num_topk_entities_with_pad_and_nocand - 1, "The -1 eid should still map to -1" self.learned_entity_embedding = nn.Embedding( num_topk_entities_with_pad_and_nocand, self.learned_embedding_size, padding_idx=-1, sparse=True) # Keep this mapping so a topK model can simply be loaded without needing the new eid mapping self.register_buffer("eid2topkeid", eid2topkeid) # If tail_init is false, all embeddings are randomly intialized. # If tail_init is true, we initialize all embeddings to be the same. self.tail_init = True self.tail_init_zeros = False # None init vec will be random init_vec = None if "tail_init" in emb_args: self.tail_init = emb_args.tail_init if "tail_init_zeros" in emb_args: self.tail_init_zeros = emb_args.tail_init_zeros self.tail_init = False init_vec = torch.zeros(1, self.learned_embedding_size) assert not (self.tail_init and self.tail_init_zeros ), f"Can only have one of tail_init or tail_init_zeros set" if self.tail_init or self.tail_init_zeros: if self.tail_init_zeros: self.logger.debug( f"All learned entity embeddings are initialized to zero.") else: self.logger.debug( f"All learned entity embeddings are initialized to the same value." ) init_vec = model_utils.init_embeddings_to_vec( self.learned_entity_embedding, pad_idx=-1, vec=init_vec) vec_save_file = os.path.join( train_utils.get_save_folder(main_args.run_config), "init_vec_entity_embs.npy") self.logger.debug(f"Saving init vector to {vec_save_file}") np.save(vec_save_file, init_vec) else: self.logger.debug( f"All learned embeddings are randomly initialized.") self._dim = main_args.model_config.hidden_size self.eid2reg = None # Regularization mapping goes from eid to 2d dropout percent if "regularize_mapping" in emb_args: self.logger.warning( f"You are using regularization mapping with a topK entity embedding. This means all QIDs that are mapped to the same" f" EID will get the same regularization value.") self.logger.debug( f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}" ) self.eid2reg = self.load_regularization_mapping( main_args, qid2topk_eid, num_topk_entities_with_pad_and_nocand, emb_args.regularize_mapping, self.logger.debug) self.eid2reg = self.eid2reg.to(model_device)
def __init__(self, args, main_args, word_symbols): super(BaseWordEmbedding, self).__init__() self.logger = logging_utils.get_logger(main_args) self._key = "word" self.pad_id = word_symbols.pad_id
def __setstate__(self, state): self.__dict__.update(state) self.logger = logging_utils.get_logger(self.main_args)
def __init__(self, args=None, entity_symbols=None, word_symbols=None, total_steps_per_epoch=0, resume_model_file="", eval_slice_names=None, model_eval=False): self.model_eval = model_eval # keep track of mode for model loading self.distributed = args.run_config.distributed self.args = args self.total_steps_per_epoch = total_steps_per_epoch self.start_epoch = 0 self.start_step = 0 self.use_cuda = not args.run_config.cpu and torch.cuda.is_available() self.logger = logging_utils.get_logger(args) if not self.use_cuda: self.model_device = "cpu" self.embedding_device = "cpu" else: self.model_device = args.run_config.gpu self.embedding_device = args.run_config.gpu # Load base model mod, load_class = import_class("bootleg", args.model_config.base_model_load_class) self.model = getattr(mod, load_class)(args=args, model_device=self.model_device, entity_symbols=entity_symbols, word_symbols=word_symbols) self.use_eval_wrapper = False if eval_slice_names is not None: self.use_eval_wrapper = True # Mapping of all output heads to indexes for the buffers head_key_to_idx = train_utils.get_head_key_to_idx(args) self.eval_wrapper = EvalWrapper( args=args, head_key_to_idx=head_key_to_idx, eval_slice_names=eval_slice_names, train_head_names=args.train_config.train_heads) self.eval_wrapper.to(self.model_device) self.optimizer = SparseDenseAdam( list(self.model.parameters()), lr=args.train_config.lr, weight_decay=args.train_config.weight_decay) self.scorer = Scorer(args, self.model_device) self.model.to(self.model_device) if self.distributed: # move everything to GPU self.model = nn.parallel.DistributedDataParallel( self.model, device_ids=[self.model_device], find_unused_parameters=True) # load model into existing model if model_file is provided if resume_model_file.endswith(".pt"): self.logger.info(f'Loading model from {resume_model_file}...') self.load(resume_model_file) self.logger.debug("Model device " + str(self.model_device)) self.logger.debug("Embedding device " + str(self.embedding_device)) self.logger.debug( f"*************************MODEL PARAMS WITH GRAD*************************" ) self.logger.debug( f'Number of model parameters with grad: {count_parameters(self.model, True, self.logger)}' ) self.logger.debug( f"*************************MODEL PARAMS WITHOUT GRAD*************************" ) self.logger.debug( f'Number of model parameters without grad: {count_parameters(self.model, False, self.logger)}' )