def __init__( self, model: Model, meta_optimizer: torch.optim.Optimizer, optimizer_cls: str, optimizer_kwargs: Dict[str, Any], grad_norm: Optional[float] = None, grad_clipping: Optional[float] = None, update_hook: Callable = None, inherit: bool = False, loss_ratios_per_step: List[Dict[str, int]] = None, ): super(BaseWrapper, self).__init__() self.model = model self.meta_optimizer = meta_optimizer self._grad_clipping = grad_clipping self._grad_norm = grad_norm self._container = deepcopy(self.model) training_util.enable_gradient_clipping(self.model, self._grad_clipping) self.optimizer_cls = getattr(torch.optim, optimizer_cls) self.optimizer_kwargs = optimizer_kwargs self._update_hook = update_hook def forward_kwargs(step): ratios = {"dep": 1.0, "pos": 0.0} if loss_ratios_per_step is not None: ratios = loss_ratios_per_step[step] return {'return_metric': True, 'loss_ratios': ratios} self.forward_kwargs = forward_kwargs self._inherit = inherit
def train(self) -> Dict[str, Any]: """ Trains the supplied model with the supplied parameters. """ try: epoch_counter = self._restore_checkpoint() except RuntimeError: traceback.print_exc() raise ConfigurationError("Could not recover training from the checkpoint. Did you mean to output to " "a different serialization directory or delete the existing serialization " "directory?") training_util.enable_gradient_clipping(self.model, self._grad_clipping) logger.info("Beginning training.") train_metrics: Dict[str, float] = {} val_metrics: Dict[str, float] = {} this_epoch_val_metric: float = None metrics: Dict[str, Any] = {} epochs_trained = 0 training_start_time = time.time() metrics['best_epoch'] = self._metric_tracker.best_epoch for key, value in self._metric_tracker.best_epoch_metrics.items(): metrics["best_validation_" + key] = value if self.callbacks is not None: with torch.no_grad(): for callback in self.callbacks: callback.on_train_begin() for epoch in range(epoch_counter, self._num_epochs): epoch_start_time = time.time() if self.callbacks is not None: with torch.no_grad(): for callback in self.callbacks: callback.on_epoch_begin(epoch) train_metrics = self._train_epoch(epoch) if not self._early_stopping_by_batch: # get peak of memory usage if 'cpu_memory_MB' in train_metrics: metrics['peak_cpu_memory_MB'] = max(metrics.get('peak_cpu_memory_MB', 0), train_metrics['cpu_memory_MB']) for key, value in train_metrics.items(): if key.startswith('gpu_'): metrics["peak_"+key] = max(metrics.get("peak_"+key, 0), value) if self._validation_data is not None: with torch.no_grad(): val_metrics_temp = self._estimator.estimate(self._validation_data) # We have a validation set, so compute all the metrics on it. # val_loss, num_batches = self._validation_loss() # val_metrics = training_util.get_metrics(self.model, val_loss, num_batches, reset=True) val_metrics = {'loss': 0} if 'sentiment_acc' in val_metrics_temp: val_metrics['accuracy'] = val_metrics_temp['sentiment_acc'] if 'category_f1' in val_metrics_temp: val_metrics['category_f1'] = val_metrics_temp['category_f1']['fscore'] if 'other_metrics' in val_metrics_temp and 'merge_micro_f1' in val_metrics_temp['other_metrics']: val_metrics['merge_micro_f1'] = val_metrics_temp['other_metrics']['merge_micro_f1'] # Check validation metric for early stopping val_metrics.update(val_metrics_temp) this_epoch_val_metric = val_metrics[self._validation_metric] self._metric_tracker.add_metric(this_epoch_val_metric) if self._metric_tracker.should_stop_early(): logger.info("Ran out of patience. Stopping training.") break self._tensorboard.log_metrics(train_metrics, val_metrics=val_metrics, log_to_console=True, epoch=epoch + 1) # +1 because tensorboard doesn't like 0 # Create overall metrics dict training_elapsed_time = time.time() - training_start_time metrics["training_duration"] = str(datetime.timedelta(seconds=training_elapsed_time)) metrics["training_start_epoch"] = epoch_counter metrics["training_epochs"] = epochs_trained metrics["epoch"] = epoch for key, value in train_metrics.items(): metrics["training_" + key] = value for key, value in val_metrics.items(): metrics["validation_" + key] = value if self._metric_tracker.is_best_so_far(): # Update all the best_ metrics. # (Otherwise they just stay the same as they were.) metrics['best_epoch'] = epoch for key, value in val_metrics.items(): metrics["best_validation_" + key] = value self._metric_tracker.best_epoch_metrics = val_metrics if self._serialization_dir: dump_metrics(os.path.join(self._serialization_dir, f'metrics_epoch_{epoch}.json'), metrics) # The Scheduler API is agnostic to whether your schedule requires a validation metric - # if it doesn't, the validation metric passed here is ignored. if self._learning_rate_scheduler: self._learning_rate_scheduler.step(this_epoch_val_metric, epoch) if self._momentum_scheduler: self._momentum_scheduler.step(this_epoch_val_metric, epoch) self._save_checkpoint(epoch) else: if self._metric_tracker.should_stop_early(): logger.info("Ran out of patience. Stopping training.") break epoch_elapsed_time = time.time() - epoch_start_time logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time)) if epoch < self._num_epochs - 1: training_elapsed_time = time.time() - training_start_time estimated_time_remaining = training_elapsed_time * \ ((self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1) formatted_time = str(datetime.timedelta(seconds=int(estimated_time_remaining))) logger.info("Estimated training time remaining: %s", formatted_time) if self.callbacks is not None: with torch.no_grad(): for callback in self.callbacks: callback.on_epoch_end(epoch) epochs_trained += 1 # make sure pending events are flushed to disk and files are closed properly # self._tensorboard.close() # Load the best model state before returning best_model_state = self._checkpointer.best_model_state() if best_model_state: self.model.load_state_dict(best_model_state) return metrics
def enable_gradient_clipping(self, trainer: 'CallbackTrainer'): training_util.enable_gradient_clipping(trainer.model, self.grad_clipping)
def train(self) -> Dict[str, Any]: """ Trains the supplied model with the supplied parameters. """ try: epoch_counter = self._restore_checkpoint() except RuntimeError: traceback.print_exc() raise ConfigurationError( "Could not recover training from the checkpoint. Did you mean to output to " "a different serialization directory or delete the existing serialization " "directory?") training_util.enable_gradient_clipping(self.model, self._grad_clipping) logger.info("Beginning training.") val_metrics: Dict[str, float] = {} this_epoch_val_metric: float = None metrics: Dict[str, Any] = {} epochs_trained = 0 training_start_time = time.time() metrics["best_epoch"] = self._metric_tracker.best_epoch for key, value in self._metric_tracker.best_epoch_metrics.items(): metrics["best_validation_" + key] = value for callback in self._epoch_callbacks: callback(self, metrics={}, epoch=-1) for epoch in range(epoch_counter, self._num_epochs): epoch_start_time = time.time() train_metrics = self._train_epoch(epoch) # get peak of memory usage if "cpu_memory_MB" in train_metrics: metrics["peak_cpu_memory_MB"] = max( metrics.get("peak_cpu_memory_MB", 0), train_metrics["cpu_memory_MB"]) for key, value in train_metrics.items(): if key.startswith("gpu_"): metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) if self._validation_data_loader is not None: with torch.no_grad(): # We have a validation set, so compute all the metrics on it. val_loss, val_reg_loss, num_batches = self._validation_loss( epoch) # It is safe again to wait till the validation is done. This is # important to get the metrics right. if self._distributed: dist.barrier() val_metrics = training_util.get_metrics( self.model, val_loss, val_reg_loss, num_batches, reset=True, world_size=self._world_size, cuda_device=[self.cuda_device], ) # Check validation metric for early stopping this_epoch_val_metric = val_metrics[ self._validation_metric] self._metric_tracker.add_metric(this_epoch_val_metric) if self._metric_tracker.should_stop_early(): logger.info("Ran out of patience. Stopping training.") break if self._master: self._tensorboard.log_metrics( train_metrics, val_metrics=val_metrics, log_to_console=True, epoch=epoch + 1) # +1 because tensorboard doesn't like 0 # Create overall metrics dict training_elapsed_time = time.time() - training_start_time metrics["training_duration"] = str( datetime.timedelta(seconds=training_elapsed_time)) metrics["training_start_epoch"] = epoch_counter metrics["training_epochs"] = epochs_trained metrics["epoch"] = epoch for key, value in train_metrics.items(): metrics["training_" + key] = value for key, value in val_metrics.items(): metrics["validation_" + key] = value if self._metric_tracker.is_best_so_far(): # Update all the best_ metrics. # (Otherwise they just stay the same as they were.) metrics["best_epoch"] = epoch for key, value in val_metrics.items(): metrics["best_validation_" + key] = value self._metric_tracker.best_epoch_metrics = val_metrics if self._serialization_dir and self._master: common_util.dump_metrics( os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"), metrics) # The Scheduler API is agnostic to whether your schedule requires a validation metric - # if it doesn't, the validation metric passed here is ignored. if self._learning_rate_scheduler: self._learning_rate_scheduler.step(this_epoch_val_metric) if self._momentum_scheduler: self._momentum_scheduler.step(this_epoch_val_metric) if self._master: self._checkpointer.save_checkpoint( epoch, self, is_best_so_far=self._metric_tracker.is_best_so_far()) # Wait for the master to finish saving the checkpoint if self._distributed: dist.barrier() for callback in self._epoch_callbacks: callback(self, metrics=metrics, epoch=epoch) epoch_elapsed_time = time.time() - epoch_start_time logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time)) if epoch < self._num_epochs - 1: training_elapsed_time = time.time() - training_start_time estimated_time_remaining = training_elapsed_time * ( (self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1) formatted_time = str( datetime.timedelta(seconds=int(estimated_time_remaining))) logger.info("Estimated training time remaining: %s", formatted_time) epochs_trained += 1 # make sure pending events are flushed to disk and files are closed properly self._tensorboard.close() # Load the best model state before returning best_model_state = self._checkpointer.best_model_state() if best_model_state: self.model.load_state_dict(best_model_state) return metrics
def _try_train(self) -> Tuple[Dict[str, Any], int]: try: epoch_counter = self._restore_checkpoint() except RuntimeError: traceback.print_exc() raise ConfigurationError( "Could not recover training from the checkpoint. Did you mean to output to " "a different serialization directory or delete the existing serialization " "directory?" ) training_util.enable_gradient_clipping(self.model, self._grad_clipping) logger.info("Beginning training.") val_metrics: Dict[str, float] = {} metrics: Dict[str, Any] = {} epochs_trained = 0 training_start_time = time.time() metrics["best_epoch"] = self._metric_tracker.best_epoch for key, value in self._metric_tracker.best_epoch_metrics.items(): metrics["best_validation_" + key] = value for epoch in range(epoch_counter, self._num_epochs): epoch_start_time = time.time() train_metrics = self._train_epoch(epoch) # Back up the model now, in case something goes wrong later with the evaluation if self._primary and self._checkpointer is not None: self._checkpointer.shelve_model(epoch, self) # Wait for the primary process to finish saving the model checkpoint if self._distributed: dist.barrier() # get peak of memory usage for key, value in train_metrics.items(): if key.startswith("gpu_") and key.endswith("_memory_MB"): metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) elif key.startswith("worker_") and key.endswith("_memory_MB"): metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) this_epoch_val_metric: float = 0.0 if self._validation_data_loader is not None: with torch.no_grad(): # We have a validation set, so compute all the metrics on it. val_loss, val_reg_loss, num_batches = self._validation_loss(epoch) # It is safe again to wait till the validation is done. This is # important to get the metrics right. if self._distributed: dist.barrier() val_metrics = training_util.get_metrics( self.model, val_loss, val_reg_loss, batch_loss=None, batch_reg_loss=None, num_batches=num_batches, reset=True, world_size=self._world_size, cuda_device=self.cuda_device, ) # Check validation metric for early stopping this_epoch_val_metric = self._metric_tracker.combined_score(val_metrics) self._metric_tracker.add_metrics(val_metrics) # Create overall metrics dict training_elapsed_time = time.time() - training_start_time metrics["training_duration"] = str(datetime.timedelta(seconds=training_elapsed_time)) metrics["training_start_epoch"] = epoch_counter metrics["training_epochs"] = epochs_trained metrics["epoch"] = epoch for key, value in train_metrics.items(): metrics["training_" + key] = value for key, value in val_metrics.items(): metrics["validation_" + key] = value if self._metric_tracker.is_best_so_far(): # Update all the best_ metrics. # (Otherwise they just stay the same as they were.) metrics["best_epoch"] = epoch for key, value in val_metrics.items(): metrics["best_validation_" + key] = value self._metric_tracker.best_epoch_metrics = val_metrics if self._serialization_dir and self._primary: common_util.dump_metrics( os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"), metrics, ) # The Scheduler API is agnostic to whether your schedule requires a validation metric - # if it doesn't, the validation metric passed here is ignored. if self._learning_rate_scheduler: self._learning_rate_scheduler.step(this_epoch_val_metric) if self._momentum_scheduler: self._momentum_scheduler.step(this_epoch_val_metric) # The checkpointer saves state from the learning rate scheduler and the momentum # scheduler, so we have to make sure those are updated before we save the checkpoint here. if self._primary and self._checkpointer is not None: self._checkpointer.save_checkpoint( epoch, self, is_best_so_far=self._metric_tracker.is_best_so_far() ) # Wait for the primary process to finish saving the checkpoint if self._distributed: dist.barrier() for callback in self._callbacks: callback.on_epoch(self, metrics=metrics, epoch=epoch, is_primary=self._primary) epoch_elapsed_time = time.time() - epoch_start_time logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time)) if epoch < self._num_epochs - 1: training_elapsed_time = time.time() - training_start_time estimated_time_remaining = training_elapsed_time * ( (self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1 ) formatted_time = str(datetime.timedelta(seconds=int(estimated_time_remaining))) logger.info("Estimated training time remaining: %s", formatted_time) epochs_trained += 1 if self._metric_tracker.should_stop_early(): logger.info("Ran out of patience. Stopping training.") break else: epoch = self._num_epochs - 1 # Load the best model state before returning best_model_state = ( None if self._checkpointer is None else self._checkpointer.best_model_state() ) if best_model_state: self.model.load_state_dict(best_model_state) return metrics, epoch
def train(self) -> Dict[str, Any]: """ Trains the supplied model with the supplied parameters. """ try: epoch_counter = self._restore_checkpoint() except RuntimeError: traceback.print_exc() raise ConfigurationError( "Could not recover training from the checkpoint. Did you mean to output to " "a different serialization directory or delete the existing serialization " "directory?") training_util.enable_gradient_clipping(self.model, self._grad_clipping) logger.info("Beginning training.") train_metrics: Dict[str, float] = {} val_metrics: Dict[str, float] = {} this_epoch_val_metric: float = None metrics: Dict[str, Any] = {} epochs_trained = 0 training_start_time = time.time() metrics['best_epoch'] = self._metric_tracker.best_epoch for key, value in self._metric_tracker.best_epoch_metrics.items(): metrics["best_validation_" + key] = value #################################################################################################### if self.visdom: def create_plot_window(vis, xlabel, ylabel, title): return vis.line(X=np.array([1]), Y=np.array([np.nan]), opts=dict(xlabel=xlabel, ylabel=ylabel, title=title)) self.train_loss_window = create_plot_window( self.visdom, '#Iterations', 'Loss', 'Training Loss') self.consume_time_window = create_plot_window( self.visdom, "#Epochs", "Seconds", "Consuming time") self.left_time_window = self.visdom.text( "Waiting for training.......") metric_window = {} ########################################################################################## for epoch in range(epoch_counter, self._num_epochs): epoch_start_time = time.time() train_metrics = self._train_epoch(epoch) # get peak of memory usage if 'cpu_memory_MB' in train_metrics: metrics['peak_cpu_memory_MB'] = max( metrics.get('peak_cpu_memory_MB', 0), train_metrics['cpu_memory_MB']) for key, value in train_metrics.items(): if key.startswith('gpu_'): metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) if self._validation_data is not None: with torch.no_grad(): # We have a validation set, so compute all the metrics on it. val_loss, num_batches = self._validation_loss() val_metrics = training_util.get_metrics(self.model, val_loss, num_batches, reset=True) # Check validation metric for early stopping this_epoch_val_metric = val_metrics[ self._validation_metric] self._metric_tracker.add_metric(this_epoch_val_metric) if self._metric_tracker.should_stop_early(): logger.info("Ran out of patience. Stopping training.") break self._tensorboard.log_metrics( train_metrics, val_metrics=val_metrics, log_to_console=True, ) # +1 because tensorboard doesn't like 0 # Create overall metrics dict training_elapsed_time = time.time() - training_start_time metrics["training_duration"] = time.strftime( "%H:%M:%S", time.gmtime(training_elapsed_time)) metrics["training_start_epoch"] = epoch_counter metrics["training_epochs"] = epochs_trained metrics["epoch"] = epoch # print(train_metrics.keys()) # print(val_metrics.keys()) ############################################################################################### if self.visdom: for key in train_metrics.keys(): newkey = 'training_' + key if newkey in metric_window: continue else: metric_window[newkey] = create_plot_window( self.visdom, '#Epochs', key, newkey) for key in val_metrics.keys(): newkey = 'validation_' + key if newkey in metric_window: continue else: metric_window[newkey] = create_plot_window( self.visdom, '#Epochs', key, newkey) ################################################################################################# for key, value in train_metrics.items(): metrics["training_" + key] = value ########################################################## if self.visdom: self.visdom.line(X=np.array([epoch]), Y=np.array([value]), win=metric_window["training_" + key], update='append') ######################################################### for key, value in val_metrics.items(): metrics["validation_" + key] = value ########################################################## if self.visdom: self.visdom.line(X=np.array([epoch]), Y=np.array([value]), win=metric_window["validation_" + key], update='append') ############################################################ if self._metric_tracker.is_best_so_far(): # Update all the best_ metrics. # (Otherwise they just stay the same as they were.) metrics['best_epoch'] = epoch for key, value in val_metrics.items(): metrics["best_validation_" + key] = value self._metric_tracker.best_epoch_metrics = val_metrics if self._serialization_dir: dump_metrics( os.path.join(self._serialization_dir, f'metrics_epoch_{epoch}.json'), metrics) # The Scheduler API is agnostic to whether your schedule requires a validation metric - # if it doesn't, the validation metric passed here is ignored. if self._learning_rate_scheduler: self._learning_rate_scheduler.step(this_epoch_val_metric, epoch) if self._momentum_scheduler: self._momentum_scheduler.step(this_epoch_val_metric, epoch) self._save_checkpoint(epoch) epoch_elapsed_time = time.time() - epoch_start_time logger.info( "Epoch duration: %s", time.strftime("%H:%M:%S", time.gmtime(epoch_elapsed_time))) ####################################################################################### if self.visdom: self.visdom.line(X=np.array([epoch]), Y=np.array([epoch_elapsed_time / 60]), win=self.consume_time_window, update='append') ############################################################################################ if epoch < self._num_epochs - 1: training_elapsed_time = time.time() - training_start_time estimated_time_remaining = training_elapsed_time * \ ((self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1) formatted_time = str( datetime.timedelta(seconds=int(estimated_time_remaining))) logger.info("Estimated training time remaining: %s", formatted_time) ####################################################################################### if self.visdom: self.visdom.text( "Estimated training time remaining: {}".format( formatted_time), win=self.left_time_window, append=True) ############################################################################################ epochs_trained += 1 # Load the best model state before returning best_model_state = self._checkpointer.best_model_state() if best_model_state: self.model.load_state_dict(best_model_state) return metrics
def train(self) -> Dict[str, Any]: try: epoch_counter = self._restore_checkpoint() except RuntimeError: traceback.print_exc() raise ConfigurationError( "Could not recover training from the checkpoint. Did you mean to output to " "a different serialization directory or delete the existing serialization " "directory?") training_util.enable_gradient_clipping(self.model, self._grad_clipping) logger.info("Beginning training.") train_metrics: Dict[str, float] = {} val_metrics: Dict[str, float] = {} this_epoch_val_metric: float = None metrics: Dict[str, Any] = {} epochs_trained = 0 training_start_time = time.time() metrics['best_epoch'] = self._metric_tracker.best_epoch for key, value in self._metric_tracker.best_epoch_metrics.items(): metrics["best_validation_" + key] = value for epoch in range(epoch_counter, self._num_epochs): epoch_start_time = time.time() train_metrics = self._train_epoch(epoch) if self._validation_data is not None: with torch.no_grad(): val_loss, num_batches = self._validation_loss() val_metrics = training_util.get_metrics(self.get_model(), val_loss, num_batches, reset=True) this_epoch_val_metric = val_metrics[ self._validation_metric] self._metric_tracker.add_metric(this_epoch_val_metric) if self._metric_tracker.should_stop_early(): logger.info("Ran out of patience. Stopping training.") break # Create overall metrics dict training_elapsed_time = time.time() - training_start_time metrics["training_duration"] = str( datetime.timedelta(seconds=training_elapsed_time)) metrics["training_start_epoch"] = epoch_counter metrics["training_epochs"] = epochs_trained metrics["epoch"] = epoch for key, value in train_metrics.items(): metrics["training_" + key] = value for key, value in val_metrics.items(): metrics["validation_" + key] = value if self._metric_tracker.is_best_so_far(): metrics['best_epoch'] = epoch for key, value in val_metrics.items(): metrics["best_validation_" + key] = value self._metric_tracker.best_epoch_metrics = val_metrics if self._serialization_dir and is_master_rank(): dump_metrics( os.path.join(self._serialization_dir, f'metrics_epoch_{epoch}.json'), metrics) if self._learning_rate_scheduler: self._learning_rate_scheduler.step(this_epoch_val_metric, epoch) if is_master_rank(): self._save_checkpoint(epoch) epoch_elapsed_time = time.time() - epoch_start_time logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time)) if epoch < self._num_epochs - 1: training_elapsed_time = time.time() - training_start_time estimated_time_remaining = training_elapsed_time * \ ((self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1) formatted_time = str( datetime.timedelta(seconds=int(estimated_time_remaining))) logger.info("Estimated training time remaining: %s", formatted_time) epochs_trained += 1 best_model_state = self._checkpointer.best_model_state() if best_model_state: self.model.load_state_dict(best_model_state) return metrics
def train(self) -> Dict[str, Any]: """ Trains the supplied model with the supplied parameters. """ try: epoch_counter = self._restore_checkpoint() except RuntimeError: traceback.print_exc() raise ConfigurationError( "Could not recover training from the checkpoint. Did you mean to output to " "a different serialization directory or delete the existing serialization " "directory?") training_util.enable_gradient_clipping(self.model, self._grad_clipping) logger.info("Beginning training.") train_metrics: Dict[str, float] = {} val_metrics: Dict[str, float] = {} this_epoch_val_metric: float = None metrics: Dict[str, Any] = {} epochs_trained = 0 training_start_time = time.time() for epoch in range(epoch_counter, self._num_epochs): epoch_start_time = time.time() train_metrics = self._train_epoch(epoch) # get peak of memory usage if 'cpu_memory_MB' in train_metrics: metrics['peak_cpu_memory_MB'] = max( metrics.get('peak_cpu_memory_MB', 0), train_metrics['cpu_memory_MB']) for key, value in train_metrics.items(): if key.startswith('gpu_'): metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) if self._validation_data is not None: with torch.no_grad(): # We have a validation set, so compute all the metrics on it. val_loss, num_batches = self._validation_loss() val_metrics = training_util.get_metrics(self.model, val_loss, num_batches, reset=True) # Check validation metric for early stopping this_epoch_val_metric = val_metrics[ self._validation_metric] self._metric_tracker.add_metric(this_epoch_val_metric) if self._metric_tracker.should_stop_early(): logger.info("Ran out of patience. Stopping training.") break self._tensorboard.log_metrics(train_metrics, val_metrics=val_metrics, log_to_console=True) # Create overall metrics dict training_elapsed_time = time.time() - training_start_time metrics["training_duration"] = time.strftime( "%H:%M:%S", time.gmtime(training_elapsed_time)) metrics["training_start_epoch"] = epoch_counter metrics["training_epochs"] = epochs_trained metrics["epoch"] = epoch for key, value in train_metrics.items(): metrics["training_" + key] = value for key, value in val_metrics.items(): metrics["validation_" + key] = value if self._metric_tracker.is_best_so_far(): # Update all the best_ metrics. # (Otherwise they just stay the same as they were.) metrics['best_epoch'] = epoch for key, value in val_metrics.items(): metrics["best_validation_" + key] = value if self._serialization_dir: dump_metrics( os.path.join(self._serialization_dir, f'metrics_epoch_{epoch}.json'), metrics) if self._learning_rate_scheduler: # The LRScheduler API is agnostic to whether your schedule requires a validation metric - # if it doesn't, the validation metric passed here is ignored. self._learning_rate_scheduler.step(this_epoch_val_metric, epoch) self._save_checkpoint(epoch) epoch_elapsed_time = time.time() - epoch_start_time logger.info( "Epoch duration: %s", time.strftime("%H:%M:%S", time.gmtime(epoch_elapsed_time))) if epoch < self._num_epochs - 1: training_elapsed_time = time.time() - training_start_time estimated_time_remaining = training_elapsed_time * \ ((self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1) formatted_time = str( datetime.timedelta(seconds=int(estimated_time_remaining))) logger.info("Estimated training time remaining: %s", formatted_time) epochs_trained += 1 # Load the best model state before returning best_model_state = self._checkpointer.best_model_state() if best_model_state: self.model.load_state_dict(best_model_state) return metrics
def _enable_gradient_clipping(self) -> None: training_util.enable_gradient_clipping(self._model, self._grad_clipping)
def train(self, experiment: Optional[Experiment] = None) -> Dict[str, Any]: """ Trains the supplied model with the supplied parameters. """ try: epoch_counter = self._restore_checkpoint() except RuntimeError: traceback.print_exc() raise ConfigurationError( "Could not recover training from the checkpoint. Did you mean to output to " "a different serialization directory or delete the existing serialization " "directory?") training_util.enable_gradient_clipping(self.model, self._grad_clipping) self.experiment = experiment logger.info("Beginning training.") self.val_metrics: Dict[str, float] = {} this_epoch_val_metric: float = None self.metrics: Dict[str, Any] = {} epochs_trained = 0 training_start_time = time.time() self.metrics["best_epoch"] = self._metric_tracker.best_epoch for key, value in self._metric_tracker.best_epoch_metrics.items(): self.metrics["best_validation_" + key] = value for callback in self._epoch_callbacks: callback(self, metrics={}, epoch=-1, is_master=self._master) for epoch in range(epoch_counter, self._num_epochs): self.epoch = epoch epoch_start_time = time.time() train_metrics = self._train_epoch(epoch) if experiment: with experiment.train(): experiment.log_metrics( { k: v for k, v in train_metrics.items() if np.isscalar(v) }, step=epoch) # get peak of memory usage for key, value in train_metrics.items(): if key.startswith("gpu_") and key.endswith("_memory_MB"): self.metrics["peak_" + key] = max( self.metrics.get("peak_" + key, 0), value) elif key.startswith("worker_") and key.endswith("_memory_MB"): self.metrics["peak_" + key] = max( self.metrics.get("peak_" + key, 0), value) if self._validation_data_loader is not None and epoch >= self.epochs_before_validate: with torch.no_grad(): try: if self.external_callbacks: self.external_callbacks.call_if_registered( CallbackName.BEFORE_VALIDATION, annotator=self.annotator, model=self.model, trainer=self, experiment=experiment) # We have a validation set, so compute all the metrics on it. val_loss, val_reg_loss, num_batches, preds = self._validation_loss( epoch) # It is safe again to wait till the validation is done. This is # important to get the metrics right. if self._distributed: dist.barrier() self.val_metrics = training_util.get_metrics( self.model, val_loss, val_reg_loss, num_batches, reset=True, world_size=self._world_size, cuda_device=self.cuda_device, ) if self.dataset_writer: if self.decoder: preds = self.decoder.decode_batch( self.model.vocab, preds) filename = self._serialization_dir + f"/pred_epoch_{epoch}.txt" with open(filename, "w") as f: self.dataset_writer.write_to_file( self.model.vocab, OrderedDatasetReader.restore_order(preds), f) if self.validation_command: self.val_metrics.update( self.validation_command.evaluate(filename)) if self.external_callbacks: self.external_callbacks.call_if_registered( CallbackName.AFTER_VALIDATION, annotator=self.annotator, model=self.model, trainer=self, experiment=experiment) # Check validation metric for early stopping this_epoch_val_metric = self.val_metrics[ self._validation_metric] self._metric_tracker.add_metric(this_epoch_val_metric) if self._metric_tracker.should_stop_early(): logger.info( "Ran out of patience. Stopping training.") break except Exception as ex: print("An exception occured:") print(ex) self._checkpointer.save_checkpoint("validation-failed", trainer=self) raise if self._master: self._tensorboard.log_metrics( train_metrics, val_metrics=self.val_metrics, log_to_console=True, epoch=epoch + 1) # +1 because tensorboard doesn't like 0 # Create overall metrics dict training_elapsed_time = time.time() - training_start_time self.metrics["training_duration"] = str( datetime.timedelta(seconds=training_elapsed_time)) self.metrics["training_start_epoch"] = epoch_counter self.metrics["training_epochs"] = epochs_trained self.metrics["epoch"] = epoch for key, value in train_metrics.items(): self.metrics["training_" + key] = value for key, value in self.val_metrics.items(): self.metrics["validation_" + key] = value if experiment: with experiment.validate(): experiment.log_metrics( { k: v for k, v in self.metrics.items() if np.isscalar(v) }, step=epoch) if self._metric_tracker.is_best_so_far(): # Update all the best_ metrics. # (Otherwise they just stay the same as they were.) self.metrics["best_epoch"] = epoch for key, value in self.val_metrics.items(): self.metrics["best_validation_" + key] = value self._metric_tracker.best_epoch_metrics = self.val_metrics if self._serialization_dir and self._master: common_util.dump_metrics( os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"), self.metrics) # The Scheduler API is agnostic to whether your schedule requires a validation metric - # if it doesn't, the validation metric passed here is ignored. if self._learning_rate_scheduler: self._learning_rate_scheduler.step(this_epoch_val_metric) if self._momentum_scheduler: self._momentum_scheduler.step(this_epoch_val_metric) if self._master: self._checkpointer.save_checkpoint( epoch, self, is_best_so_far=self._metric_tracker.is_best_so_far()) # Wait for the master to finish saving the checkpoint if self._distributed: dist.barrier() for callback in self._epoch_callbacks: callback(self, metrics=self.metrics, epoch=epoch, is_master=self._master) epoch_elapsed_time = time.time() - epoch_start_time logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time)) if epoch < self._num_epochs - 1: training_elapsed_time = time.time() - training_start_time estimated_time_remaining = training_elapsed_time * ( (self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1) formatted_time = str( datetime.timedelta(seconds=int(estimated_time_remaining))) logger.info("Estimated training time remaining: %s", formatted_time) epochs_trained += 1 # make sure pending events are flushed to disk and files are closed properly self._tensorboard.close() # Load the best model state before returning best_model_state = self._checkpointer.best_model_state() if best_model_state: self.model.load_state_dict(best_model_state) if self.external_callbacks: self.external_callbacks.call_if_registered( CallbackName.AFTER_TRAINING, annotator=self.annotator, model=self.model, trainer=self, experiment=experiment) return self.metrics
def custom_train(self) -> Dict[str, Any]: """ Trains the supplied model with the supplied parameters. """ logger.info("GAN TRAINER HM START") try: epoch_counter = self.trainer._restore_checkpoint() except RuntimeError: traceback.print_exc() raise ConfigurationError( "Could not recover training from the checkpoint. Did you mean to output to " "a different serialization directory or delete the existing serialization " "directory?") # TODO - gradient clipping? training_util.enable_gradient_clipping(self.trainer.model, self.trainer._grad_clipping) #HACK: #self.trainer._metric_tracker._patience = 30 logger.info("Beginning training.") train_metrics: Dict[str, float] = {} val_metrics: Dict[str, float] = {} this_epoch_val_metric: float = None metrics: Dict[str, Any] = {} epochs_trained = 0 training_start_time = time.time() metrics['best_epoch'] = self.trainer._metric_tracker.best_epoch for key, value in self.trainer._metric_tracker.best_epoch_metrics.items( ): metrics["best_validation_" + key] = value for epoch in range(epoch_counter, self.trainer._num_epochs): # Start tracemalloc # tracemalloc.start() epoch_start_time = time.time() train_metrics = self.semi_train_epoch(epoch) # get peak of memory usage if 'cpu_memory_MB' in train_metrics: metrics['peak_cpu_memory_MB'] = max( metrics.get('peak_cpu_memory_MB', 0), train_metrics['cpu_memory_MB']) for key, value in train_metrics.items(): if key.startswith('gpu_'): metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) """ if self.unlabelled_dataset is not None: unlabelled_metrics = unlabelled_train_epoch(self.trainer, self.unlabelled_dataset, epoch) for key, value in unlabelled_metrics.items(): if key.startswith('gpu_'): metrics["peak_"+'un_'+key] = max(unlabelled_metrics.get("peak_"+key, 0), value) else: metrics['un_'+key] = value """ if self.trainer._validation_data is not None and ( (epoch - epoch_counter) % self.calc_valid_freq == (self.calc_valid_freq - 1)): with torch.no_grad(): # We have a validation set, so compute all the metrics on it. val_loss, num_batches = self.trainer._validation_loss() val_metrics = training_util.get_metrics(self.trainer.model, val_loss, num_batches, reset=True) # Check validation metric for early stopping this_epoch_val_metric = val_metrics[ self.trainer._validation_metric] self.trainer._metric_tracker.add_metric( this_epoch_val_metric) if self.trainer._metric_tracker.should_stop_early(): logger.info("Ran out of patience. Stopping training.") break self.trainer._tensorboard.log_metrics(train_metrics, val_metrics=val_metrics, log_to_console=True) # Create overall metrics dict training_elapsed_time = time.time() - training_start_time metrics["training_duration"] = time.strftime( "%H:%M:%S", time.gmtime(training_elapsed_time)) metrics["training_start_epoch"] = epoch_counter metrics["training_epochs"] = epochs_trained metrics["epoch"] = epoch for key, value in train_metrics.items(): metrics["training_" + key] = value for key, value in val_metrics.items(): metrics["validation_" + key] = value is_best_so_far = False if self.trainer._metric_tracker.is_best_so_far(): is_best_so_far = True # Update all the best_ metrics. # (Otherwise they just stay the same as they were.) metrics['best_epoch'] = epoch for key, value in val_metrics.items(): metrics["best_validation_" + key] = value self.trainer._metric_tracker.best_epoch_metrics = val_metrics if self.trainer._serialization_dir: dump_metrics( os.path.join(self.trainer._serialization_dir, f'metrics_epoch_{epoch}.json'), metrics) #Pdb().set_trace() if self.trainer._learning_rate_scheduler: # The LRScheduler API is agnostic to whether your schedule requires a validation metric - # if it doesn't, the validation metric passed here is ignored. self.trainer._learning_rate_scheduler.step( this_epoch_val_metric, epoch) self.trainer._save_checkpoint(epoch) if self.constraints_model is not None: spath = self.save_constraints_model(epoch) if is_best_so_far: shutil.copyfile( spath, os.path.join(self.trainer._serialization_dir, 'best_dd_checkpoint.pth')) # Start saving checkpoint models after checkpoint_begin after every checkpoint_interval #if (self.trainer._checkpointer._save_intermediate_checkpoints) and (epoch >= self.trainer._checkpointer._checkpoint_begin) and (epoch%self.trainer._checkpointer._checkpoint_interval == 0): # shutil.copyfile(spath,os.path.join(self.trainer._serialization_dir,'dd_checkpoint_epoch_'+str(epoch)+'.cpoint')) epoch_elapsed_time = time.time() - epoch_start_time logger.info( "Epoch duration: %s", time.strftime("%H:%M:%S", time.gmtime(epoch_elapsed_time))) if epoch < self.trainer._num_epochs - 1: training_elapsed_time = time.time() - training_start_time estimated_time_remaining = training_elapsed_time * \ ((self.trainer._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1) formatted_time = str( datetime.timedelta(seconds=int(estimated_time_remaining))) logger.info("Estimated training time remaining: %s", formatted_time) self.trainer.model.train() epochs_trained += 1 # Take snapshot and reveal top memory allocation # snapshot = tracemalloc.take_snapshot() # top_stats = snapshot.statistics('lineno') # print("[ Top 10 ]") # for stat in top_stats[:10]: # logger.info(stat) # Load the best model state before returning best_model_state = self.trainer._checkpointer.best_model_state() if best_model_state: self.trainer.model.load_state_dict(best_model_state) return metrics
def train(self) -> Dict[str, Any]: """ Trains the supplied model with the supplied parameters. """ try: epoch_counter = self._restore_checkpoint() except RuntimeError: traceback.print_exc() raise ConfigurationError( "Could not recover training from the checkpoint. Did you mean to output to " "a different serialization directory or delete the existing serialization " "directory?") training_util.enable_gradient_clipping(self.model, self._grad_clipping) logger.info("Beginning training.") train_metrics: Dict[str, float] = {} val_metrics: Dict[str, float] = {} this_epoch_val_metric: float = None metrics: Dict[str, Any] = {} epochs_trained = 0 training_start_time = time.time() if self.cold_step_count > 0: # 冷启动几个step, 这些step不用更新权重.对于embed网路都freeeze上.只更新后面自己简历的分类器网络. base_lr = self.optimizer.param_groups[0]['lr'] for param_group in self.optimizer.param_groups: param_group['lr'] = self.cold_lr self.model.text_field_embedder._token_embedders[ 'bert'].set_weights(freeze=True) metrics["best_epoch"] = self._metric_tracker.best_epoch for key, value in self._metric_tracker.best_epoch_metrics.items(): metrics["best_validation_" + key] = value for epoch in range(epoch_counter, self._num_epochs): # 把之前学完的epoch直接跳过. if epoch == self.cold_step_count and epoch != 0: # 冷启动完毕,开始恢复学习率 for param_group in self.optimizer.param_groups: param_group['lr'] = base_lr self.model.text_field_embedder._token_embedders[ 'bert'].set_weights(freeze=False) #并且把embed网络解除冻结 epoch_start_time = time.time() train_metrics = self._train_epoch(epoch) # get peak of memory usage if "cpu_memory_MB" in train_metrics: metrics["peak_cpu_memory_MB"] = max( metrics.get("peak_cpu_memory_MB", 0), train_metrics["cpu_memory_MB"]) for key, value in train_metrics.items(): if key.startswith("gpu_"): metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) # clear cache before validation torch.cuda.empty_cache() if self._validation_data is not None: with torch.no_grad(): # We have a validation set, so compute all the metrics on it. val_loss, num_batches = self._validation_loss() val_metrics = training_util.get_metrics(self.model, val_loss, num_batches, reset=True) # Check validation metric for early stopping this_epoch_val_metric = val_metrics[ self._validation_metric] self._metric_tracker.add_metric(this_epoch_val_metric) if self._metric_tracker.should_stop_early(): logger.info("Ran out of patience. Stopping training.") break self._tensorboard.log_metrics( train_metrics, val_metrics=val_metrics, log_to_console=True, epoch=epoch + 1) # +1 because tensorboard doesn't like 0 # Create overall metrics dict training_elapsed_time = time.time() - training_start_time metrics["training_duration"] = str( datetime.timedelta(seconds=training_elapsed_time)) metrics["training_start_epoch"] = epoch_counter metrics["training_epochs"] = epochs_trained metrics["epoch"] = epoch for key, value in train_metrics.items(): metrics["training_" + key] = value for key, value in val_metrics.items(): metrics["validation_" + key] = value # if self.cold_step_count <= epoch: self.scheduler.step(metrics['validation_loss']) if self._metric_tracker.is_best_so_far(): # Update all the best_ metrics. # (Otherwise they just stay the same as they were.) metrics["best_epoch"] = epoch for key, value in val_metrics.items(): metrics["best_validation_" + key] = value self._metric_tracker.best_epoch_metrics = val_metrics if self._serialization_dir: dump_metrics( os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"), metrics) # The Scheduler API is agnostic to whether your schedule requires a validation metric - # if it doesn't, the validation metric passed here is ignored. if self._learning_rate_scheduler: self._learning_rate_scheduler.step(this_epoch_val_metric, epoch) if self._momentum_scheduler: self._momentum_scheduler.step(this_epoch_val_metric, epoch) #保存model, 只需要给定这个文件夹,那么算法自动会读取里面最新的模型.来进行finetune.很方便. self._save_checkpoint(epoch) epoch_elapsed_time = time.time() - epoch_start_time logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time)) if epoch < self._num_epochs - 1: training_elapsed_time = time.time() - training_start_time estimated_time_remaining = training_elapsed_time * ( (self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1) formatted_time = str( datetime.timedelta(seconds=int(estimated_time_remaining))) logger.info("Estimated training time remaining: %s", formatted_time) epochs_trained += 1 # make sure pending events are flushed to disk and files are closed properly # self._tensorboard.close() # Load the best model state before returning # 根据路径目录找到里面存的最好的模型. best_model_state = self._checkpointer.best_model_state() if best_model_state: self.model.load_state_dict(best_model_state) return metrics
def enable_gradient_clipping(self, trainer): training_util.enable_gradient_clipping(trainer.model, self.grad_clipping)
def multi_task_training( main_trainer_name: Tuple[Trainer, str], aux_trainers_names: Tuple[List[Trainer], List[str]]) -> Dict[str, Any]: ''' Performs as many epochs as the main task requires and if early stopping is set then it is defined by the main task. The way that multi task is run it runs the auxiliary task for one epoch and then the main task for one epoch and then it evaluates the main tasks validation dataset to see if early stopping needs to happen and if so then no more training else it goes for another epoch on auxiliary then main task. :param main_trainer_name: A tuple of 1. Trainer and 2. name of task. :param aux_trainers_names: A tuple of 1. A list of auxiliary trainers and 2. A list of names associated to those trainers. :returns: Metrics for both auxiliary and main tasks ''' main_trainer = main_trainer_name[0] main_task_name = main_trainer_name[1] training_util.enable_gradient_clipping(main_trainer.model, main_trainer._grad_clipping) for aux_trainer in aux_trainers_names[0]: training_util.enable_gradient_clipping(aux_trainer.model, aux_trainer._grad_clipping) all_metrics: Dict[str, Any] = {} # need to deal with the metrics the format could be `split name, auxiliary or not, task name,` for epoch in range(main_trainer._num_epochs): aux_name_validation_metrics: Dict[str, float] = {} for aux_trainer, aux_name in zip(*aux_trainers_names): logger.warning(f'Training Auxiliary task {aux_name}') aux_metrics = train_one_epoch(aux_trainer, epoch) all_metrics[f'training_aux_{aux_name}'] = aux_metrics[0] all_metrics[f'validation_aux_{aux_name}'] = aux_metrics[1] aux_name_validation_metrics[aux_name] = aux_metrics[1] logger.warning(f'Training Main task {main_task_name}') main_train_metrics, main_val_metrics = train_one_epoch( main_trainer, epoch) all_metrics[f'training_main_{main_task_name}'] = main_train_metrics all_metrics[f'validation_main_{main_task_name}'] = main_val_metrics # Early stopping if applicable (main task) and tracking the best metric main_validation_metric_name = main_trainer._validation_metric main_validation_metric = main_val_metrics[main_validation_metric_name] main_trainer._metric_tracker.add_metric(main_validation_metric) for aux_trainer in aux_trainers_names[0]: multi_task_checkpoint_saver( aux_trainer, main_trainer._metric_tracker.is_best_so_far(), epoch) multi_task_checkpoint_saver( main_trainer, main_trainer._metric_tracker.is_best_so_far(), epoch) if main_trainer._metric_tracker.should_stop_early(): logger.info("Ran out of patience. Stopping training.") break # Getting the best metrics for the main task if main_trainer._metric_tracker.is_best_so_far(): # Update all the best_ metrics. # (Otherwise they just stay the same as they were.) all_metrics['best_epoch'] = epoch for key, value in main_val_metrics.items(): all_metrics["best_validation_" + key] = value main_trainer._metric_tracker.best_epoch_metrics = main_val_metrics # Load the best model state before returning main_best_model_state = main_trainer._checkpointer.best_model_state() if main_best_model_state: main_trainer.model.load_state_dict(main_best_model_state) for aux_trainer in aux_trainers_names[0]: aux_best_model_state = aux_trainer._checkpointer.best_model_state() if aux_best_model_state: aux_trainer.model.load_state_dict(aux_best_model_state) return all_metrics
def train(self) -> Dict[str, Any]: """ Trains the supplied model with the supplied parameters. 相关的metric字典记录的信息都在训练时产生的json文件中 """ try: epoch_counter = self._restore_checkpoint() except RuntimeError: traceback.print_exc() raise ConfigurationError( "Could not recover training from the checkpoint. Did you mean to output to " "a different serialization directory or delete the existing serialization " "directory?") # 梯度剪裁 防止梯度爆炸跳过最优解 training_util.enable_gradient_clipping(self.model, self._grad_clipping) logger.info("Beginning training.") train_metrics: Dict[str, float] = {} val_metrics: Dict[str, float] = {} this_epoch_val_metric: float = None metrics: Dict[str, Any] = {} epochs_trained = 0 # ------训练开始------- training_start_time = time.time() # cold_step_count为只训练最后一层线性层的epoch数 # 训练阶段一,二 # 在前 cold_step_count个epoch # 不需要训练原来的预训练模型,之后需要训练 # 阶段三直接训练预训练模型参数, 因为预训练模型的参数过多 # 同时需要注意,在cold step阶段也要使用cold lr, # 此阶段结束后,使用base lr if self.cold_step_count > 0: # 1e-5 base_lr = self.optimizer.param_groups[0]['lr'] for param_group in self.optimizer.param_groups: # 1e-3 param_group['lr'] = self.cold_lr self.model.text_field_embedder._token_embedders[ 'bert'].set_weights(freeze=True) metrics["best_epoch"] = self._metric_tracker.best_epoch for key, value in self._metric_tracker.best_epoch_metrics.items(): metrics["best_validation_" + key] = value # epoch_counter = 0 if restore_checkpoint is none else continue training for epoch in range(epoch_counter, self._num_epochs): # 恢复正常 if epoch == self.cold_step_count and epoch != 0: for param_group in self.optimizer.param_groups: param_group['lr'] = base_lr self.model.text_field_embedder._token_embedders[ 'bert'].set_weights(freeze=False) # --开始当前epoch-- epoch_start_time = time.time() # **训练** train_metrics = self._train_epoch(epoch) # get peak of memory usage if "cpu_memory_MB" in train_metrics: metrics["peak_cpu_memory_MB"] = max( metrics.get("peak_cpu_memory_MB", 0), train_metrics["cpu_memory_MB"]) for key, value in train_metrics.items(): if key.startswith("gpu_"): metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) # clear cache before validation torch.cuda.empty_cache() # evaluate的函数说了, 不是一定需要进行验证,所以这里要做判断 if self._validation_data is not None: # 常规操作,验证时不计算梯度,不更新参数 with torch.no_grad(): # We have a validation set, so compute all the metrics on it. val_loss, num_batches = self._validation_loss() val_metrics = training_util.get_metrics(self.model, val_loss, num_batches, reset=True) # Check validation metric for early stopping # 获取性能指标--loss this_epoch_val_metric = val_metrics[ self._validation_metric] self._metric_tracker.add_metric(this_epoch_val_metric) if self._metric_tracker.should_stop_early(): # 这就是为什么有的时候ckpt不足epoch个数,是因为patience耗光 # patience是配合早停机制的阈值,patience次在验证集的性能下降时,停止训练 logger.info("Ran out of patience. Stopping training.") break self._tensorboard.log_metrics( train_metrics, val_metrics=val_metrics, log_to_console=True, epoch=epoch + 1) # +1 because tensorboard doesn't like 0 # Create overall metrics dict # **epoch结束** training_elapsed_time = time.time() - training_start_time metrics["training_duration"] = str( datetime.timedelta(seconds=training_elapsed_time)) metrics["training_start_epoch"] = epoch_counter metrics["training_epochs"] = epochs_trained metrics["epoch"] = epoch # 将train, evaluate阶段的metric记录都汇总 for key, value in train_metrics.items(): metrics["training_" + key] = value for key, value in val_metrics.items(): metrics["validation_" + key] = value # if self.cold_step_count <= epoch: # step操作 self.scheduler.step(metrics['validation_loss']) # 这些更新都在119服务器的pretraingectors目录下 if self._metric_tracker.is_best_so_far(): # Update all the best_ metrics. # (Otherwise they just stay the same as they were.) metrics["best_epoch"] = epoch for key, value in val_metrics.items(): metrics["best_validation_" + key] = value self._metric_tracker.best_epoch_metrics = val_metrics # 以json形式存储metrics if self._serialization_dir: dump_metrics( os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"), metrics) # The Scheduler API is agnostic to whether your schedule requires a validation metric - # if it doesn't, the validation metric passed here is ignored. if self._learning_rate_scheduler: # step操作 self._learning_rate_scheduler.step(this_epoch_val_metric, epoch) if self._momentum_scheduler: # step操作 self._momentum_scheduler.step(this_epoch_val_metric, epoch) # 保存ckpt self._save_checkpoint(epoch) epoch_elapsed_time = time.time() - epoch_start_time logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time)) if epoch < self._num_epochs - 1: training_elapsed_time = time.time() - training_start_time estimated_time_remaining = training_elapsed_time * ( (self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1) formatted_time = str( datetime.timedelta(seconds=int(estimated_time_remaining))) logger.info("Estimated training time remaining: %s", formatted_time) # 一个epoch结束 epochs_trained += 1 # make sure pending events are flushed to disk and files are closed properly # self._tensorboard.close() # Load the best model state before returning best_model_state = self._checkpointer.best_model_state() if best_model_state: self.model.load_state_dict(best_model_state) return metrics
def train(self, oldmodel) -> Dict[str, Any]: """ Trains the supplied model with the supplied parameters. """#---------------通过oldmodel来加载官方模型,从而进行finetune try: # epoch_counter = self._restore_checkpoint() epoch_counter = 0 # 直接进行finetune. 加速训练. # 下面看这个地方怎么修改,保证之前的参数都赋值上去.?????????????????????????????????????之前没遇到过这个问题,需要自己克服看看. print(oldmodel, '旧的模型路径是这个') # 这个model里面已经涵盖了xlnet网络的参数和最后2层linear的参数. tmp = torch.load(oldmodel, map_location=torch.device('cpu')) out_shape = self.model.tag_labels_projection_layer._module.out_features # 把下面的东西扩充到out_shape now_shape = tmp[ 'tag_labels_projection_layer._module.weight'].shape[0] fix_shape = out_shape - now_shape # 通过concat tmp['tag_labels_projection_layer._module.weight'] = torch.cat( (tmp['tag_labels_projection_layer._module.weight'], torch.zeros(fix_shape, 768)), 0) tmp['tag_labels_projection_layer._module.bias'] = torch.cat( (tmp['tag_labels_projection_layer._module.bias'], torch.zeros(fix_shape)), 0) # 需要补充到的数据大小: # tmp只是一个字典而已,随便玩. self.model.load_state_dict( tmp) # 这次的收敛速度飞快!!!!!!!!!!!!!!!!!!!# 初步的打算是,补充shape到我们需要的,大小. except RuntimeError: traceback.print_exc() raise ConfigurationError( "Could not recover training from the checkpoint. Did you mean to output to " "a different serialization directory or delete the existing serialization " "directory?") training_util.enable_gradient_clipping(self.model, self._grad_clipping) logger.info("Beginning training.") train_metrics: Dict[str, float] = {} val_metrics: Dict[str, float] = {} this_epoch_val_metric: float = None metrics: Dict[str, Any] = {} epochs_trained = 0 training_start_time = time.time() if self.cold_step_count > 0: # 冷启动几个step, 这些step不用更新权重.对于embed网路都freeeze上.只更新后面自己简历的分类器网络. base_lr = self.optimizer.param_groups[0]['lr'] for param_group in self.optimizer.param_groups: param_group['lr'] = self.cold_lr self.model.text_field_embedder._token_embedders[ 'bert'].set_weights(freeze=True) metrics["best_epoch"] = self._metric_tracker.best_epoch for key, value in self._metric_tracker.best_epoch_metrics.items(): metrics["best_validation_" + key] = value for epoch in range(epoch_counter, self._num_epochs): # 把之前学完的epoch直接跳过. if epoch == self.cold_step_count and epoch != 0: # 冷启动完毕,开始恢复学习率 for param_group in self.optimizer.param_groups: param_group['lr'] = base_lr self.model.text_field_embedder._token_embedders[ 'bert'].set_weights(freeze=False) #并且把embed网络解除冻结 epoch_start_time = time.time() # 下行是训练代码 train_metrics = self._train_epoch(epoch) # get peak of memory usage if "cpu_memory_MB" in train_metrics: metrics["peak_cpu_memory_MB"] = max( metrics.get("peak_cpu_memory_MB", 0), train_metrics["cpu_memory_MB"]) for key, value in train_metrics.items(): if key.startswith("gpu_"): metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) # clear cache before validation torch.cuda.empty_cache() # 在验证集上评测效果,防止过拟合. if self._validation_data is not None: with torch.no_grad(): # We have a validation set, so compute all the metrics on it. val_loss, num_batches = self._validation_loss() val_metrics = training_util.get_metrics(self.model, val_loss, num_batches, reset=True) # Check validation metric for early stopping this_epoch_val_metric = val_metrics[ self._validation_metric] self._metric_tracker.add_metric(this_epoch_val_metric) if self._metric_tracker.should_stop_early(): logger.info("Ran out of patience. Stopping training.") break self._tensorboard.log_metrics( train_metrics, val_metrics=val_metrics, log_to_console=True, epoch=epoch + 1) # +1 because tensorboard doesn't like 0 # Create overall metrics dict training_elapsed_time = time.time() - training_start_time metrics["training_duration"] = str( datetime.timedelta(seconds=training_elapsed_time)) metrics["training_start_epoch"] = epoch_counter metrics["training_epochs"] = epochs_trained metrics["epoch"] = epoch for key, value in train_metrics.items(): metrics["training_" + key] = value for key, value in val_metrics.items(): metrics["validation_" + key] = value # if self.cold_step_count <= epoch: self.scheduler.step(metrics['validation_loss']) if self._metric_tracker.is_best_so_far(): # Update all the best_ metrics. # (Otherwise they just stay the same as they were.) metrics["best_epoch"] = epoch for key, value in val_metrics.items(): metrics["best_validation_" + key] = value self._metric_tracker.best_epoch_metrics = val_metrics if self._serialization_dir: dump_metrics( os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"), metrics) # The Scheduler API is agnostic to whether your schedule requires a validation metric - # if it doesn't, the validation metric passed here is ignored. if self._learning_rate_scheduler: self._learning_rate_scheduler.step(this_epoch_val_metric, epoch) if self._momentum_scheduler: self._momentum_scheduler.step(this_epoch_val_metric, epoch) #保存model, 只需要给定这个文件夹,那么算法自动会读取里面最新的模型.来进行finetune.很方便. if self._num_epochs == epoch + 1: # 只存最后一个. self._save_checkpoint(epoch) # 每一个epoch 都存,最后的磁盘占用很大. epoch_elapsed_time = time.time() - epoch_start_time logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time)) if epoch < self._num_epochs - 1: training_elapsed_time = time.time() - training_start_time estimated_time_remaining = training_elapsed_time * ( (self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1) formatted_time = str( datetime.timedelta(seconds=int(estimated_time_remaining))) logger.info("Estimated training time remaining: %s", formatted_time) epochs_trained += 1 # make sure pending events are flushed to disk and files are closed properly # self._tensorboard.close() # Load the best model state before returning # 根据路径目录找到里面存的最好的模型. best_model_state = self._checkpointer.best_model_state() if best_model_state: self.model.load_state_dict(best_model_state) return metrics