def train_from_state( self, state: TrainingState, training_data: BatchIterator, eval_data: BatchIterator, metric_reporter: MetricReporter, train_config: PyTextConfig, ) -> Tuple[torch.nn.Module, Any]: """ Train and eval a model from a given training state will be modified. This function iterates epochs specified in config, and for each epoch do: 1. Train model using training data, aggregate and report training results 2. Adjust learning rate if scheduler is specified 3. Evaluate model using evaluation data 4. Calculate metrics based on evaluation results and select best model Args: training_state (TrainingState): contrains stateful information to be able to restore a training job train_iter (BatchIterator): batch iterator of training data eval_iter (BatchIterator): batch iterator of evaluation data model (Model): model to be trained metric_reporter (MetricReporter): compute metric based on training output and report results to console, file.. etc train_config (PyTextConfig): training config Returns: model, best_metric: the trained model together with the best metric """ training_data = self.set_up_training(state, training_data) model = state.model rank = state.rank trainable_params = sum(p.numel() for p in state.model.parameters() if p.requires_grad) print(f"Model :{model}") print(f"Num trainable parameters: {trainable_params}") self.sparsifier.initialize(self, state, eval_data, metric_reporter, train_config) while self.continue_training(state): self.sparsifier.op_pre_epoch(self, state) state.epoch += 1 state.epochs_since_last_improvement += 1 lrs = learning_rates(state.optimizer) print(f"\nWorker {state.rank} starting epoch {state.epoch}") print(f"Learning rate(s): {', '.join(map(str, lrs))}") with timing.time("train epoch"): state.stage = Stage.TRAIN state.model.train() print(f"start training epoch {state.epoch}") epoch_data = training_data if self.config.num_batches_per_epoch: # We want to limit the number of batches in the epoch; # equivalent to epoch_data[:num_batches_per_epoch] for iterators. # In this case we set the training data iterator to cycle earlier # in the training process, so when it reaches the end it will # loop back to the beginning. epoch_data = itertools.islice( epoch_data, self.config.num_batches_per_epoch) self.run_epoch(state, epoch_data, metric_reporter) if not self.config.do_eval: continue with timing.time("eval epoch"): state.stage = Stage.EVAL model.eval(Stage.EVAL) print(f"start evaluating epoch {state.epoch}") with torch.no_grad(): eval_metric = self.run_epoch(state, eval_data, metric_reporter) # Step the learning rate scheduler(s) assert eval_metric is not None state.scheduler.step_epoch( metrics=metric_reporter.get_model_select_metric(eval_metric), epoch=state.epoch, ) # Did we train a better model? better_model = metric_reporter.compare_metric( eval_metric, state.best_model_metric) if better_model: self.update_best_model(state, train_config, eval_metric) if better_model or train_config.save_all_checkpoints: self.save_checkpoint(state, train_config) if self.optimizer.finalize(): should_update_model = True eval_metric = None if self.config.do_eval: state.stage = Stage.EVAL model.eval(Stage.EVAL) print(f"start evaluating finalized state") with torch.no_grad(): eval_metric = self.run_epoch(state, eval_data, metric_reporter) should_update_model = metric_reporter.compare_metric( eval_metric, state.best_model_metric) if should_update_model: self.update_best_model(state, train_config, eval_metric) if should_update_model or train_config.save_all_checkpoints: self.save_checkpoint(state, train_config) # Only bother loading the best model for master worker if (rank == 0 and state.best_model_state is not None and self.config.load_best_model_after_train): self.load_best_model(state) return state.model, state.best_model_metric
def train( self, train_iter: BatchIterator, eval_iter: BatchIterator, model: Model, metric_reporter: MetricReporter, train_config: PyTextConfig, optimizer: torch.optim.Optimizer, scheduler=None, rank: int = 0, ) -> Tuple[torch.nn.Module, Any]: """ Train and eval a model, the model states will be modified. This function iterates epochs specified in config, and for each epoch do: 1. Train model using training data, aggregate and report training results 2. Adjust learning rate if scheduler is specified 3. Evaluate model using evaluation data 4. Calculate metrics based on evaluation results and select best model Args: train_iter (BatchIterator): batch iterator of training data eval_iter (BatchIterator): batch iterator of evaluation data model (Model): model to be trained metric_reporter (MetricReporter): compute metric based on training output and report results to console, file.. etc train_config (PyTextConfig): training config optimizer (torch.optim.Optimizer): torch optimizer to be used scheduler (Optional[torch.optim.lr_scheduler]): learning rate scheduler, default is None training_result (Optional): only meaningful for Hogwild training. default is None rank (int): only used in distributed training, the rank of the current training thread, evaluation will only be done in rank 0 Returns: model, best_metric: the trained model together with the best metric """ timer = time_utils.StageTimer() world_size = 1 if cuda_utils.CUDA_ENABLED: model = model.cuda() world_size = cuda_utils.DISTRIBUTED_WORLD_SIZE if world_size > 1: device_id = torch.cuda.current_device() model = DistributedModel( module=model, device_ids=[device_id], output_device=device_id, broadcast_buffers=False, ) timer.add_stage(stage="init_distributed_model") best_metric = None last_best_epoch = 0 scheduler = self._prepare_scheduler(train_iter, scheduler) timer.add_stage(stage="pre_training") def training_pre_batch_callback(): if world_size > 1: # replace optimizer.zero_grad() here to work with DDP # in cases where some parameters don't receive grads at each step # loss.backward will set grad for params in the computation graph # we can thus follow which params are left out and call .backward # on them manually for p in model.parameters(): if p.grad is not None: p.grad.detach_() p.grad = None else: optimizer.zero_grad() def training_backprop(loss, timer): loss.backward() if world_size > 1: # DDP fix when some parameters don't receive grads for p in model.parameters(): if p.requires_grad and p.grad is None: p.backward(torch.zeros_like(p.data)) timer.add_stage("backward") if scheduler: scheduler.step_batch() if self.config.max_clip_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), self.config.max_clip_norm) else: grad_norm = None optimizer.step() timer.add_stage("update_grads") # grad_norm could be used to check grads sync in distributed training return grad_norm time_start = time.time() for epoch in range(1, self.config.epochs + 1): if self.config.target_time_limit_seconds > 0 and epoch > 1: time_elapsed = time.time() - time_start mean_epoch_time = time_elapsed / float(epoch - 1) expected_next_epoch_time = time_elapsed + mean_epoch_time if expected_next_epoch_time > self.config.target_time_limit_seconds: print( f"Training stopped after {epoch - 1} epochs and " f"{int(time_elapsed)} seconds, due to the target max training " f"time of {self.config.target_time_limit_seconds} seconds." ) break print(f"Rank {rank} worker: Starting epoch #{epoch}") model.train() lrs = (str(lr) for lr in learning_rates(optimizer)) print(f"Learning rate(s): {', '.join(lrs)}") self._run_epoch( Stage.TRAIN, epoch, train_iter, model, metric_reporter, pre_batch=training_pre_batch_callback, backprop=training_backprop, rank=rank, ) timer.add_stage(stage=f"epoch_train") model.eval(Stage.EVAL) with torch.no_grad(): eval_metric = self._run_epoch(Stage.EVAL, epoch, eval_iter, model, metric_reporter, rank=rank) timer.add_stage(stage=f"epoch_eval") # Step the learning rate scheduler(s) if scheduler: assert eval_metric is not None scheduler.step( metrics=metric_reporter.get_model_select_metric( eval_metric), epoch=epoch, ) # choose best model. if metric_reporter.compare_metric(eval_metric, best_metric): last_best_epoch = epoch best_metric = eval_metric # Only rank = 0 trainer saves modules. if train_config.save_module_checkpoints and rank == 0: model.save_modules(base_path=train_config.modules_save_dir, suffix=f"-ep{epoch}") if rank == 0: print(f"Rank {rank} worker: Found a better model!") model_state = model.state_dict() # save to cpu to avoid multiple model copies in gpu memory if cuda_utils.CUDA_ENABLED: for key, state in model_state.items(): model_state[key] = state.cpu() best_model_state = model_state timer.add_stage(stage=f"epoch_save/load_module") if self.config.early_stop_after > 0 and ( epoch - last_best_epoch == self.config.early_stop_after): print(f"Rank {rank} worker: Eval metric hasn't changed for " + f"{self.config.early_stop_after} epochs. Stopping now.") break sys.stdout.flush() if rank == 0: if cuda_utils.CUDA_ENABLED: for key, state in best_model_state.items(): best_model_state[key] = state.cuda() model.load_state_dict(best_model_state) timer.report("Trainer train timer") return model, best_metric
def train( self, train_iter: BatchIterator, eval_iter: BatchIterator, model: Model, metric_reporter: MetricReporter, train_config: PyTextConfig, optimizers: List[torch.optim.Optimizer], scheduler=None, rank: int = 0, ) -> Tuple[torch.nn.Module, Any]: """ Train and eval a model, the model states will be modified. This function iterates epochs specified in config, and for each epoch do: 1. Train model using training data, aggregate and report training results 2. Adjust learning rate if scheduler is specified 3. Evaluate model using evaluation data 4. Calculate metrics based on evaluation results and select best model Args: train_iter (BatchIterator): batch iterator of training data eval_iter (BatchIterator): batch iterator of evaluation data model (Model): model to be trained metric_reporter (MetricReporter): compute metric based on training output and report results to console, file.. etc train_config (PyTextConfig): training config optimizers (List[torch.optim.Optimizer]): a list of torch optimizers, in most of the case only contains one optimizer scheduler (Optional[torch.optim.lr_scheduler]): learning rate scheduler, default is None training_result (Optional): only meaningful for Hogwild training. default is None rank (int): only used in distributed training, the rank of the current training thread, evaluation will only be done in rank 0 Returns: model, best_metric: the trained model together with the best metric """ if cuda_utils.CUDA_ENABLED: model = model.cuda() if cuda_utils.DISTRIBUTED_WORLD_SIZE > 1: device_id = torch.cuda.current_device() model = DistributedModel( module=model, device_ids=[device_id], output_device=device_id, broadcast_buffers=False, ) best_metric = None last_best_epoch = 0 best_model_state = None scheduler = self._prepare_scheduler(train_iter, scheduler) def training_pre_batch_callback(): optimizer_zero_grad(optimizers) def training_backprop(loss): loss.backward() if scheduler: scheduler.step_batch() if self.config.max_clip_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), self.config.max_clip_norm) else: grad_norm = None optimizer_step(optimizers) # grad_norm could be used to check grads sync in distributed training return grad_norm for epoch in range(1, self.config.epochs + 1): print(f"Rank {rank} worker: Starting epoch #{epoch}") model.train() lrs = (str(lr) for lr in learning_rates(optimizers)) print(f"Learning rate(s): {', '.join(lrs)}") self._run_epoch( Stage.TRAIN, epoch, train_iter, model, metric_reporter, pre_batch=training_pre_batch_callback, backprop=training_backprop, rank=rank, ) model.eval(Stage.EVAL) eval_metric = self._run_epoch(Stage.EVAL, epoch, eval_iter, model, metric_reporter, rank=rank) # Step the learning rate scheduler(s) if scheduler: assert eval_metric is not None scheduler.step( metrics=metric_reporter.get_model_select_metric( eval_metric), epoch=epoch, ) # choose best model. if metric_reporter.compare_metric(eval_metric, best_metric): print( f"Rank {rank} worker: Found a better model! Saving the model state." ) last_best_epoch = epoch best_metric = eval_metric # Only rank = 0 trainer saves modules. if train_config.save_module_checkpoints and rank == 0: model.save_modules(base_path=train_config.modules_save_dir, suffix=f"-ep{epoch}") best_model_state = copy.deepcopy(model.state_dict()) if self.config.early_stop_after > 0 and ( epoch - last_best_epoch == self.config.early_stop_after): print(f"Rank {rank} worker: Eval metric hasn't changed for " + f"{self.config.early_stop_after} epochs. Stopping now.") break sys.stdout.flush() model.load_state_dict(best_model_state) return model, best_metric
def train( self, training_data: BatchIterator, eval_data: BatchIterator, model: Model, metric_reporter: MetricReporter, train_config: PyTextConfig, rank: int = 0, ) -> Tuple[torch.nn.Module, Any]: """ Train and eval a model, the model states will be modified. This function iterates epochs specified in config, and for each epoch do: 1. Train model using training data, aggregate and report training results 2. Adjust learning rate if scheduler is specified 3. Evaluate model using evaluation data 4. Calculate metrics based on evaluation results and select best model Args: train_iter (BatchIterator): batch iterator of training data eval_iter (BatchIterator): batch iterator of evaluation data model (Model): model to be trained metric_reporter (MetricReporter): compute metric based on training output and report results to console, file.. etc train_config (PyTextConfig): training config training_result (Optional): only meaningful for Hogwild training. default is None rank (int): only used in distributed training, the rank of the current training thread, evaluation will only be done in rank 0 Returns: model, best_metric: the trained model together with the best metric """ state = TrainingState(model=model, optimizer=self.optimizer, scheduler=self.scheduler, rank=rank) self.set_up_training(state, training_data) while self.continue_training(state): state.epoch += 1 state.epochs_since_last_improvement += 1 print(f"Worker {state.rank} starting epoch {state.epoch}", flush=True) lrs = learning_rates(state.optimizer) print(f"Learning rate(s): {', '.join(map(str, lrs))}") with timing.time("train epoch"): state.stage = Stage.TRAIN state.model.train() self.run_epoch(state, training_data, metric_reporter) if not self.config.do_eval: continue with timing.time("eval epoch"): state.stage = Stage.EVAL model.eval(Stage.EVAL) with torch.no_grad(): eval_metric = self.run_epoch(state, eval_data, metric_reporter) # Step the learning rate scheduler(s) assert eval_metric is not None state.scheduler.step_epoch( metrics=metric_reporter.get_model_select_metric(eval_metric), epoch=state.epoch, ) # Did we train a better model? if metric_reporter.compare_metric(eval_metric, state.best_model_metric): state.epochs_since_last_improvement = 0 state.best_model_metric = eval_metric self.save_checkpoint(state, train_config) # Only bother loading the best model for master worker if rank == 0 and state.best_model_state is not None: self.load_best_model(state) return state.model, state.best_model_metric
def train( self, train_iter: BatchIterator, eval_iter: BatchIterator, model: Model, metric_reporter: MetricReporter, train_config: PyTextConfig, optimizers: List[torch.optim.Optimizer], scheduler=None, rank: int = 0, ) -> Tuple[torch.nn.Module, Any]: if cuda_utils.CUDA_ENABLED: model = model.cuda() if cuda_utils.DISTRIBUTED_WORLD_SIZE > 1: device_id = torch.cuda.current_device() model = DistributedModel( module=model, device_ids=[device_id], output_device=device_id, broadcast_buffers=False, ) best_metric = None last_best_epoch = 0 best_model_path = None scheduler = self._prepare_scheduler(train_iter, scheduler) def training_pre_batch_callback(): optimizer_zero_grad(optimizers) def training_backprop(loss): loss.backward() if scheduler: scheduler.step_batch() if self.config.max_clip_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), self.config.max_clip_norm) else: grad_norm = None optimizer_step(optimizers) # grad_norm could be used to check grads sync in distributed training return grad_norm len_sched_ix = 0 # Used since we need the infinite iterator (only created and called once) def batch_generator_for_epoch(it): n = len(it) while n > 0: yield next(it) n -= 1 for epoch in range(self.config.start_epoch, self.config.epochs + 1): # Set the dialogue length in the fields, to be used by the postprocessor while self.config.length_schedule_per_epoch \ and len_sched_ix < len(self.config.length_schedule_per_epoch) \ and epoch >= self.config.length_schedule_per_epoch[len_sched_ix][0]: train_iter.max_n_turns = \ self.config.length_schedule_per_epoch[len_sched_ix][1] eval_iter.max_n_turns = \ self.config.length_schedule_per_epoch[len_sched_ix][1] len_sched_ix += 1 LOG.info(f"\nRank {rank} worker: Starting epoch #{epoch}") model.train() lrs = (str(lr) for lr in learning_rates(optimizers)) LOG.info(f"Learning rate(s): {', '.join(lrs)}") self._run_epoch( Stage.TRAIN, epoch, batch_generator_for_epoch(train_iter), model, metric_reporter, pre_batch=training_pre_batch_callback, backprop=training_backprop, rank=rank, ) model.eval(Stage.EVAL) with torch.no_grad(): eval_metric = self._run_epoch( Stage.EVAL, epoch, batch_generator_for_epoch(eval_iter), model, metric_reporter, rank=rank) # Step the learning rate scheduler(s) if scheduler: assert eval_metric is not None scheduler.step( metrics=metric_reporter.get_model_select_metric( eval_metric), epoch=epoch, ) # choose best model. if metric_reporter.compare_metric(eval_metric, best_metric): LOG.info( f"Rank {rank} worker: Found a better model! Saving the model state for epoch #{epoch}." ) last_best_epoch = epoch best_metric = eval_metric # Only rank = 0 trainer saves modules. if train_config.save_module_checkpoints and rank == 0: best_model_path = os.path.join( train_config.modules_save_dir, "best_model") optimizer, = optimizers # PyText only ever returns a single optimizer in this list torch.save( ModelState( epoch=epoch, parameters=model.state_dict(), optimizer=optimizer.state_dict(), ), best_model_path) if (self.config.early_stop_after > 0 and (epoch - last_best_epoch == self.config.early_stop_after)): LOG.info( f"Rank {rank} worker: Eval metric hasn't changed for " f"{self.config.early_stop_after} epochs. Stopping now.") break sys.stdout.flush() train_iter.close() eval_iter.close() model.load_state_dict(torch.load(best_model_path).parameters) return model, best_metric