def __init__(self, args, model_device, entity_symbols, word_symbols): super(EmbeddingLayer, self).__init__() self.logger = logging_utils.get_logger(args) self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand # Word Embedding (passed to Sentence and Entity Embedding classess) mod, load_class = import_class( "bootleg.embeddings.word_embeddings", args.data_config.word_embedding.load_class) self.word_emb = getattr(mod, load_class)(args.data_config.word_embedding, args, word_symbols) # Sentence Embedding mod, load_class = import_class( "bootleg.embeddings.word_embeddings", args.data_config.word_embedding.sent_class) self.sent_emb = getattr(mod, load_class)(args.data_config.word_embedding, args, self.word_emb.get_dim(), word_symbols) # Entity Embedding self.entity_embs = nn.ModuleDict() self.logger.info('Loading embeddings...') for emb in args.data_config.ent_embeddings: try: emb_args = emb.args except: emb_args = None assert "load_class" in emb, "You must specify a load_class in the embedding config: {load_class: ..., key: ...}" assert "key" in emb, "You must specify a key in the embedding config: {load_class: ..., key: ...}" mod, load_class = import_class("bootleg.embeddings", emb.load_class) emb_obj = getattr(mod, load_class)(main_args=args, emb_args=emb_args, model_device=model_device, entity_symbols=entity_symbols, word_symbols=word_symbols, word_emb=self.word_emb, key=emb.key) self.entity_embs[emb.key] = emb_obj self.logger.info('Finished loading embeddings.') # Track the dimensions of different embeddings self.emb_sizes = {} for emb in self.entity_embs.values(): key = emb.key dim = emb.get_dim() assert not key in self.emb_sizes, f"Can't have duplicate keys in your embeddings and {key} is already here" self.emb_sizes[key] = dim self.sent_emb_size = self.sent_emb._dim self.project_sent = MLP(input_size=self.sent_emb_size, num_hidden_units=0, output_size=args.model_config.hidden_size, num_layers=1, dropout=0, residual=False, activation=None)
def __init__(self, args, model_device, entity_symbols, word_symbols): super(Model, self).__init__() self.model_device = model_device self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand self.logger = logging_utils.get_logger(args) # embeddings self.emb_layer = EmbeddingLayer(args, self.model_device, entity_symbols, word_symbols) self.type_pred = False if args.data_config.type_prediction.use_type_pred: self.type_pred = True # Add 1 for pad type self.type_prediction = TypePred( args.model_config.hidden_size, args.data_config.type_prediction.dim, args.data_config.type_prediction.num_types + 1) self.emb_combiner = EmbCombinerProj(args, self.emb_layer.emb_sizes, self.emb_layer.sent_emb_size, word_symbols, entity_symbols) # attention network mod, load_class = import_class("bootleg.layers.attn_networks", args.model_config.attn_load_class) self.attn_network = getattr(mod, load_class)(args, self.emb_layer.emb_sizes, self.emb_layer.sent_emb_size, entity_symbols, word_symbols) # slice heads self.slice_heads = self.get_slice_method(args, entity_symbols) self.freeze_components(args)
def create_dataset(args, data_args, is_writer, word_symbols, entity_symbols, slice_dataset=None, dataset_is_eval=False): dataset_name = generate_save_data_name( data_args=args.data_config, use_weak_label=data_args.use_weak_label, split_name=os.path.splitext(data_args.file)[0]) prep_dir = get_data_prep_dir(args) full_dataset_name = os.path.join(prep_dir, dataset_name) mod, load_class = import_class("bootleg.dataloaders", data_args.load_class) dataset = getattr(mod, load_class)(args=args, use_weak_label=data_args.use_weak_label, input_src=os.path.join( args.data_config.data_dir, data_args.file), dataset_name=full_dataset_name, is_writer=is_writer, distributed=args.run_config.distributed, word_symbols=word_symbols, entity_symbols=entity_symbols, slice_dataset=slice_dataset, dataset_is_eval=dataset_is_eval) return dataset
def create_slice_dataset(args, data_args, is_writer, dataset_is_eval): # Note that the weak labelling is going to alter our indexing for the slices. Our slices still only score gold==True dataset_name = generate_slice_name(args, args.data_config, use_weak_label=data_args.use_weak_label, split_name="slice_" + os.path.splitext(data_args.file)[0], dataset_is_eval=dataset_is_eval) prep_dir = get_data_prep_dir(args) full_dataset_name = os.path.join(prep_dir, dataset_name) mod, load_class = import_class("bootleg.dataloaders", data_args.slice_class) dataset = getattr(mod, load_class)(args=args, use_weak_label=data_args.use_weak_label, input_src=os.path.join(args.data_config.data_dir, data_args.file), dataset_name=full_dataset_name, is_writer=is_writer, distributed=args.run_config.distributed, dataset_is_eval=dataset_is_eval) return dataset
def __init__(self, config_args, device='cuda', max_alias_len=6, cand_map=None, threshold=0.0): self.args = config_args self.device = device self.entity_db = EntitySymbols( os.path.join(self.args.data_config.entity_dir, self.args.data_config.entity_map_dir), alias_cand_map_file=self.args.data_config.alias_cand_map) self.word_db = data_utils.load_wordsymbols(self.args.data_config, is_writer=True, distributed=False) self.model = self._load_model() self.max_alias_len = max_alias_len if cand_map is None: alias_map = self.entity_db._alias2qids else: alias_map = ujson.load(open(cand_map)) self.all_aliases_trie = get_all_aliases(alias_map, logger=logging.getLogger()) self.alias_table = AliasEntityTable(args=self.args, entity_symbols=self.entity_db) # minimum probability of prediction to return mention self.threshold = threshold # get batch_on_the_fly embeddings _and_ the batch_prep embeddings self.batch_on_the_fly_embs = {} for i, emb in enumerate(self.args.data_config.ent_embeddings): if 'batch_prep' in emb and emb['batch_prep'] is True: self.args.data_config.ent_embeddings[i][ 'batch_on_the_fly'] = True del self.args.data_config.ent_embeddings[i]['batch_prep'] if 'batch_on_the_fly' in emb and emb['batch_on_the_fly'] is True: mod, load_class = import_class("bootleg.embeddings", emb.load_class) try: self.batch_on_the_fly_embs[emb.key] = getattr( mod, load_class)(main_args=self.args, emb_args=emb['args'], entity_symbols=self.entity_db, model_device=None, word_symbols=None) except AttributeError as e: print( f'No prep method found for {emb.load_class} with error {e}' ) except Exception as e: print("ERROR", e)
def __init__(self, args, model_device, entity_symbols, word_symbols): super(BaselineModel, self).__init__(args, model_device, entity_symbols, word_symbols) self.model_device = model_device self.logger = logging_utils.get_logger(args) mod, load_class = import_class("bootleg.layers.attn_networks", args.model_config.attn_load_class) self.emb_layer = EmbeddingLayerNoProj(args, self.model_device, entity_symbols, word_symbols) self.attn_network = getattr(mod, load_class)(args, self.emb_layer.emb_sizes, self.emb_layer.sent_emb_size, entity_symbols, word_symbols) self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand self.freeze_components(args)
def __init__(self, args, model_device, entity_symbols, word_symbols): super(Model, self).__init__() self.model_device = model_device self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand self.logger = logging_utils.get_logger(args) # embeddings self.emb_layer = EmbeddingLayer(args, self.model_device, entity_symbols, word_symbols) self.emb_combiner = EmbCombinerProj(args, self.emb_layer.emb_sizes, self.emb_layer.sent_emb_size, word_symbols, entity_symbols) # attention network mod, load_class = import_class("bootleg.layers.attn_networks", args.model_config.attn_load_class) self.attn_network = getattr(mod, load_class)(args, self.emb_layer.emb_sizes, self.emb_layer.sent_emb_size, entity_symbols, word_symbols) # slice heads self.slice_heads = self.get_slice_method(args, entity_symbols) self.freeze_components(args)
def __init__(self, args, use_weak_label, input_src, dataset_name, is_writer, distributed, word_symbols, entity_symbols, slice_dataset=None, dataset_is_eval=False): # Need to save args to reinstantiate logger self.args = args self.logger = logging_utils.get_logger(args) # Number of candidates, including NIL if a NIL model (train_in_candidates is False) self.K = entity_symbols.max_candidates + ( not args.data_config.train_in_candidates) self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand self.dataset_name = dataset_name self.slice_dataset = slice_dataset self.dataset_is_eval = dataset_is_eval # Slice names used for eval slices and a slicing model self.slice_names = train_utils.get_data_slices(args, dataset_is_eval) self.storage_type_file = data_utils.get_storage_file(self.dataset_name) # Mappings from sent_idx to row_id in dataset self.sent_idx_file = os.path.splitext( dataset_name)[0] + "_sent_idx.json" self.type_pred = False if args.data_config.type_prediction.use_type_pred: self.type_pred = True self.eid2typeid, self.num_types_with_pad = self.load_coarse_type_table( args, entity_symbols) # Load memory mapped file self.logger.info("Loading dataset...") self.logger.debug("Seeing if " + dataset_name + " exists") if (args.data_config.overwrite_preprocessed_data or (not os.path.exists(self.dataset_name)) or (not os.path.exists(self.sent_idx_file)) or (not os.path.exists(self.storage_type_file)) or (not os.path.exists( data_utils.get_batch_prep_config(self.dataset_name)))): start = time.time() self.logger.debug(f"Building dataset with {input_src}") # Only prep data once per node if is_writer: prep_data(args, use_weak_label=use_weak_label, dataset_is_eval=self.dataset_is_eval, input_src=input_src, dataset_name=dataset_name, prep_dir=data_utils.get_data_prep_dir(args)) if distributed: # Make sure all processes wait for data to be created dist.barrier() self.logger.debug( f"Finished building and saving dataset in {round(time.time() - start, 2)}s." ) start = time.time() # Storage type for loading memory mapped file of dataset self.storage_type = pickle.load(open(self.storage_type_file, 'rb')) self.data = np.memmap(self.dataset_name, dtype=self.storage_type, mode='r') self.data_len = len(self.data) # Mapping from sentence idx to rows in the dataset (indices). # Needed when sampling sentence indices from slices for evaluation. sent_idx_to_idx_str = utils.load_json_file(self.sent_idx_file) self.sent_idx_to_idx = { int(i): val for i, val in sent_idx_to_idx_str.items() } self.logger.info(f"Finished loading dataset.") # Stores info about the batch prepped embedding memory mapped files and their shapes and datatypes # so we can load them self.batch_prep_config = utils.load_json_file( data_utils.get_batch_prep_config(self.dataset_name)) self.batch_prepped_emb_files = {} self.batch_prepped_emb_file_names = {} for emb in args.data_config.ent_embeddings: if 'batch_prep' in emb and emb['batch_prep']: assert emb.key in self.batch_prep_config, f'Need to prep {emb.key}. Please call prep instead of run with batch_prep_embeddings set to true.' self.batch_prepped_emb_file_names[emb.key] = os.path.join( os.path.dirname(self.dataset_name), os.path.basename( self.batch_prep_config[emb.key]['file_name'])) self.batch_prepped_emb_files[emb.key] = np.memmap( self.batch_prepped_emb_file_names[emb.key], dtype=self.batch_prep_config[emb.key]['dtype'], shape=tuple(self.batch_prep_config[emb.key]['shape']), mode='r') assert len(self.batch_prepped_emb_files[emb.key]) == self.data_len,\ f'Preprocessed emb data file {self.batch_prep_config[emb.key]["file_name"]} does not match length of main data file.' # Stores embeddings that we compute on the fly; these are embeddings where batch_on_the_fly is set to true. self.batch_on_the_fly_embs = {} for emb in args.data_config.ent_embeddings: if 'batch_on_the_fly' in emb and emb['batch_on_the_fly'] is True: mod, load_class = import_class("bootleg.embeddings", emb.load_class) try: self.batch_on_the_fly_embs[emb.key] = getattr( mod, load_class)(main_args=args, emb_args=emb['args'], entity_symbols=entity_symbols, model_device=None, word_symbols=None, key=emb.key) except AttributeError as e: self.logger.warning( f'No prep method found for {emb.load_class} with error {e}' ) except Exception as e: print("ERROR", e) # The data in this table shouldn't be pickled since we delete it in the class __getstate__ self.alias2entity_table = AliasEntityTable( args=args, entity_symbols=entity_symbols) # Random NIL percent self.mask_perc = args.train_config.random_nil_perc self.random_nil = False # Don't want to random mask for eval if not dataset_is_eval: # Whether to use a random NIL training regime self.random_nil = args.train_config.random_nil if self.random_nil: self.logger.info( f'Using random nils during training with {self.mask_perc} percent' )
def __init__(self, args=None, entity_symbols=None, word_symbols=None, total_steps_per_epoch=0, resume_model_file="", eval_slice_names=None, model_eval=False): self.model_eval = model_eval # keep track of mode for model loading self.distributed = args.run_config.distributed self.args = args self.total_steps_per_epoch = total_steps_per_epoch self.start_epoch = 0 self.start_step = 0 self.use_cuda = not args.run_config.cpu and torch.cuda.is_available() self.logger = logging_utils.get_logger(args) if not self.use_cuda: self.model_device = "cpu" self.embedding_device = "cpu" else: self.model_device = args.run_config.gpu self.embedding_device = args.run_config.gpu # Load base model mod, load_class = import_class("bootleg", args.model_config.base_model_load_class) self.model = getattr(mod, load_class)(args=args, model_device=self.model_device, entity_symbols=entity_symbols, word_symbols=word_symbols) self.use_eval_wrapper = False if eval_slice_names is not None: self.use_eval_wrapper = True # Mapping of all output heads to indexes for the buffers head_key_to_idx = train_utils.get_head_key_to_idx(args) self.eval_wrapper = EvalWrapper( args=args, head_key_to_idx=head_key_to_idx, eval_slice_names=eval_slice_names, train_head_names=args.train_config.train_heads) self.eval_wrapper.to(self.model_device) self.optimizer = SparseDenseAdam( list(self.model.parameters()), lr=args.train_config.lr, weight_decay=args.train_config.weight_decay) self.scorer = Scorer(args, self.model_device) self.model.to(self.model_device) if self.distributed: # move everything to GPU self.model = nn.parallel.DistributedDataParallel( self.model, device_ids=[self.model_device], find_unused_parameters=True) # load model into existing model if model_file is provided if resume_model_file.endswith(".pt"): self.logger.info(f'Loading model from {resume_model_file}...') self.load(resume_model_file) self.logger.debug("Model device " + str(self.model_device)) self.logger.debug("Embedding device " + str(self.embedding_device)) self.logger.debug( f"*************************MODEL PARAMS WITH GRAD*************************" ) self.logger.debug( f'Number of model parameters with grad: {count_parameters(self.model, True, self.logger)}' ) self.logger.debug( f"*************************MODEL PARAMS WITHOUT GRAD*************************" ) self.logger.debug( f'Number of model parameters without grad: {count_parameters(self.model, False, self.logger)}' )
def get_dataloader_embeddings(main_args, entity_symbols): """Gets KG embeddings that need to be processed in the __get_item__ method of a dataset (e.g., querying a sparce numpy matrix). We save, for each KG embedding class that needs this preprocessing, the adjacency matrix (for KG connections), the processing function to run in __get_item__, and the file to load the adj matrix for dumping/loading. Args: main_args: main arguments entity_symbols: entity symbols Returns: Dict of KG metadata for using in the __get_item__ method. """ batch_on_the_fly_kg_adj = {} for emb in main_args.data_config.ent_embeddings: batch_on_fly = "batch_on_the_fly" in emb and emb[ "batch_on_the_fly"] is True # Find embeddings that have a "batch of the fly" key if batch_on_fly: log_rank_0_debug( logger, f"Loading class {emb.load_class} for preprocessing as on the fly or in data prep embeddings", ) ( cpu, dropout1d_perc, dropout2d_perc, emb_args, freeze, normalize, through_bert, ) = embedding_utils.get_embedding_args(emb) try: # Load the object mod, load_class = import_class("bootleg.embeddings", emb.load_class) kg_class = getattr(mod, load_class)( main_args=main_args, emb_args=emb_args, entity_symbols=entity_symbols, key=emb.key, cpu=cpu, normalize=normalize, dropout1d_perc=dropout1d_perc, dropout2d_perc=dropout2d_perc, ) # Extract its kg adj, we'll use this later # Extract the kg_adj_process_func (how to process the embeddings in __get_item__ or dataset prep) # Extract the prep_file. We use this to load the kg_adj back after # saving/loading state using scipy.sparse.load_npz(prep_file) assert hasattr( kg_class, "kg_adj" ), f"The embedding class {emb.key} does not have a kg_adj attribute and it needs to." assert hasattr( kg_class, "kg_adj_process_func" ), f"The embedding class {emb.key} does not have a kg_adj_process_func attribute and it needs to." assert hasattr(kg_class, "prep_file"), ( f"The embedding class {emb.key} does not have a prep_file attribute and it needs to. We will call" f" `scipy.sparse.load_npz(prep_file)` to load the kg_adj matrix." ) batch_on_the_fly_kg_adj[emb.key] = { "kg_adj": kg_class.kg_adj, "kg_adj_process_func": kg_class.kg_adj_process_func, "prep_file": kg_class.prep_file, } except AttributeError as e: logger.warning( f"No prep method found for {emb.load_class} with error {e}" ) raise except Exception as e: print("ERROR", e) raise return batch_on_the_fly_kg_adj