def learn(self, model: EmmentalModel, dataloaders: List[EmmentalDataLoader]) -> None: """Learning procedure of emmental MTL. Args: model: The emmental model that needs to learn. dataloaders: A list of dataloaders used to learn the model. """ # Generate the list of dataloaders for learning process start_time = time.time() train_split = Meta.config["learner_config"]["train_split"] if isinstance(train_split, str): train_split = [train_split] train_dataloaders = [ dataloader for dataloader in dataloaders if dataloader.split in train_split ] if not train_dataloaders: raise ValueError( f"Cannot find the specified train_split " f'{Meta.config["learner_config"]["train_split"]} in dataloaders.' ) # Set up task_scheduler self._set_task_scheduler() # Calculate the total number of batches per epoch self.n_batches_per_epoch = self.task_scheduler.get_num_batches( train_dataloaders) # Set up logging manager self._set_logging_manager() # Set up optimizer self._set_optimizer(model) # Set up lr_scheduler self._set_lr_scheduler(model) if Meta.config["learner_config"]["fp16"]: try: from apex import amp # type: ignore except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to " "use fp16 training.") logger.info( f"Modeling training with 16-bit (mixed) precision " f"and {Meta.config['learner_config']['fp16_opt_level']} opt level." ) model, self.optimizer = amp.initialize( model, self.optimizer, opt_level=Meta.config["learner_config"]["fp16_opt_level"], ) # Multi-gpu training (after apex fp16 initialization) if (Meta.config["learner_config"]["local_rank"] == -1 and Meta.config["model_config"]["dataparallel"]): model._to_dataparallel() # Distributed training (after apex fp16 initialization) if Meta.config["learner_config"]["local_rank"] != -1: model._to_distributed_dataparallel() # Set to training mode model.train() if Meta.config["meta_config"]["verbose"]: logger.info("Start learning...") self.metrics: Dict[str, float] = dict() self._reset_losses() # Set gradients of all model parameters to zero self.optimizer.zero_grad() for epoch_num in range(Meta.config["learner_config"]["n_epochs"]): batches = tqdm( enumerate( self.task_scheduler.get_batches(train_dataloaders, model)), total=self.n_batches_per_epoch, disable=(not Meta.config["meta_config"]["verbose"] or Meta.config["learner_config"]["local_rank"] not in [-1, 0]), desc=f"Epoch {epoch_num}:", ) for batch_num, batch in batches: # Covert single batch into a batch list if not isinstance(batch, list): batch = [batch] total_batch_num = epoch_num * self.n_batches_per_epoch + batch_num batch_size = 0 for uids, X_dict, Y_dict, task_to_label_dict, data_name, split in batch: batch_size += len(next(iter(Y_dict.values()))) # Perform forward pass and calcualte the loss and count uid_dict, loss_dict, prob_dict, gold_dict = model( uids, X_dict, Y_dict, task_to_label_dict) # Update running loss and count for task_name in uid_dict.keys(): identifier = f"{task_name}/{data_name}/{split}" self.running_uids[identifier].extend( uid_dict[task_name]) self.running_losses[identifier] += ( loss_dict[task_name].item() * len(uid_dict[task_name]) if len(loss_dict[task_name].size()) == 0 else torch.sum(loss_dict[task_name]).item()) self.running_probs[identifier].extend( prob_dict[task_name]) self.running_golds[identifier].extend( gold_dict[task_name]) # Skip the backward pass if no loss is calcuated if not loss_dict: continue # Calculate the average loss loss = sum([ model.weights[task_name] * task_loss if len(task_loss.size()) == 0 else torch.mean(model.weights[task_name] * task_loss) for task_name, task_loss in loss_dict.items() ]) # Perform backward pass to calculate gradients if Meta.config["learner_config"]["fp16"]: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() # type: ignore if (total_batch_num + 1) % Meta.config["learner_config"]["optimizer_config"][ "gradient_accumulation_steps"] == 0 or ( batch_num + 1 == self.n_batches_per_epoch and epoch_num + 1 == Meta.config["learner_config"]["n_epochs"]): # Clip gradient norm if Meta.config["learner_config"]["optimizer_config"][ "grad_clip"]: if Meta.config["learner_config"]["fp16"]: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optimizer), Meta.config["learner_config"] ["optimizer_config"]["grad_clip"], ) else: torch.nn.utils.clip_grad_norm_( model.parameters(), Meta.config["learner_config"] ["optimizer_config"]["grad_clip"], ) # Update the parameters self.optimizer.step() # Set gradients of all model parameters to zero self.optimizer.zero_grad() if Meta.config["learner_config"]["local_rank"] in [-1, 0]: self.metrics.update( self._logging(model, dataloaders, batch_size)) batches.set_postfix(self.metrics) # Update lr using lr scheduler self._update_lr_scheduler(model, total_batch_num, self.metrics) if Meta.config["learner_config"]["local_rank"] in [-1, 0]: model = self.logging_manager.close(model) logger.info( f"Total learning time: {time.time() - start_time} seconds.")
def learn(self, model: EmmentalModel, dataloaders: List[EmmentalDataLoader]) -> None: """Learning procedure of emmental MTL. Args: model: The emmental model that needs to learn. dataloaders: A list of dataloaders used to learn the model. """ start_time = time.time() # Generate the list of dataloaders for learning process train_split = Meta.config["learner_config"]["train_split"] if isinstance(train_split, str): train_split = [train_split] train_dataloaders = [ dataloader for dataloader in dataloaders if dataloader.split in train_split ] if not train_dataloaders: raise ValueError( f"Cannot find the specified train_split " f'{Meta.config["learner_config"]["train_split"]} in dataloaders.' ) # Set up task_scheduler self._set_task_scheduler() # Calculate the total number of batches per epoch self.n_batches_per_epoch: int = self.task_scheduler.get_num_batches( train_dataloaders) if self.n_batches_per_epoch == 0: logger.info("No batches in training dataloaders, existing...") return # Set up learning counter self._set_learning_counter() # Set up logging manager self._set_logging_manager() # Set up wandb watch model if (Meta.config["logging_config"]["writer_config"]["writer"] == "wandb" and Meta.config["logging_config"]["writer_config"] ["wandb_watch_model"]): if Meta.config["logging_config"]["writer_config"][ "wandb_model_watch_freq"]: wandb.watch( model, log_freq=Meta.config["logging_config"]["writer_config"] ["wandb_model_watch_freq"], ) else: wandb.watch(model) # Set up optimizer self._set_optimizer(model) # Set up lr_scheduler self._set_lr_scheduler(model) if Meta.config["learner_config"]["fp16"]: try: from apex import amp # type: ignore except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to " "use fp16 training.") logger.info( f"Modeling training with 16-bit (mixed) precision " f"and {Meta.config['learner_config']['fp16_opt_level']} opt level." ) model, self.optimizer = amp.initialize( model, self.optimizer, opt_level=Meta.config["learner_config"]["fp16_opt_level"], ) # Multi-gpu training (after apex fp16 initialization) if (Meta.config["learner_config"]["local_rank"] == -1 and Meta.config["model_config"]["dataparallel"]): model._to_dataparallel() # Distributed training (after apex fp16 initialization) if Meta.config["learner_config"]["local_rank"] != -1: model._to_distributed_dataparallel() # Set to training mode model.train() if Meta.config["meta_config"]["verbose"]: logger.info("Start learning...") self.metrics: Dict[str, float] = dict() self._reset_losses() # Set gradients of all model parameters to zero self.optimizer.zero_grad() batch_iterator = self.task_scheduler.get_batches( train_dataloaders, model) for epoch_num in range(self.start_epoch, self.end_epoch): for train_dataloader in train_dataloaders: # Set epoch for distributed sampler if isinstance(train_dataloader, DataLoader) and isinstance( train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch_num) step_pbar = tqdm( range(self.start_step, self.end_step), desc=f"Step {self.start_step + 1}/{self.end_step}" if self.use_step_base_counter else f"Epoch {epoch_num + 1}/{self.end_epoch}", disable=not Meta.config["meta_config"]["verbose"] or Meta.config["learner_config"]["local_rank"] not in [-1, 0], ) for step_num in step_pbar: if self.use_step_base_counter: step_pbar.set_description( f"Step {step_num + 1}/{self.total_steps}") step_pbar.refresh() try: batch = next(batch_iterator) except StopIteration: batch_iterator = self.task_scheduler.get_batches( train_dataloaders, model) batch = next(batch_iterator) # Check if skip the current batch if epoch_num < self.start_train_epoch or ( epoch_num == self.start_train_epoch and step_num < self.start_train_step): continue # Covert single batch into a batch list if not isinstance(batch, list): batch = [batch] total_step_num = epoch_num * self.n_batches_per_epoch + step_num batch_size = 0 for _batch in batch: batch_size += len(_batch.uids) # Perform forward pass and calcualte the loss and count uid_dict, loss_dict, prob_dict, gold_dict = model( _batch.uids, _batch.X_dict, _batch.Y_dict, _batch.task_to_label_dict, return_probs=Meta.config["learner_config"] ["online_eval"], return_action_outputs=False, ) # Update running loss and count for task_name in uid_dict.keys(): identifier = f"{task_name}/{_batch.data_name}/{_batch.split}" self.running_uids[identifier].extend( uid_dict[task_name]) self.running_losses[identifier] += ( loss_dict[task_name].item() * len(uid_dict[task_name]) if len(loss_dict[task_name].size()) == 0 else torch.sum(loss_dict[task_name]).item() ) * model.task_weights[task_name] if (Meta.config["learner_config"]["online_eval"] and prob_dict and gold_dict): self.running_probs[identifier].extend( prob_dict[task_name]) self.running_golds[identifier].extend( gold_dict[task_name]) # Calculate the average loss loss = sum([ model.task_weights[task_name] * task_loss if len(task_loss.size()) == 0 else torch.mean(model.task_weights[task_name] * task_loss) for task_name, task_loss in loss_dict.items() ]) # Perform backward pass to calculate gradients if Meta.config["learner_config"]["fp16"]: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() # type: ignore if (total_step_num + 1) % Meta.config["learner_config"]["optimizer_config"][ "gradient_accumulation_steps"] == 0 or ( step_num + 1 == self.end_step and epoch_num + 1 == self.end_epoch): # Clip gradient norm if Meta.config["learner_config"]["optimizer_config"][ "grad_clip"]: if Meta.config["learner_config"]["fp16"]: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optimizer), Meta.config["learner_config"] ["optimizer_config"]["grad_clip"], ) else: torch.nn.utils.clip_grad_norm_( model.parameters(), Meta.config["learner_config"] ["optimizer_config"]["grad_clip"], ) # Update the parameters self.optimizer.step() # Set gradients of all model parameters to zero self.optimizer.zero_grad() if Meta.config["learner_config"]["local_rank"] in [-1, 0]: self.metrics.update( self._logging(model, dataloaders, batch_size)) step_pbar.set_postfix(self.metrics) # Update lr using lr scheduler self._update_lr_scheduler(model, total_step_num, self.metrics) step_pbar.close() if Meta.config["learner_config"]["local_rank"] in [-1, 0]: model = self.logging_manager.close(model) logger.info( f"Total learning time: {time.time() - start_time} seconds.")
def run_model(mode, config, run_config_path=None): """ Main run method for Emmental Bootleg models. Args: mode: run mode (train, eval, dump_preds, dump_embs) config: parsed model config run_config_path: original config path (for saving) Returns: """ # Set up distributed backend and save configuration files setup(config, run_config_path) # Load entity symbols log_rank_0_info(logger, f"Loading entity symbols...") entity_symbols = EntitySymbols.load_from_cache( load_dir=os.path.join(config.data_config.entity_dir, config.data_config.entity_map_dir), alias_cand_map_file=config.data_config.alias_cand_map, alias_idx_file=config.data_config.alias_idx_map, ) # Create tasks tasks = [NED_TASK] if config.data_config.type_prediction.use_type_pred is True: tasks.append(TYPE_PRED_TASK) # Create splits for data loaders data_splits = [TRAIN_SPLIT, DEV_SPLIT, TEST_SPLIT] # Slices are for eval so we only split on test/dev slice_splits = [DEV_SPLIT, TEST_SPLIT] # If doing eval, only run on test data if mode in ["eval", "dump_preds", "dump_embs"]: data_splits = [TEST_SPLIT] slice_splits = [TEST_SPLIT] # We only do dumping if weak labels is True if mode in ["dump_preds", "dump_embs"]: if config.data_config[ f"{TEST_SPLIT}_dataset"].use_weak_label is False: raise ValueError( f"When calling dump_preds or dump_embs, we require use_weak_label to be True." ) # Gets embeddings that need to be prepped during data prep or in the __get_item__ method batch_on_the_fly_kg_adj = get_dataloader_embeddings(config, entity_symbols) # Gets dataloaders dataloaders = get_dataloaders( config, tasks, data_splits, entity_symbols, batch_on_the_fly_kg_adj, ) slice_datasets = get_slicedatasets(config, slice_splits, entity_symbols) configure_optimizer(config) # Create models and add tasks if config.model_config.attn_class == "BERTNED": log_rank_0_info(logger, f"Starting NED-Base Model") assert (config.data_config.type_prediction.use_type_pred is False), f"NED-Base does not support type prediction" assert ( config.data_config.word_embedding.use_sent_proj is False ), f"NED-Base requires word_embeddings.use_sent_proj to be False" model = EmmentalModel(name="NED-Base") model.add_tasks( ned_task.create_task(config, entity_symbols, slice_datasets)) else: log_rank_0_info(logger, f"Starting Bootleg Model") model = EmmentalModel(name="Bootleg") # TODO: make this more general for other tasks -- iterate through list of tasks # and add task for each model.add_task( ned_task.create_task(config, entity_symbols, slice_datasets)) if TYPE_PRED_TASK in tasks: model.add_task( type_pred_task.create_task(config, entity_symbols, slice_datasets)) # Add the mention type embedding to the embedding payload type_pred_task.update_ned_task(model) # Print param counts if mode == "train": log_rank_0_debug(logger, "PARAMS WITH GRAD\n" + "=" * 30) total_params = count_parameters(model, requires_grad=True, logger=logger) log_rank_0_info(logger, f"===> Total Params With Grad: {total_params}") log_rank_0_debug(logger, "PARAMS WITHOUT GRAD\n" + "=" * 30) total_params = count_parameters(model, requires_grad=False, logger=logger) log_rank_0_info(logger, f"===> Total Params Without Grad: {total_params}") # Load the best model from the pretrained model if config["model_config"]["model_path"] is not None: model.load(config["model_config"]["model_path"]) # Barrier if config["learner_config"]["local_rank"] == 0: torch.distributed.barrier() # Train model if mode == "train": emmental_learner = EmmentalLearner() emmental_learner._set_optimizer(model) emmental_learner.learn(model, dataloaders) if config.learner_config.local_rank in [0, -1]: model.save(f"{emmental.Meta.log_path}/last_model.pth") # Multi-gpu DataParallel eval (NOT distributed) if mode in ["eval", "dump_embs", "dump_preds"]: # This happens inside EmmentalLearner for training if (config["learner_config"]["local_rank"] == -1 and config["model_config"]["dataparallel"]): model._to_dataparallel() # If just finished training a model or in eval mode, run eval if mode in ["train", "eval"]: scores = model.score(dataloaders) # Save metrics and models log_rank_0_info(logger, f"Saving metrics to {emmental.Meta.log_path}") log_rank_0_info(logger, f"Metrics: {scores}") scores["log_path"] = emmental.Meta.log_path if config.learner_config.local_rank in [0, -1]: write_to_file(f"{emmental.Meta.log_path}/{mode}_metrics.txt", scores) eval_utils.write_disambig_metrics_to_csv( f"{emmental.Meta.log_path}/{mode}_disambig_metrics.csv", scores) return scores # If you want detailed dumps, save model outputs assert mode in [ "dump_preds", "dump_embs", ], 'Mode must be "dump_preds" or "dump_embs"' dump_embs = False if mode != "dump_embs" else True assert ( len(dataloaders) == 1 ), f"We should only have length 1 dataloaders for dump_embs and dump_preds!" final_result_file, final_out_emb_file = None, None if config.learner_config.local_rank in [0, -1]: # Setup files/folders filename = os.path.basename(dataloaders[0].dataset.raw_filename) log_rank_0_debug( logger, f"Collecting sentence to mention map {os.path.join(config.data_config.data_dir, filename)}", ) sentidx2num_mentions, sent_idx2row = eval_utils.get_sent_idx2num_mens( os.path.join(config.data_config.data_dir, filename)) log_rank_0_debug(logger, f"Done collecting sentence to mention map") eval_folder = eval_utils.get_eval_folder(filename) subeval_folder = os.path.join(eval_folder, "batch_results") utils.ensure_dir(subeval_folder) # Will keep track of sentences dumped already. These will only be ones with mentions all_dumped_sentences = set() number_dumped_batches = 0 total_mentions_seen = 0 all_result_files = [] all_out_emb_files = [] # Iterating over batches of predictions for res_i, res_dict in enumerate( eval_utils.batched_pred_iter( model, dataloaders[0], config.run_config.eval_accumulation_steps, sentidx2num_mentions, )): ( result_file, out_emb_file, final_sent_idxs, mentions_seen, ) = eval_utils.disambig_dump_preds( res_i, total_mentions_seen, config, res_dict, sentidx2num_mentions, sent_idx2row, subeval_folder, entity_symbols, dump_embs, NED_TASK, ) all_dumped_sentences.update(final_sent_idxs) all_result_files.append(result_file) all_out_emb_files.append(out_emb_file) total_mentions_seen += mentions_seen number_dumped_batches += 1 # Dump the sentences that had no mentions and were not already dumped # Assert all remaining sentences have no mentions assert all( v == 0 for k, v in sentidx2num_mentions.items() if k not in all_dumped_sentences ), (f"Sentences with mentions were not dumped: " f"{[k for k, v in sentidx2num_mentions.items() if k not in all_dumped_sentences]}" ) empty_sentidx2row = { k: v for k, v in sent_idx2row.items() if k not in all_dumped_sentences } empty_resultfile = eval_utils.get_result_file(number_dumped_batches, subeval_folder) all_result_files.append(empty_resultfile) # Dump the outputs eval_utils.write_data_labels_single( sentidx2row=empty_sentidx2row, output_file=empty_resultfile, filt_emb_data=None, sental2embid={}, alias_cand_map=entity_symbols.get_alias2qids(), qid2eid=entity_symbols.get_qid2eid(), result_alias_offset=total_mentions_seen, train_in_cands=config.data_config.train_in_candidates, max_cands=entity_symbols.max_candidates, dump_embs=dump_embs, ) log_rank_0_info( logger, f"Finished dumping. Merging results across accumulation steps.") # Final result files for labels and embeddings final_result_file = os.path.join(eval_folder, config.run_config.result_label_file) # Copy labels output = open(final_result_file, "wb") for file in all_result_files: shutil.copyfileobj(open(file, "rb"), output) output.close() log_rank_0_info(logger, f"Bootleg labels saved at {final_result_file}") # Try to copy embeddings if dump_embs: final_out_emb_file = os.path.join( eval_folder, config.run_config.result_emb_file) log_rank_0_info( logger, f"Trying to merge numpy embedding arrays. " f"If your machine is limited in memory, this may cause OOM errors. " f"Is that happens, result files should be saved in {subeval_folder}.", ) all_arrays = [] for i, npfile in enumerate(all_out_emb_files): all_arrays.append(np.load(npfile)) np.save(final_out_emb_file, np.concatenate(all_arrays)) log_rank_0_info( logger, f"Bootleg embeddings saved at {final_out_emb_file}") # Cleanup try_rmtree(subeval_folder) return final_result_file, final_out_emb_file