def __init__(self, trainer): self.trainer = trainer self.early_stopping_accumulator = None self.checkpoint_accumulator = None self.accumulated_loss = None self._teardown_already_run = False self.running_loss = TensorRunningAccum(window_length=20)
def on_train_epoch_start(self, epoch): # update training progress in trainer self.trainer.current_epoch = epoch model = self.trainer.lightning_module # reset train dataloader if epoch != 0 and self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_train_dataloader(model) # todo: specify the possible exception with suppress(Exception): # set seed for distributed sampler (enables shuffling for each epoch) self.trainer.train_dataloader.sampler.set_epoch(epoch) # changing gradient according accumulation_scheduler self.trainer.accumulation_scheduler.on_epoch_start( self.trainer, self.trainer.lightning_module) # stores accumulated grad fractions per batch self.accumulated_loss = TensorRunningAccum( window_length=self.trainer.accumulate_grad_batches) # structured result accumulators for callbacks self.early_stopping_accumulator = Accumulator() self.checkpoint_accumulator = Accumulator() # hook self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start")
def on_train_epoch_start(self, epoch): model = self.trainer.get_model() # set seed for distributed sampler (enables shuffling for each epoch) try: self.trainer.train_dataloader.sampler.set_epoch(epoch) except Exception: pass # update training progress in trainer and model model.current_epoch = epoch self.trainer.current_epoch = epoch # changing gradient according accumulation_scheduler self.trainer.accumulation_scheduler.on_epoch_start( self.trainer, self.trainer.get_model()) # stores accumulated grad fractions per batch self.accumulated_loss = TensorRunningAccum( window_length=self.trainer.accumulate_grad_batches) # structured result accumulators for callbacks self.early_stopping_accumulator = Accumulator() self.checkpoint_accumulator = Accumulator() # hook self.trainer.call_hook('on_epoch_start') self.trainer.call_hook('on_train_epoch_start')
def __init__(self, trainer): self.trainer = trainer self.early_stopping_accumulator = None self.checkpoint_accumulator = None self.accumulated_loss = None self.warning_cache = WarningCache() self._teardown_already_run = False self.running_loss = TensorRunningAccum(window_length=20) self.automatic_optimization = True
def __init__(self) -> None: super().__init__() self.accumulated_loss = TensorRunningAccum(window_length=20) self.running_loss = TensorRunningAccum(window_length=20) # the current split index when the batch gets split into chunks in truncated backprop through time self.split_idx: int = 0 self.optimizer_loop = OptimizerLoop() self.manual_loop = ManualOptimization() self._outputs: _OUTPUTS_TYPE = [] self._remaining_splits: List[Tuple[int, Any]] = []
def __init__(self, trainer, multiple_trainloader_mode: str): self.trainer = trainer self.accumulated_loss = None self.warning_cache = WarningCache() self._teardown_already_run = False self.running_loss = TensorRunningAccum(window_length=20) self._curr_step_result = None self._cur_grad_norm_dict = None self._multiple_trainloader_mode = multiple_trainloader_mode self._skip_backward = False self.trainer._multiple_trainloader_mode = multiple_trainloader_mode
def __init__(self, trainer, multiple_trainloader_mode): self.trainer = trainer self.early_stopping_accumulator = None self.checkpoint_accumulator = None self.accumulated_loss = None self.warning_cache = WarningCache() self._teardown_already_run = False self.running_loss = TensorRunningAccum(window_length=20) self.automatic_optimization = True self._curr_step_result = None self._cur_grad_norm_dict = None self._multiple_trainloader_mode = multiple_trainloader_mode self._skip_backward = False self.trainer._multiple_trainloader_mode = multiple_trainloader_mode
def on_advance_start(self) -> None: """Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and ``on_train_epoch_start``""" model = self.trainer.lightning_module # reset train dataloader if not self._is_fresh_start_epoch and self.trainer._should_reload_dl_epoch: self.trainer.reset_train_dataloader(model) self._is_fresh_start_epoch = False if self.trainer.train_dataloader is not None and callable( getattr(self.trainer.train_dataloader.sampler, "set_epoch", None) ): # set seed for distributed sampler (enables shuffling for each epoch) self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch) # changing gradient according accumulation_scheduler self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module) # stores accumulated grad fractions per batch self.epoch_loop.batch_loop.accumulated_loss = TensorRunningAccum( window_length=self.trainer.accumulate_grad_batches ) self.epoch_progress.increment_ready()
def __init__( self, trainer, max_epochs: Optional[int], min_epochs: Optional[int], max_steps: Optional[int], min_steps: Optional[int], num_sanity_val_steps: int, ): self.trainer = trainer self.accumulated_loss = None self.warning_cache = WarningCache() self._teardown_already_run = False self.running_loss = TensorRunningAccum(window_length=20) self._skip_backward = False self._optimizer_freq_cumsum = None self._hiddens = None self.global_step = 0 self.current_epoch = 0 self.trainer.should_stop = False # the total batch index across all epochs self.total_batch_idx = 0 # the current batch index in the loop that runs over the dataloader(s) self.batch_idx = 0 # the current split index when the batch gets split into chunks in truncated backprop through time self.split_idx = None self.trainer.num_training_batches = 0 self.trainer.train_dataloader = None # If neither max_epochs or max_steps is set, then use existing default of max_epochs = 1000 self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs # If neither min_epochs or min_steps is set, then use existing default of min_epochs = 1 self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs self.max_steps = max_steps self.min_steps = min_steps if num_sanity_val_steps == -1: self.trainer.num_sanity_val_steps = float("inf") else: self.trainer.num_sanity_val_steps = num_sanity_val_steps
def __init__( self, trainer, multiple_trainloader_mode: str, max_epochs: Optional[int], min_epochs: Optional[int], max_steps: Optional[int], min_steps: Optional[int], num_sanity_val_steps: int, ): self.trainer = trainer self.accumulated_loss = None self.warning_cache = WarningCache() self._teardown_already_run = False self.running_loss = TensorRunningAccum(window_length=20) self._curr_step_result = None self._cur_grad_norm_dict = None self._multiple_trainloader_mode = multiple_trainloader_mode self._skip_backward = False self.trainer._multiple_trainloader_mode = multiple_trainloader_mode self._optimizer_freq_cumsum = None self.global_step = 0 self.current_epoch = 0 self.trainer.should_stop = False self.total_batch_idx = 0 self.batch_idx = 0 self.trainer.num_training_batches = 0 self.trainer.train_dataloader = None # If neither max_epochs or max_steps is set, then use existing default of max_epochs = 1000 self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs # If neither min_epochs or min_steps is set, then use existing default of min_epochs = 1 self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs self.max_steps = max_steps self.min_steps = min_steps if num_sanity_val_steps == -1: self.trainer.num_sanity_val_steps = float("inf") else: self.trainer.num_sanity_val_steps = num_sanity_val_steps
def test_tensor_running_accum_reset(): """ Test that reset would set all attributes to the initialization state """ window_length = 10 accum = TensorRunningAccum(window_length=window_length) assert accum.last() is None assert accum.mean() is None accum.append(torch.tensor(1.5)) assert accum.last() == torch.tensor(1.5) assert accum.mean() == torch.tensor(1.5) accum.reset() assert accum.window_length == window_length assert accum.memory is None assert accum.current_idx == 0 assert accum.last_idx is None assert not accum.rotated
def _update_running_loss(self, outputs: Mapping[str, torch.Tensor]) -> None: for k, v in outputs.items(): if "weighted_loss" in k: continue if "loss" not in k: continue self._running_loss.setdefault(k, TensorRunningAccum(window_length=20)) self._running_loss[k].append(v.mean())
def __init__(self) -> None: super().__init__() self.accumulated_loss: Optional[Tensor] = None self.batch_outputs: Optional[List[List[STEP_OUTPUT]]] = None self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=20) self.batch_idx: int = 0 self.split_idx: Optional[int] = None self.warning_cache: WarningCache = WarningCache() self._hiddens: Optional[Tensor] = None self._optimizer_freq_cumsum: Optional[int] = None self._remaining_splits: Optional[List[Any]] = None self._skip_backward: bool = False
def __init__(self) -> None: super().__init__() self.accumulated_loss: Optional[Tensor] = None self.batch_outputs: Optional[List[List[STEP_OUTPUT]]] = None self.running_loss: TensorRunningAccum = TensorRunningAccum( window_length=20) # the current split index when the batch gets split into chunks in truncated backprop through time self.split_idx: Optional[int] = None self.optimizer_loop = OptimizerLoop() self._warning_cache: WarningCache = WarningCache() self._hiddens: Optional[Tensor] = None self._optimizer_freq_cumsum: Optional[int] = None self._remaining_splits: Optional[List[Any]] = None
def on_advance_start(self) -> None: """Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and ``on_train_epoch_start``""" model = self.trainer.lightning_module # reset train dataloader if self.current_epoch != 0 and self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_train_dataloader(model) # TODO: specify the possible exception with suppress(Exception): # set seed for distributed sampler (enables shuffling for each epoch) self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch) # changing gradient according accumulation_scheduler self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module) # stores accumulated grad fractions per batch self.epoch_loop.batch_loop.accumulated_loss = TensorRunningAccum( window_length=self.trainer.accumulate_grad_batches )
def train(self): self.run_sanity_check(self.get_model()) # TODO: shrink # clear cache before training if self.on_gpu and self.root_gpu is not None: # use context because of: # https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898 with torch.cuda.device(f'cuda:{self.root_gpu}'): torch.cuda.empty_cache() # get model model = self.get_model() # enable train mode model.train() # enable gradients torch.set_grad_enabled(True) # load data # if reload_dataloaders_every_epoch, this is moved to the epoch loop if not self.reload_dataloaders_every_epoch: self.reset_train_dataloader(model) if self.val_dataloaders is None and not self.reload_dataloaders_every_epoch: self.reset_val_dataloader(model) # Train start events with self.profiler.profile('on_train_start'): # callbacks self.on_train_start() # model hooks model.on_train_start() try: # run all epochs for epoch in range(self.current_epoch, self.max_epochs): # reset train dataloader if self.reload_dataloaders_every_epoch: self.reset_train_dataloader(model) # set seed for distributed sampler (enables shuffling for each epoch) if (self.use_ddp or self.use_horovod or self.on_tpu) \ and hasattr(self.train_dataloader, 'sampler') \ and hasattr(self.train_dataloader.sampler, 'set_epoch'): self.train_dataloader.sampler.set_epoch(epoch) # update training progress in trainer and model model.current_epoch = epoch self.current_epoch = epoch # changing gradient according accumulation_scheduler self.accumulation_scheduler.on_epoch_start(self, self.get_model()) # stores accumulated grad fractions per batch self.batch_loss_value = TensorRunningAccum( window_length=self.accumulate_grad_batches ) # ----------------- # RUN TNG EPOCH # ----------------- self.run_training_epoch() if self.max_steps and self.max_steps <= self.global_step: self.run_training_teardown() return # update LR schedulers self.update_learning_rates(interval='epoch') # early stopping met_min_epochs = epoch >= self.min_epochs - 1 met_min_steps = self.global_step >= self.min_steps if self.min_steps else True if self.should_stop: if (met_min_epochs and met_min_steps): self.run_training_teardown() return else: log.info('Trainer was signaled to stop but required minimum epochs' f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has' ' not been met. Training will continue...') self.run_training_teardown() except KeyboardInterrupt: rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...') # user could press ctrl+c many times... only shutdown once if not self.interrupted: self.interrupted = True self._state = TrainerState.INTERRUPTED self.on_keyboard_interrupt() self.run_training_teardown()
class TrainLoop: def __init__(self, trainer, multiple_trainloader_mode): self.trainer = trainer self.early_stopping_accumulator = None self.checkpoint_accumulator = None self.accumulated_loss = None self.warning_cache = WarningCache() self._teardown_already_run = False self.running_loss = TensorRunningAccum(window_length=20) self.automatic_optimization = True self._curr_step_result = None self._cur_grad_norm_dict = None self._multiple_trainloader_mode = multiple_trainloader_mode self._skip_backward = False self.trainer._multiple_trainloader_mode = multiple_trainloader_mode def on_trainer_init( self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps, weights_summary, ): self.trainer.global_step = 0 self.trainer.current_epoch = 0 self.trainer.should_stop = False self.trainer._state = TrainerState.INITIALIZING self.trainer.total_batch_idx = 0 self.trainer.batch_idx = 0 self.trainer.num_training_batches = 0 self.trainer.train_dataloader = None # If neither max_epochs or max_steps is set, then use existing default of max_epochs = 1000 self.trainer.max_epochs = 1000 if ( max_epochs is None and max_steps is None) else max_epochs # If neither min_epochs or min_steps is set, then use existing default of min_epochs = 1 self.trainer.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs self.trainer.max_steps = max_steps self.trainer.min_steps = min_steps if num_sanity_val_steps == -1: self.trainer.num_sanity_val_steps = float("inf") else: self.trainer.num_sanity_val_steps = num_sanity_val_steps self.trainer.weights_summary = weights_summary if weights_summary is not None and weights_summary not in ModelSummary.MODES: raise MisconfigurationException( f"`weights_summary` can be None, {', '.join(ModelSummary.MODES)}, got {weights_summary}" ) @property def num_optimizers(self): num_optimizers = len(self.get_optimizers_iterable()) return num_optimizers def should_skip_training(self): should_by_max_steps = self.trainer.max_steps is not None and self.trainer.global_step >= self.trainer.max_steps should_by_epoch = self.trainer.max_epochs is not None and self.trainer.current_epoch >= self.trainer.max_epochs return should_by_max_steps or should_by_epoch or self.trainer.num_training_batches == 0 def on_train_start(self): # hook self.trainer.call_hook("on_train_start") def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None): # clean hparams if hasattr(model, "hparams"): parsing.clean_namespace(model.hparams) # links data to the trainer self.trainer.data_connector.attach_data(model, train_dataloader, val_dataloaders, datamodule) # check that model is configured correctly self.trainer.config_validator.verify_loop_configurations(model) # attach model log function to callback self.trainer.callback_connector.attach_model_logging_functions(model) def on_train_end(self): if self._teardown_already_run: return self._teardown_already_run = True # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates # when a checkpoint was saved at the last step self.trainer.global_step -= 1 self.check_checkpoint_callback(should_update=True, is_last=True) self.trainer.global_step += 1 # hook self.trainer.call_hook("on_train_end") # todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers. # It might be related to xla tensors blocked when moving the cpu # kill loggers if self.trainer.logger is not None: self.trainer.logger.finalize("success") # summarize profile results self.trainer.profiler.describe() # give accelerators a chance to finish self.trainer.accelerator.on_train_end() # reset bookkeeping self.trainer._running_stage = None def check_checkpoint_callback(self, should_update, is_last=False): # TODO bake this logic into the ModelCheckpoint callback if should_update and self.trainer.checkpoint_connector.has_trained: callbacks = self.trainer.checkpoint_callbacks if is_last and any(cb.save_last and cb.verbose for cb in callbacks): rank_zero_info("Saving latest checkpoint...") model = self.trainer.lightning_module for cb in callbacks: cb.on_validation_end(self.trainer, model) def check_early_stopping_callback(self, should_update): # TODO bake this logic into the EarlyStopping callback if should_update and self.trainer.checkpoint_connector.has_trained: callbacks = [ c for c in self.trainer.callbacks if isinstance(c, EarlyStopping) ] model = self.trainer.lightning_module for cb in callbacks: cb.on_validation_end(self.trainer, model) def on_train_epoch_start(self, epoch): # update training progress in trainer self.trainer.current_epoch = epoch model = self.trainer.lightning_module # reset train dataloader if epoch != 0 and self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_train_dataloader(model) # todo: specify the possible exception with suppress(Exception): # set seed for distributed sampler (enables shuffling for each epoch) self.trainer.train_dataloader.sampler.set_epoch(epoch) # changing gradient according accumulation_scheduler self.trainer.accumulation_scheduler.on_epoch_start( self.trainer, self.trainer.lightning_module) # stores accumulated grad fractions per batch self.accumulated_loss = TensorRunningAccum( window_length=self.trainer.accumulate_grad_batches) # structured result accumulators for callbacks self.early_stopping_accumulator = Accumulator() self.checkpoint_accumulator = Accumulator() # hook self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start") def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx): # hook self.trainer.call_hook('on_train_batch_end', batch_end_outputs, batch, batch_idx, dataloader_idx) self.trainer.call_hook('on_batch_end') # figure out what to track for epoch end self.track_epoch_end_reduce_metrics(epoch_output, batch_end_outputs) # reset batch logger internals self.trainer.logger_connector.on_train_batch_end() def reset_train_val_dataloaders(self, model): if self.trainer.train_dataloader is None or not self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_train_dataloader(model) if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_val_dataloader(model) def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): # track the outputs to reduce at the end of the epoch for opt_idx, opt_outputs in enumerate(batch_end_outputs): sample_output = opt_outputs[-1] # decide if we need to reduce at the end of the epoch automatically auto_reduce_tng_result = isinstance( sample_output, Result) and sample_output.should_reduce_on_epoch_end hook_overridden = (is_overridden( "training_epoch_end", model=self.trainer.lightning_module) or is_overridden( "on_train_epoch_end", model=self.trainer.lightning_module)) # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end if not (hook_overridden or auto_reduce_tng_result): continue # with 1 step (no tbptt) don't use a sequence at epoch end if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance( opt_outputs[0], Result): opt_outputs = opt_outputs[0] epoch_output[opt_idx].append(opt_outputs) def get_optimizers_iterable(self): """ Generates an iterable with (idx, optimizer) for each optimizer. """ if not self.trainer.optimizer_frequencies: # call training_step once per optimizer return list(enumerate(self.trainer.optimizers)) optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) optimizers_loop_length = optimizer_freq_cumsum[-1] current_place_in_loop = self.trainer.total_batch_idx % optimizers_loop_length # find optimzier index by looking for the first {item > current_place} in the cumsum list opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop) return [[opt_idx, self.trainer.optimizers[opt_idx]]] def on_after_backward(self, training_step_output, batch_idx, untouched_loss): is_result_obj = isinstance(training_step_output, Result) if is_result_obj: training_step_output = training_step_output.detach() else: training_step_output.batch_loss = training_step_output.batch_loss.detach( ) # insert after step hook self.trainer.call_hook("on_after_backward") # when in dev debugging track the losses self.trainer.dev_debugger.track_train_loss_history( batch_idx, untouched_loss.detach()) def _check_training_step_output(self, training_step_output): if isinstance(training_step_output, torch.Tensor) and not self.automatic_optimization: if training_step_output.grad_fn is None: # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ... raise MisconfigurationException( "In manual optimization, `training_step` should not return a Tensor" ) def training_step(self, split_batch, batch_idx, opt_idx, hiddens): # give the PL module a result for logging model_ref = self.trainer.lightning_module with self.trainer.profiler.profile("model_forward"): args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens) # manually capture logged metrics model_ref._current_fx_name = 'training_step' model_ref._results = Result() with self.trainer.profiler.profile("training_step"): training_step_output = self.trainer.accelerator.training_step( args) self.trainer.accelerator.post_training_step() self.trainer.logger_connector.cache_logged_metrics() self._check_training_step_output(training_step_output) training_step_output = self.trainer.call_hook( "training_step_end", training_step_output) training_step_output_for_epoch_end, training_step_output = self._process_training_step_output( training_step_output, split_batch) is_result_obj = isinstance(training_step_output, Result) if training_step_output_for_epoch_end is None: return None # enable empty loss when using manual opt closure_loss = None untouched_loss = None if self.automatic_optimization: # accumulate loss # (if accumulate_grad_batches = 1 no effect) if is_result_obj: closure_loss = training_step_output.minimize else: closure_loss = training_step_output.batch_loss closure_loss = closure_loss / self.trainer.accumulate_grad_batches # the loss will get scaled for amp. avoid any modifications to it untouched_loss = closure_loss.detach().clone() # result result = AttributeDict( closure_loss=closure_loss, loss=untouched_loss, training_step_output=training_step_output, training_step_output_for_epoch_end= training_step_output_for_epoch_end, hiddens=training_step_output.hiddens, ) return result def _process_training_step_output(self, training_step_output, split_batch): training_step_output_for_epoch_end = training_step_output # enable validation_step return None if training_step_output_for_epoch_end is None: return None, None # ----------------------------------------- # process hybrid (1.0) # ----------------------------------------- # no need for these checks in 1.0.0 # TODO: remove checks in 1.0.0 is_tensor = isinstance(training_step_output_for_epoch_end, torch.Tensor) is_1_0_output = is_tensor or ("log" not in training_step_output and "progress_bar" not in training_step_output) if is_1_0_output: return self._process_training_step_output_1_0( training_step_output, split_batch) # ----------------------------------------- # process old dict (deprecate 1.0) # ----------------------------------------- training_step_output = self.trainer.process_dict_result( training_step_output, train=True) training_step_output = AttributeDict( batch_loss=training_step_output[0], pbar_on_batch_end=training_step_output[1], log_metrics=training_step_output[2], callback_metrics=training_step_output[3], hiddens=training_step_output[4], ) # if the user decides to finally reduce things in epoch_end, save raw output without graphs if isinstance(training_step_output_for_epoch_end, torch.Tensor): training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach( ) else: training_step_output_for_epoch_end = recursive_detach( training_step_output_for_epoch_end) return training_step_output_for_epoch_end, training_step_output def _process_training_step_output_1_0(self, training_step_output, split_batch): result = self.trainer.lightning_module._results loss = None hiddens = None # handle dict return if isinstance(training_step_output, dict): loss = training_step_output.pop("loss", None) hiddens = training_step_output.pop("hiddens", None) result["extra"] = training_step_output # handle scalar return elif isinstance(training_step_output, torch.Tensor): loss = training_step_output result["extra"] = {} # map to results under the hood result.minimize = loss result.hiddens = hiddens # track batch for manual reduction with result result.track_batch_size(len(split_batch)) # track metrics without grads for epoch reduction training_step_output_for_epoch_end = copy(result) training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach( ) if self.trainer.move_metrics_to_cpu: training_step_output_for_epoch_end = training_step_output_for_epoch_end.cpu( ) # what flows back into the system training_step_output = result return training_step_output_for_epoch_end, training_step_output def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): model_ref = self.trainer.lightning_module is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) using_native_amp = self.trainer.amp_backend == AMPType.NATIVE # native amp + lbfgs is a no go right now if using_native_amp and is_lbfgs: raise MisconfigurationException( 'native PyTorch amp and lbfgs are not compatible.' ' To request, please file a Github issue in PyTorch and tag @mcarilli' ) # wraps into LightningOptimizer only for running step optimizer = LightningOptimizer._to_lightning_optimizer( optimizer, self.trainer, opt_idx) # model hook model_ref.optimizer_step( self.trainer.current_epoch, batch_idx, optimizer, opt_idx, train_step_and_backward_closure, on_tpu=self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE, using_native_amp=using_native_amp, using_lbfgs=is_lbfgs, ) def on_before_zero_grad(self, optimizer): self.trainer.call_hook('on_before_zero_grad', optimizer) def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): self.trainer.accelerator.optimizer_zero_grad( self.trainer.current_epoch, batch_idx, optimizer, opt_idx) def track_and_norm_grad(self, optimizer): # track gradient norms grad_norm_dic = self._track_gradient_norm() # clip gradients self.trainer.accelerator.clip_gradients(optimizer, self.trainer.gradient_clip_val) self._cur_grad_norm_dict = grad_norm_dic def _track_gradient_norm(self): grad_norm_dict = {} if (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0: if float(self.trainer.track_grad_norm) > 0: model = self.trainer.lightning_module grad_norm_dict = model.grad_norm(self.trainer.track_grad_norm) return grad_norm_dict def process_hiddens(self, opt_closure_result): hiddens = opt_closure_result.hiddens if isinstance(opt_closure_result.training_step_output, Result): opt_closure_result.training_step_output_for_epoch_end.drop_hiddens( ) return hiddens def tbptt_split_batch(self, batch): splits = [batch] if self.trainer.truncated_bptt_steps is not None: model_ref = self.trainer.lightning_module with self.trainer.profiler.profile("tbptt_split_batch"): splits = model_ref.tbptt_split_batch( batch, self.trainer.truncated_bptt_steps) return splits def run_training_epoch(self): # modify dataloader if needed (ddp, etc...) train_dataloader = self.trainer.accelerator.process_dataloader( self.trainer.train_dataloader) # track epoch output epoch_output = [[] for _ in range(self.num_optimizers)] train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader( train_dataloader) dataloader_idx = 0 val_loop_called = False for batch_idx, (batch, is_last_batch) in train_dataloader: self.trainer.batch_idx = batch_idx # ------------------------------------ # TRAINING_STEP + TRAINING_STEP_END # ------------------------------------ with self.trainer.profiler.profile("run_training_batch"): batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx) # when returning -1 from train_step, we end epoch early if batch_output.signal == -1: break batch_end_outputs = self.process_train_step_outputs( batch_output.training_step_output_for_epoch_end, self.early_stopping_accumulator, self.checkpoint_accumulator, ) # hook # TODO: add outputs to batches self.on_train_batch_end(epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx) # ----------------------------------------- # SAVE METRICS TO LOGGERS # ----------------------------------------- self.trainer.logger_connector.log_train_step_metrics(batch_output) # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- should_check_val = self.should_check_val_fx( batch_idx, is_last_batch) if should_check_val: self.trainer.validating = True self.trainer.run_evaluation() self.trainer.training = True val_loop_called = True # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) # ----------------------------------------- self.save_loggers_on_train_batch_end() # update LR schedulers monitor_metrics = deepcopy( self.trainer.logger_connector.callback_metrics) self.update_train_loop_lr_schedulers( monitor_metrics=monitor_metrics) self.trainer.checkpoint_connector.has_trained = True # max steps reached, end training if (self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1 and self._accumulated_batches_reached()): break # end epoch early # stop when the flag is changed or we've gone past the amount # requested in the batches if self.trainer.should_stop: break self.trainer.total_batch_idx += 1 # stop epoch if we limited the number of training batches if self._num_training_batches_reached(is_last_batch): break # progress global step according to grads progress self.increment_accumulated_grad_global_step() # epoch end hook self.run_on_epoch_end_hook(epoch_output) # log epoch metrics self.trainer.logger_connector.log_train_epoch_end_metrics( epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers) should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation( self.trainer.num_val_batches) should_train_only = self.trainer.disable_validation or should_skip_eval # update epoch level lr_schedulers if no val loop outside train loop is triggered if (val_loop_called and not should_check_val) or should_train_only: self.trainer.optimizer_connector.update_learning_rates( interval='epoch') if should_train_only: self.check_checkpoint_callback(True) self.check_early_stopping_callback(True) if should_check_val: self.trainer.validating = True self.trainer.run_evaluation(on_epoch=True) self.trainer.training = True # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() def run_training_batch(self, batch, batch_idx, dataloader_idx): # track grad norms grad_norm_dic = {} # bookkeeping self.trainer.hiddens = None # track all outputs across time and num of optimizers batch_outputs = [[] for _ in range(len(self.get_optimizers_iterable()))] if batch is None: return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic) # hook response = self.trainer.call_hook("on_batch_start") if response == -1: return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) # hook response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, dataloader_idx) if response == -1: return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) # lightning module hook splits = self.tbptt_split_batch(batch) for split_idx, split_batch in enumerate(splits): # create an iterable for optimizers and loop over them for opt_idx, optimizer in self.prepare_optimizers(): # toggle model params + set info to logger_connector self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) if self.should_accumulate(): # For gradient accumulation # ------------------- # calculate loss (train step + train step end) # ------------------- # automatic_optimization=True: perform dpp sync only when performing optimizer_step # automatic_optimization=False: don't block synchronization here with self.block_ddp_sync_behaviour(): self.training_step_and_backward( split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) batch_outputs = self._process_closure_result( batch_outputs=batch_outputs, opt_idx=opt_idx, ) # ------------------------------ # BACKWARD PASS # ------------------------------ # gradient update with accumulated gradients else: if self.automatic_optimization: def train_step_and_backward_closure(): result = self.training_step_and_backward( split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) return None if result is None else result.loss # optimizer step self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) else: self._curr_step_result = self.training_step( split_batch, batch_idx, opt_idx, self.trainer.hiddens) if self._curr_step_result is None: # user decided to skip optimization # make sure to zero grad. continue batch_outputs = self._process_closure_result( batch_outputs=batch_outputs, opt_idx=opt_idx, ) # todo: Properly aggregate grad_norm accros opt_idx and split_idx grad_norm_dic = self._cur_grad_norm_dict self._cur_grad_norm_dict = None # update running loss + reset accumulated loss self.update_running_loss() result = AttributeDict( signal=0, grad_norm_dic=grad_norm_dic, training_step_output_for_epoch_end=batch_outputs, ) return result @contextmanager def block_ddp_sync_behaviour(self, should_block_sync: bool = False): """ automatic_optimization = True Blocks ddp sync gradients behaviour on backwards pass. This is useful for skipping sync when accumulating gradients, reducing communication overhead automatic_optimization = False do not block ddp gradient sync when using manual optimization as gradients are needed within the training step Returns: context manager with sync behaviour off """ if (isinstance(self.trainer.training_type_plugin, ParallelPlugin) and (self.automatic_optimization or should_block_sync)): with self.trainer.training_type_plugin.block_backward_sync(): yield None else: yield None def _process_closure_result(self, batch_outputs: list, opt_idx: int) -> list: opt_closure_result = self._curr_step_result if opt_closure_result is not None: # cache metrics self.trainer.logger_connector.cache_training_step_metrics( opt_closure_result) # track hiddens self.trainer.hiddens = self.process_hiddens(opt_closure_result) # check if loss or model weights are nan if self.trainer.terminate_on_nan: self.trainer.detect_nan_tensors(opt_closure_result.loss) # track all the outputs across all steps batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0 batch_outputs[batch_opt_idx].append( opt_closure_result.training_step_output_for_epoch_end) if self.automatic_optimization: # track total loss for logging (avoid mem leaks) self.accumulated_loss.append(opt_closure_result.loss) self._curr_step_result = None return batch_outputs def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): """ wrap the forward step in a closure so second order methods work """ with self.trainer.profiler.profile("training_step_and_backward"): # lightning module hook result = self.training_step(split_batch, batch_idx, opt_idx, hiddens) self._curr_step_result = result if not self._skip_backward and self.automatic_optimization: is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0 if is_first_batch_to_accumulate: self.on_before_zero_grad(optimizer) self.optimizer_zero_grad(batch_idx, optimizer, opt_idx) # backward pass if result is not None: with self.trainer.profiler.profile("model_backward"): self.backward(result, optimizer, opt_idx) # hook - call this hook only # when gradients have finished to accumulate if not self.should_accumulate(): self.on_after_backward(result.training_step_output, batch_idx, result.loss) # check if loss or model weights are nan if self.trainer.terminate_on_nan: self.trainer.detect_nan_tensors(result.loss) else: self.warning_cache.warn( "training_step returned None if it was on purpose, ignore this warning..." ) if len(self.trainer.optimizers) > 1: # revert back to previous state self.trainer.lightning_module.untoggle_optimizer(opt_idx) return result def backward(self, result, optimizer, opt_idx, *args, **kwargs): self.trainer.dev_debugger.track_event("backward_call") should_accumulate = self.should_accumulate() # backward can be called manually in the training loop if isinstance(result, torch.Tensor): self.trainer.accelerator.backward(result, optimizer, opt_idx, should_accumulate, *args, **kwargs) else: result.closure_loss = self.trainer.accelerator.backward( result.closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs) if not self.should_accumulate(): # track gradients self.track_and_norm_grad(optimizer=optimizer) def update_train_loop_lr_schedulers(self, monitor_metrics=None): num_accumulated_batches_reached = self._accumulated_batches_reached() num_training_batches_reached = self._num_training_batches_reached() if num_accumulated_batches_reached or num_training_batches_reached: # update lr self.trainer.optimizer_connector.update_learning_rates( interval="step", monitor_metrics=monitor_metrics) def run_on_epoch_end_hook(self, epoch_output): # inform logger the batch loop has finished self.trainer.logger_connector.on_train_epoch_end() self.trainer.call_hook('on_train_epoch_end', epoch_output) self.trainer.call_hook('on_epoch_end') def increment_accumulated_grad_global_step(self): num_accumulated_batches_reached = self._accumulated_batches_reached() num_training_batches_reached = self._num_training_batches_reached() # progress global step according to grads progress if num_accumulated_batches_reached or num_training_batches_reached: self.trainer.global_step += 1 def _accumulated_batches_reached(self): return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 def _num_training_batches_reached(self, is_last_batch=False): return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch def should_accumulate(self): # checks if backward or backward + optimizer step (via closure) accumulation_done = self._accumulated_batches_reached() is_final_batch = self._num_training_batches_reached() return not (accumulation_done or is_final_batch) def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False): # decide if we should run validation is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 can_check_val = self.trainer.enable_validation and is_val_check_epoch is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float( "inf") epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0 should_check_val = ( (is_val_check_batch and epoch_end_val_check) or self.trainer.should_stop or is_last_batch_for_infinite_dataset ) if on_epoch else (is_val_check_batch and not epoch_end_val_check) return should_check_val and can_check_val def build_train_args(self, batch, batch_idx, opt_idx, hiddens): # enable not needing to add opt_idx to training_step args = [batch, batch_idx] if len(self.trainer.optimizers) > 1: if self.trainer.has_arg("training_step", "optimizer_idx"): if not self.automatic_optimization: self.warning_cache.warn( "`training_step` hook signature has changed in v1.3." " `optimizer_idx` argument has been removed in case of manual optimization. Support for" " the old signature will be removed in v1.5", DeprecationWarning) args.append(opt_idx) elif not self.trainer.has_arg( "training_step", "optimizer_idx") and self.automatic_optimization: raise ValueError( f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but" ' `training_step` is missing the `optimizer_idx` argument.' ) # pass hiddens if using tbptt if self.trainer.truncated_bptt_steps is not None: args.append(hiddens) return args def save_loggers_on_train_batch_end(self): # when loggers should save to disk should_flush_logs = self.trainer.logger_connector.should_flush_logs if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accumulator, checkpoint_accumulator): """ Figure out what needs to be tracked/logged at the end of the epoch """ # the training step outputs a list per optimizer. The list contains the outputs at each time step # when no TBPTT is used, then the list has 1 item per batch # when TBPTT IS used, then the list has n items (1 per time step) batch_end_outputs = [] for optimizer_idx_outputs in all_train_step_outputs: # extract one representative sample from each time step (1 if no tbptt) and 0th optimizer if len(optimizer_idx_outputs) == 0: continue sample_output = optimizer_idx_outputs[-1] # pull out callback info if available (ie: Results object) if isinstance(sample_output, dict) and "early_stop_on" in sample_output: early_stopping_accumulator.accumulate( sample_output["early_stop_on"]) if isinstance(sample_output, dict) and "checkpoint_on" in sample_output: checkpoint_accumulator.accumulate( sample_output["checkpoint_on"]) batch_end_outputs.append(optimizer_idx_outputs) return batch_end_outputs def prepare_optimizers(self): # in manual optimization we loop over all optimizers at once optimizers = self.get_optimizers_iterable() if not self.automatic_optimization: optimizers = [optimizers[0]] return optimizers def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): # set split_idx to trainer for tracking self.trainer.split_idx = split_idx # make sure only the gradients of the current optimizer's parameters are calculated # in the training step to prevent dangling gradients in multiple-optimizer setup. if self.automatic_optimization and len(self.trainer.optimizers) > 1: model = self.trainer.lightning_module model.toggle_optimizer(optimizer, opt_idx) # use to track metrics internally self.trainer.logger_connector.on_train_split_start( split_idx, opt_idx, split_batch) def update_running_loss(self): accumulated_loss = self.accumulated_loss.mean() if accumulated_loss is not None: # calculate running loss for display self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches) # reset for next set of accumulated grads self.accumulated_loss.reset()
class TrainLoop: def __init__(self, trainer): self.trainer = trainer self.should_check_val = False self.early_stopping_accumulator = None self.checkpoint_accumulator = None self.accumulated_loss = None self._teardown_already_run = False self.running_loss = TensorRunningAccum(window_length=20) def on_trainer_init(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps): self.trainer.global_step = 0 self.trainer.current_epoch = 0 self.trainer.interrupted = False self.trainer.should_stop = False self.trainer._state = TrainerState.INITIALIZING self.trainer.total_batch_idx = 0 self.trainer.batch_idx = 0 self.trainer.num_training_batches = 0 self.trainer.train_dataloader = None self.trainer.max_epochs = max_epochs self.trainer.min_epochs = min_epochs self.trainer.max_steps = max_steps self.trainer.min_steps = min_steps if num_sanity_val_steps == -1: self.trainer.num_sanity_val_steps = float('inf') else: self.trainer.num_sanity_val_steps = num_sanity_val_steps @property def num_optimizers(self): num_optimizers = len(self.get_optimizers_iterable()) return num_optimizers def on_train_start(self): # clear cache before training if self.trainer.on_gpu and self.trainer.root_gpu is not None: # use context because of: # https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898 with torch.cuda.device(f'cuda:{self.trainer.root_gpu}'): torch.cuda.empty_cache() # hook self.trainer.call_hook('on_train_start') def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): # bind logger and other properties self.trainer.model_connector.copy_trainer_model_properties(model) # clean hparams if hasattr(model, 'hparams'): parsing.clean_namespace(model.hparams) # links data to the trainer self.trainer.data_connector.attach_data(model, train_dataloader, val_dataloaders, datamodule) # check that model is configured correctly self.trainer.config_validator.verify_loop_configurations(model) def setup_training(self, model: LightningModule): """Sanity check a few things before starting actual training. Args: model: The model to run sanity test on. """ # -------------------------- # Setup?? # -------------------------- ref_model = model if self.trainer.data_parallel: ref_model = model.module # give model convenience properties ref_model.trainer = self.trainer # set local properties on the model self.trainer.model_connector.copy_trainer_model_properties(ref_model) # init amp. Must be done here instead of __init__ to allow ddp to work if self.trainer.amp_backend == AMPType.NATIVE and self.trainer.precision == 16 and not self.trainer.use_tpu: self.trainer.scaler = torch.cuda.amp.GradScaler() # log hyper-parameters if self.trainer.logger is not None: # save exp to get started self.trainer.logger.log_hyperparams(ref_model.hparams) self.trainer.logger.log_graph(ref_model) self.trainer.logger.save() # wait for all to join if on distributed self.trainer.accelerator_backend.barrier('setup_training') # register auto-resubmit when on SLURM self.trainer.slurm_connector.register_slurm_signal_handlers() # -------------------------- # Pre-train # -------------------------- # on pretrain routine start self.trainer.on_pretrain_routine_start(ref_model) if self.trainer.is_function_implemented('on_pretrain_routine_start'): ref_model.on_pretrain_routine_start() # print model summary if self.trainer.is_global_zero and self.trainer.weights_summary is not None and not self.trainer.testing: if self.trainer.weights_summary in ModelSummary.MODES: ref_model.summarize(mode=self.trainer.weights_summary) else: raise MisconfigurationException( "weights_summary can be None, " + ", ".join(ModelSummary.MODES)) # track model now. # if cluster resets state, the model will update with the saved weights self.trainer.model = model # restore training and model before hpc is called self.trainer.checkpoint_connector.restore_weights(model) # on pretrain routine end self.trainer.on_pretrain_routine_end(ref_model) if self.trainer.is_function_implemented('on_pretrain_routine_end'): ref_model.on_pretrain_routine_end() def on_train_end(self): if self._teardown_already_run: return self._teardown_already_run = True # Save latest checkpoint rank_zero_warn('Saving latest checkpoint..') self.check_checkpoint_callback(should_check_val=False, force_save=True) # hook self.trainer.call_hook('on_train_end') # kill loggers if self.trainer.logger is not None: self.trainer.logger.finalize("success") # summarize profile results if self.trainer.global_rank == 0: self.trainer.profiler.describe() if self.trainer.global_rank == 0: for proc in self.trainer.interactive_ddp_procs: subprocess.Popen.kill(proc) # clean up dist group if self.trainer.use_ddp or self.trainer.use_ddp2: torch_distrib.destroy_process_group() # clear mem if self.trainer.on_gpu: model = self.trainer.get_model() model.cpu() torch.cuda.empty_cache() def check_checkpoint_callback(self, should_check_val, force_save=False): model = self.trainer.get_model() # when no val loop is present or fast-dev-run still need to call checkpoints # TODO bake this logic into the checkpoint callback should_activate = not is_overridden('validation_step', model) and not should_check_val if should_activate or force_save: checkpoint_callbacks = [ c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint) ] [ c.on_validation_end(self.trainer, model) for c in checkpoint_callbacks ] def on_train_epoch_start(self, epoch): model = self.trainer.get_model() # set seed for distributed sampler (enables shuffling for each epoch) try: self.trainer.train_dataloader.sampler.set_epoch(epoch) except Exception: pass # update training progress in trainer and model model.current_epoch = epoch self.trainer.current_epoch = epoch # changing gradient according accumulation_scheduler self.trainer.accumulation_scheduler.on_epoch_start( self.trainer, self.trainer.get_model()) # stores accumulated grad fractions per batch self.accumulated_loss = TensorRunningAccum( window_length=self.trainer.accumulate_grad_batches) # bookkeeping self.should_check_val = False # structured result accumulators for callbacks self.early_stopping_accumulator = Accumulator() self.checkpoint_accumulator = Accumulator() # hook self.trainer.call_hook('on_epoch_start') self.trainer.call_hook('on_train_epoch_start') def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx): # figure out what to track for epoch end self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs) # hook self.trainer.call_hook('on_batch_end') self.trainer.call_hook('on_train_batch_end', batch, batch_idx, dataloader_idx) def reset_train_val_dataloaders(self, model): if not self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_train_dataloader(model) if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_val_dataloader(model) def track_epoch_end_reduce_metrics(self, epoch_output, epoch_end_outputs): # track the outputs to reduce at the end of the epoch for opt_idx, opt_outputs in enumerate(epoch_end_outputs): # with 1 step (no tbptt) don't use a sequence at epoch end if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance( opt_outputs[0], Result): opt_outputs = opt_outputs[0] epoch_output[opt_idx].append(opt_outputs) def get_optimizers_iterable(self): """ Generates an iterable with (idx, optimizer) for each optimizer. """ if not self.trainer.optimizer_frequencies: # call training_step once per optimizer return list(enumerate(self.trainer.optimizers)) optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) optimizers_loop_length = optimizer_freq_cumsum[-1] current_place_in_loop = self.trainer.total_batch_idx % optimizers_loop_length # find optimzier index by looking for the first {item > current_place} in the cumsum list opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop) return [(opt_idx, self.trainer.optimizers[opt_idx])] def backward(self, result, optimizer, opt_idx): # backward pass with self.trainer.profiler.profile('model_backward'): result.closure_loss = self.trainer.accelerator_backend.backward( result.closure_loss, optimizer, opt_idx) def on_after_backward(self, training_step_output, batch_idx, untouched_loss): is_result_obj = isinstance(training_step_output, Result) if is_result_obj: training_step_output.detach() else: training_step_output.batch_loss = training_step_output.batch_loss.detach( ) # insert after step hook self.trainer.call_hook('on_after_backward') # when in dev debugging track the losses self.trainer.dev_debugger.track_train_loss_history( batch_idx, untouched_loss.detach()) def training_step(self, split_batch, batch_idx, opt_idx, hiddens): with self.trainer.profiler.profile('model_forward'): args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens) training_step_output = self.trainer.accelerator_backend.training_step( args) training_step_output = self.trainer.call_hook( 'training_step_end', training_step_output) # ---------------------------- # PROCESS THE RESULT # ---------------------------- # format and reduce outputs accordingly training_step_output_for_epoch_end = training_step_output is_result_obj = isinstance(training_step_output, Result) # track batch size for weighted average if is_result_obj: training_step_output.track_batch_size(len(split_batch)) # don't allow EvalResult in the training_step if isinstance(training_step_output, EvalResult): raise MisconfigurationException( 'training_step cannot return EvalResult, ' 'use a dict or TrainResult instead') # handle regular dicts if not is_result_obj: training_step_output = self.trainer.process_output( training_step_output, train=True) training_step_output = AttributeDict( batch_loss=training_step_output[0], pbar_on_batch_end=training_step_output[1], log_metrics=training_step_output[2], callback_metrics=training_step_output[3], hiddens=training_step_output[4], ) # if the user decides to finally reduce things in epoch_end, save raw output without graphs if isinstance(training_step_output_for_epoch_end, torch.Tensor): training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach( ) elif is_result_obj: training_step_output_for_epoch_end = copy(training_step_output) training_step_output_for_epoch_end.detach() else: training_step_output_for_epoch_end = recursive_detach( training_step_output_for_epoch_end) # accumulate loss # (if accumulate_grad_batches = 1 no effect) closure_loss = training_step_output.minimize if is_result_obj else training_step_output.batch_loss closure_loss = closure_loss / self.trainer.accumulate_grad_batches # the loss will get scaled for amp. avoid any modifications to it untouched_loss = closure_loss.detach().clone() # result result = AttributeDict( closure_loss=closure_loss, loss=untouched_loss, training_step_output=training_step_output, training_step_output_for_epoch_end= training_step_output_for_epoch_end, hiddens=training_step_output.hiddens, ) return result def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): with self.trainer.profiler.profile('optimizer_step'): # optimizer step lightningModule hook self.trainer.accelerator_backend.optimizer_step( optimizer, batch_idx, opt_idx, train_step_and_backward_closure) def on_before_zero_grad(self, optimizer): model = self.trainer.get_model() model.on_before_zero_grad(optimizer) def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): self.trainer.accelerator_backend.optimizer_zero_grad( batch_idx, optimizer, opt_idx) def on_before_backward(self, batch_idx, optimizer): # track gradient norms grad_norm_dic = self._track_gradient_norm(batch_idx) # clip gradients self.trainer.accelerator_backend.clip_gradients(optimizer) return grad_norm_dic def _track_gradient_norm(self, batch_idx): grad_norm_dic = {} if batch_idx % self.trainer.row_log_interval == 0: if float(self.trainer.track_grad_norm) > 0: model = self.trainer.get_model() grad_norm_dic = model.grad_norm(self.trainer.track_grad_norm) return grad_norm_dic def log_training_step_metrics(self, opt_closure_result, batch_callback_metrics, batch_log_metrics): # track callback metrics callback_metrics = opt_closure_result.training_step_output.callback_metrics batch_callback_metrics.append(callback_metrics) # decide which metrics to log (results vs dict return) using_results_obj = isinstance(opt_closure_result.training_step_output, Result) if using_results_obj: metrics_to_log = opt_closure_result.training_step_output.batch_log_metrics step_pbar_metrics = opt_closure_result.training_step_output.batch_pbar_metrics else: metrics_to_log = opt_closure_result.training_step_output.log_metrics step_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end # track batch log metrics batch_log_metrics.append(metrics_to_log) # track progress bar metrics if len(step_pbar_metrics) > 0: self.trainer.logger_connector.add_progress_bar_metrics( step_pbar_metrics) def process_hiddens(self, opt_closure_result): hiddens = opt_closure_result.hiddens if isinstance(opt_closure_result.training_step_output, Result): opt_closure_result.training_step_output_for_epoch_end.drop_hiddens( ) return hiddens def tbptt_split_batch(self, batch): splits = [batch] if self.trainer.truncated_bptt_steps is not None: model_ref = self.trainer.get_model() with self.trainer.profiler.profile('tbptt_split_batch'): splits = model_ref.tbptt_split_batch( batch, self.trainer.truncated_bptt_steps) return splits def run_training_epoch(self): # get model model = self.trainer.get_model() # modify dataloader if needed (ddp, etc...) train_dataloader = self.trainer.accelerator_backend.process_dataloader( self.trainer.train_dataloader) # track epoch output epoch_output = [[] for _ in range(self.num_optimizers)] # enable profiling for the dataloader train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader( train_dataloader) dataloader_idx = 0 for batch_idx, (batch, is_last_batch) in train_dataloader: # stop epoch if we limited the number of training batches if batch_idx >= self.trainer.num_training_batches: break self.trainer.batch_idx = batch_idx model.global_step = self.trainer.global_step # ------------------------------------ # TRAINING_STEP + TRAINING_STEP_END # ------------------------------------ batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx) # only track outputs when user implements training_epoch_end # otherwise we will build up unnecessary memory epoch_end_outputs = self.process_train_step_outputs( batch_output.training_step_output_for_epoch_end, self.early_stopping_accumulator, self.checkpoint_accumulator) # hook self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx) # when returning -1 from train_step, we end epoch early self.trainer.should_stop = batch_output.signal == -1 # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- should_check_val = self.should_check_val_fx( batch_idx, is_last_batch) if should_check_val: self.trainer.run_evaluation(test_mode=False) # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) # ----------------------------------------- self.save_loggers_on_train_batch_end(batch_idx) # ----------------------------------------- # SAVE METRICS TO LOGGERS # ----------------------------------------- self.trainer.logger_connector.save_train_loop_metrics_to_loggers( batch_idx, batch_output) # update LR schedulers monitor_metrics = deepcopy( self.trainer.logger_connector.callback_metrics) monitor_metrics.update(batch_output.batch_log_metrics) self.update_train_loop_lr_schedulers( monitor_metrics=monitor_metrics) # progress global step according to grads progress self.increment_accumulated_grad_global_step() # max steps reached, end training if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step: break # end epoch early # stop when the flag is changed or we've gone past the amount # requested in the batches if self.trainer.should_stop: break # process epoch outputs self.trainer.logger_connector.on_train_epoch_end( epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers) # checkpoint callback self.check_checkpoint_callback(self.should_check_val) # epoch end hook self.run_on_epoch_end_hook() def run_training_batch(self, batch, batch_idx, dataloader_idx): # track grad norms grad_norm_dic = {} # track all metrics for callbacks batch_callback_metrics = [] # track metrics to log batch_log_metrics = [] # bookkeeping using_results_obj = False self.trainer.hiddens = None # track all outputs across time and num of optimizers batch_outputs = [[] for _ in range(len(self.get_optimizers_iterable()))] if batch is None: return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic) # hook response = self.trainer.call_hook('on_batch_start') if response == -1: return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) # hook response = self.trainer.call_hook('on_train_batch_start', batch, batch_idx, dataloader_idx) if response == -1: return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) # lightning module hook splits = self.tbptt_split_batch(batch) for split_idx, split_batch in enumerate(splits): self.trainer.split_idx = split_idx # loop over optimizers for opt_idx, optimizer in self.get_optimizers_iterable(): # make sure only the gradients of the current optimizer's parameters are calculated # in the training step to prevent dangling gradients in multiple-optimizer setup. if len(self.trainer.optimizers) > 1: for param in self.trainer.get_model().parameters(): param.requires_grad = False for group in optimizer.param_groups: for param in group['params']: param.requires_grad = True # ------------------- # calculate loss (train step + train step end) # ------------------- opt_closure_result = self.training_step_and_backward( split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) # log metrics self.log_training_step_metrics(opt_closure_result, batch_callback_metrics, batch_log_metrics) # track hiddens self.trainer.hiddens = self.process_hiddens(opt_closure_result) # check if loss or model weights are nan if self.trainer.terminate_on_nan: self.trainer.detect_nan_tensors(opt_closure_result.loss) # track total loss for logging (avoid mem leaks) self.accumulated_loss.append(opt_closure_result.loss) # track all the outputs across all steps batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0 batch_outputs[batch_opt_idx].append( opt_closure_result.training_step_output_for_epoch_end) # ------------------------------ # BACKWARD PASS # ------------------------------ # gradient update with accumulated gradients accumulation_done = ( self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 is_final_batch = (self.trainer.batch_idx + 1) == self.trainer.num_training_batches if accumulation_done or is_final_batch: # hook grad_norm_dic = self.on_before_backward( batch_idx, optimizer) # wrap forward + backward pass in closure for 2nd order optimizers train_step_and_backward_closure = lambda: self.training_step_and_backward( split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens, ).loss # optimizer step self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) # hook self.on_before_zero_grad(optimizer) # clear gradients self.optimizer_zero_grad(batch_idx, optimizer, opt_idx) # calculate running loss for display self.running_loss.append( self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches) # reset for next set of accumulated grads self.accumulated_loss.reset() # collapse all metrics into one dict batch_log_metrics = { k: v for d in batch_log_metrics for k, v in d.items() } # track all metrics for callbacks if not using_results_obj: self.trainer.logger_connector.callback_metrics.update( {k: v for d in batch_callback_metrics for k, v in d.items()}) result = AttributeDict( signal=0, grad_norm_dic=grad_norm_dic, batch_log_metrics=batch_log_metrics, training_step_output_for_epoch_end=batch_outputs) return result def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): """ wrap the forward step in a closure so second order methods work """ # lightning module hook result = self.training_step(split_batch, batch_idx, opt_idx, hiddens) # backward pass self.backward(result, optimizer, opt_idx) # hook self.on_after_backward(result.training_step_output, batch_idx, result.loss) return result def update_train_loop_lr_schedulers(self, monitor_metrics=None): num_accumulated_batches_reached = ( self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 num_training_batches_reached = (self.trainer.batch_idx + 1) == self.trainer.num_training_batches if num_accumulated_batches_reached or num_training_batches_reached: # update lr self.trainer.optimizer_connector.update_learning_rates( interval='step', monitor_metrics=monitor_metrics) def run_on_epoch_end_hook(self): self.trainer.call_hook('on_epoch_end') self.trainer.call_hook('on_train_epoch_end') def increment_accumulated_grad_global_step(self): num_accumulated_batches_reached = ( self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 num_training_batches_reached = (self.trainer.batch_idx + 1) == self.trainer.num_training_batches # progress global step according to grads progress if num_accumulated_batches_reached or num_training_batches_reached: self.trainer.global_step += 1 self.trainer.total_batch_idx += 1 def should_check_val_fx(self, batch_idx, is_last_batch): # decide if we should run validation is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 can_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 can_check_val = self.trainer.enable_validation and can_check_epoch should_check_val = is_val_check_batch or self.trainer.should_stop is_last_batch_for_infinite_dataset = (is_last_batch and self.trainer.val_check_batch == float('inf')) should_check_val = can_check_val and ( should_check_val or is_last_batch_for_infinite_dataset) return should_check_val def build_train_args(self, batch, batch_idx, opt_idx, hiddens): # enable not needing to add opt_idx to training_step args = [batch, batch_idx] if len(self.trainer.optimizers) > 1: if self.trainer.has_arg('training_step', 'optimizer_idx'): args.append(opt_idx) else: num_opts = len(self.trainer.optimizers) raise ValueError( f'Your LightningModule defines {num_opts} optimizers but ' f'training_step is missing the "optimizer_idx" argument.') # pass hiddens if using tbptt if self.trainer.truncated_bptt_steps is not None: args.append(hiddens) return args def save_loggers_on_train_batch_end(self, batch_idx): # when loggers should save to disk should_save_log = ( batch_idx + 1 ) % self.trainer.log_save_interval == 0 or self.trainer.should_stop if should_save_log or self.trainer.fast_dev_run: if self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accumulator, checkpoint_accumulator): """ Figure out what needs to be tracked/logged at the end of the epoch """ # the training step outputs a list per optimizer. The list contains the outputs at each time step # when no TBPTT is used, then the list has 1 item per batch # when TBPTT IS used, then the list has n items (1 per time step) epoch_end_outputs = [] for optimizer_idx_outputs in all_train_step_outputs: # extract one representative sample from each time step (1 if no tbptt) and 0th optimizer sample_output = optimizer_idx_outputs[-1] # pull out callback info if available (ie: Results object) if isinstance(sample_output, dict) and 'early_stop_on' in sample_output: early_stopping_accumulator.accumulate( sample_output['early_stop_on']) if isinstance(sample_output, dict) and 'checkpoint_on' in sample_output: checkpoint_accumulator.accumulate( sample_output['checkpoint_on']) # decide if we need to reduce at the end of the epoch automatically auto_reduce_tng_result = isinstance( sample_output, Result) and sample_output.should_reduce_on_epoch_end # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end if is_overridden( 'training_epoch_end', model=self.trainer.get_model()) or auto_reduce_tng_result: epoch_end_outputs.append(optimizer_idx_outputs) return epoch_end_outputs
def train(self): # add signal handlers for process kills # def _signal_kill_handler(*args): # return TrainerTrainLoopMixin.run_training_teardown(self) # # orig_signal_handlers = {} # for sig_name in SIGNAL_TERMINATE: # orig_signal_handlers[sig_name] = signal.signal(getattr(signal, sig_name), # _signal_kill_handler) # get model model = self.get_model() # enable train mode model.train() # enable gradients torch.set_grad_enabled(True) # load data # if reload_dataloaders_every_epoch, this is moved to the epoch loop if not self.reload_dataloaders_every_epoch: self.reset_train_dataloader(model) self.reset_val_dataloader(model) # Train start events with self.profiler.profile('on_train_start'): # callbacks self.on_train_start() # model hooks model.on_train_start() try: # run all epochs for epoch in range(self.current_epoch, self.max_epochs): # reset train dataloader if self.reload_dataloaders_every_epoch: self.reset_train_dataloader(model) # set seed for distributed sampler (enables shuffling for each epoch) if (self.use_ddp or self.use_horovod) \ and hasattr(self.train_dataloader, 'sampler') \ and hasattr(self.train_dataloader.sampler, 'set_epoch'): self.train_dataloader.sampler.set_epoch(epoch) # update training progress in trainer and model model.current_epoch = epoch self.current_epoch = epoch # changing gradient according accumulation_scheduler self.accumulation_scheduler.on_epoch_start( self, self.get_model()) # stores accumulated grad fractions per batch self.batch_loss_value = TensorRunningAccum( window_length=self.accumulate_grad_batches) # ----------------- # RUN TNG EPOCH # ----------------- self.run_training_epoch() if self.max_steps and self.max_steps <= self.global_step: self.run_training_teardown() return # update LR schedulers self.update_learning_rates(interval='epoch') # early stopping met_min_epochs = epoch >= self.min_epochs - 1 met_min_steps = self.global_step >= self.min_steps if self.min_steps else True if self.should_stop: if (met_min_epochs and met_min_steps) or self.fast_dev_run: self.run_training_teardown() return else: log.info( 'Trainer was signaled to stop but required minimum epochs' f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has' ' not been met. Training will continue...') self.run_training_teardown() except KeyboardInterrupt: rank_zero_warn( 'Detected KeyboardInterrupt, attempting graceful shutdown...') # user could press ctrl+c many times... only shutdown once if not self.interrupted: self.interrupted = True self.on_keyboard_interrupt() self.run_training_teardown()
class TrainerTrainLoopMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class max_epochs: int min_epochs: int on_gpu: bool use_ddp: bool use_dp: bool use_ddp2: bool use_horovod: bool single_gpu: bool use_tpu: bool data_parallel_device_ids:... check_val_every_n_epoch:... num_training_batches: int val_check_batch:... disable_validation: bool fast_dev_run:... accumulation_scheduler:... lr_schedulers:... early_stop_callback:... callback_metrics:... logger: Union[LightningLoggerBase, bool] global_step: int testing: bool log_save_interval: float global_rank: int row_log_interval: float truncated_bptt_steps:... optimizers:... optimizer_frequencies:... accumulate_grad_batches: int track_grad_norm:... model: LightningModule interrupted: bool running_loss:... progress_bar_dict:... reduce_lr_on_plateau_scheduler:... profiler:... batch_idx: int precision:... train_dataloader: DataLoader reload_dataloaders_every_epoch: bool max_steps: int min_steps: int total_batch_idx: int terminate_on_nan: bool tpu_id: int interactive_ddp_procs:... # Callback system callbacks: List[Callback] on_train_start: Callable on_train_end: Callable on_batch_start: Callable on_batch_end: Callable on_epoch_start: Callable on_epoch_end: Callable on_validation_end: Callable on_keyboard_interrupt: Callable @abstractmethod def get_model(self) -> LightningModule: """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def is_function_implemented(self, *args, **kwargs): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def run_evaluation(self, *args, **kwargs): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def transfer_batch_to_gpu(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def transfer_batch_to_tpu(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def clip_gradients(self): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def detect_nan_tensors(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def is_overridden(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def add_progress_bar_metrics(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def log_metrics(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def process_output(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def reset_train_dataloader(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def reset_val_dataloader(self, model): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def has_arg(self, *args): """Warning: this is just empty shell for code implemented in other class.""" def train(self): # add signal handlers for process kills # def _signal_kill_handler(*args): # return TrainerTrainLoopMixin.run_training_teardown(self) # # orig_signal_handlers = {} # for sig_name in SIGNAL_TERMINATE: # orig_signal_handlers[sig_name] = signal.signal(getattr(signal, sig_name), # _signal_kill_handler) # get model model = self.get_model() # enable train mode model.train() # enable gradients torch.set_grad_enabled(True) # load data # if reload_dataloaders_every_epoch, this is moved to the epoch loop if not self.reload_dataloaders_every_epoch: self.reset_train_dataloader(model) self.reset_val_dataloader(model) # Train start events with self.profiler.profile('on_train_start'): # callbacks self.on_train_start() # model hooks model.on_train_start() try: # run all epochs for epoch in range(self.current_epoch, self.max_epochs): # reset train dataloader if self.reload_dataloaders_every_epoch: self.reset_train_dataloader(model) # set seed for distributed sampler (enables shuffling for each epoch) if (self.use_ddp or self.use_horovod) \ and hasattr(self.train_dataloader, 'sampler') \ and hasattr(self.train_dataloader.sampler, 'set_epoch'): self.train_dataloader.sampler.set_epoch(epoch) # update training progress in trainer and model model.current_epoch = epoch self.current_epoch = epoch # changing gradient according accumulation_scheduler self.accumulation_scheduler.on_epoch_start( self, self.get_model()) # stores accumulated grad fractions per batch self.batch_loss_value = TensorRunningAccum( window_length=self.accumulate_grad_batches) # ----------------- # RUN TNG EPOCH # ----------------- self.run_training_epoch() if self.max_steps and self.max_steps <= self.global_step: self.run_training_teardown() return # update LR schedulers self.update_learning_rates(interval='epoch') # early stopping met_min_epochs = epoch >= self.min_epochs - 1 met_min_steps = self.global_step >= self.min_steps if self.min_steps else True if self.should_stop: if (met_min_epochs and met_min_steps) or self.fast_dev_run: self.run_training_teardown() return else: log.info( 'Trainer was signaled to stop but required minimum epochs' f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has' ' not been met. Training will continue...') self.run_training_teardown() except KeyboardInterrupt: rank_zero_warn( 'Detected KeyboardInterrupt, attempting graceful shutdown...') # user could press ctrl+c many times... only shutdown once if not self.interrupted: self.interrupted = True self.on_keyboard_interrupt() self.run_training_teardown() def prepare_train_loop_dataloader(self, train_dataloader): # on TPU we have to wrap it under the ParallelLoader if self.use_tpu: device = xm.xla_device(self.tpu_id) train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device]) train_dataloader = train_dataloader.per_device_loader(device) return train_dataloader def run_on_epoch_start_hook(self, model): # Epoch start events with self.profiler.profile('on_epoch_start'): # callbacks self.on_epoch_start() # model hooks if self.is_function_implemented('on_epoch_start'): model.on_epoch_start() def run_training_epoch(self): # get model model = self.get_model() # Epoch start events self.run_on_epoch_start_hook(model) # modify dataloader if needed (ddp, etc...) train_dataloader = self.prepare_train_loop_dataloader( self.train_dataloader) # bookkeeping epoch_output = [] should_check_val = False # run epoch for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable( enumerate(_with_is_last(train_dataloader)), "get_train_batch"): # stop epoch if we limited the number of training batches if batch_idx >= self.num_training_batches: break self.batch_idx = batch_idx model.global_step = self.global_step # ------------------------------------ # TRAINING_STEP + TRAINING_STEP_END # ------------------------------------ batch_output = self.run_training_batch(batch, batch_idx) # only track outputs when user implements training_epoch_end # otherwise we will build up unnecessary memory if self.is_overridden('training_epoch_end', model=self.get_model()): epoch_output.append( batch_output.training_step_output_for_epoch_end) # update LR schedulers self.update_train_loop_lr_schedulers() # when returning -1 from train_step, we end epoch early self.should_stop = batch_output.signal == -1 # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- should_check_val = self.should_check_val(batch_idx, is_last_batch) if self.fast_dev_run or should_check_val: self.run_evaluation(test_mode=False) # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) # ----------------------------------------- self.save_loggers_in_training_loop(batch_idx) # ----------------------------------------- # SAVE METRICS TO LOGGERS # ----------------------------------------- self.save_train_loop_metrics_to_loggers(batch_idx, batch_output) # progress global step according to grads progress self.increment_accumulated_grad_global_step() # max steps reached, end training if self.max_steps is not None and self.max_steps == self.global_step: break # end epoch early # stop when the flag is changed or we've gone past the amount # requested in the batches if self.fast_dev_run or self.should_stop: break # let ddp devices catch up when using horovod self.sync_horovod() # process epoch outputs self.run_training_epoch_end(epoch_output) # checkpoint callback self.check_checkpoint_callback(should_check_val) # epoch end hook self.run_on_epoch_end_hook(model) def check_checkpoint_callback(self, should_check_val): # when no val loop is present or fast-dev-run still need to call checkpoints # TODO bake this logic into the checkpoint callback should_activate = not self.is_overridden('validation_step') and not ( self.fast_dev_run or should_check_val) if should_activate: checkpoint_callbacks = [ c for c in self.callbacks if isinstance(c, ModelCheckpoint) ] [ c.on_validation_end(self, self.get_model()) for c in checkpoint_callbacks ] def update_train_loop_lr_schedulers(self): if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: # update lr self.update_learning_rates(interval='step') def run_on_epoch_end_hook(self, model): with self.profiler.profile('on_epoch_end'): # callbacks self.on_epoch_end() # model hooks if self.is_function_implemented('on_epoch_end'): model.on_epoch_end() def run_training_epoch_end(self, epoch_output): model = self.get_model() if self.is_overridden('training_epoch_end', model=model): self.global_step += 1 epoch_output = model.training_epoch_end(epoch_output) _processed_outputs = self.process_output(epoch_output) log_epoch_metrics = _processed_outputs[2] callback_epoch_metrics = _processed_outputs[3] # add the metrics to the loggers self.log_metrics(log_epoch_metrics, {}) # add metrics to callbacks self.callback_metrics.update(callback_epoch_metrics) # add metrics to progress_bar self.add_progress_bar_metrics(_processed_outputs[1]) def sync_horovod(self): if self.use_horovod: hvd.join(hvd.local_rank() if self.on_gpu else -1) def increment_accumulated_grad_global_step(self): # progress global step according to grads progress if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: self.global_step += 1 self.total_batch_idx += 1 def save_train_loop_metrics_to_loggers(self, batch_idx, batch_output): # when metrics should be logged should_log_metrics = batch_idx % self.row_log_interval == 0 or self.should_stop if should_log_metrics or self.fast_dev_run: # logs user requested information to logger self.log_metrics(batch_output.batch_log_metrics, batch_output.grad_norm_dic) def save_loggers_in_training_loop(self, batch_idx): # when loggers should save to disk should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or self.should_stop if should_save_log or self.fast_dev_run: if self.is_global_zero and self.logger is not None: self.logger.save() def should_check_val(self, batch_idx, is_last_batch): # decide if we should run validation is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0 can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 can_check_val = not self.disable_validation and can_check_epoch should_check_val = is_val_check_batch or self.should_stop is_last_batch_for_infinite_dataset = (is_last_batch and self.val_check_batch == float('inf')) should_check_val = can_check_val and ( should_check_val or is_last_batch_for_infinite_dataset) return should_check_val def run_training_batch(self, batch, batch_idx): # track grad norms grad_norm_dic = {} # track all metrics for callbacks batch_callback_metrics = [] # track metrics to log batch_log_metrics = [] if batch is None: return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic) # Batch start events with self.profiler.profile('on_batch_start'): # callbacks self.on_batch_start() # hooks if self.is_function_implemented('on_batch_start'): response = self.get_model().on_batch_start(batch) if response == -1: return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) splits = [batch] if self.truncated_bptt_steps is not None: model_ref = self.get_model() with self.profiler.profile('tbptt_split_batch'): splits = model_ref.tbptt_split_batch(batch, self.truncated_bptt_steps) self.hiddens = None for split_idx, split_batch in enumerate(splits): self.split_idx = split_idx for opt_idx, optimizer in self._get_optimizers_iterable(): # make sure only the gradients of the current optimizer's parameters are calculated # in the training step to prevent dangling gradients in multiple-optimizer setup. if len(self.optimizers) > 1: for param in self.get_model().parameters(): param.requires_grad = False for group in optimizer.param_groups: for param in group['params']: param.requires_grad = True # ------------------- # calculate loss # ------------------- opt_closure_result = self.optimizer_closure( split_batch, batch_idx, opt_idx, optimizer, self.hiddens) # ------------------------------ # POST forward bookkeeping # ------------------------------ batch_callback_metrics.append( opt_closure_result.training_step_output.callback_metrics) batch_log_metrics.append( opt_closure_result.training_step_output.log_metrics) self.add_progress_bar_metrics( opt_closure_result.training_step_output.pbar_on_batch_end) # track hiddens self.hiddens = opt_closure_result.hiddens # check if loss or model weights are nan if self.terminate_on_nan: self.detect_nan_tensors(opt_closure_result.loss) # track total loss for logging (avoid mem leaks) self.batch_loss_value.append(opt_closure_result.loss) # ------------------------------ # BACKWARD PASS # ------------------------------ # gradient update with accumulated gradients if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: # backward grad_norm_dic = self.run_batch_backward_pass( split_batch, batch_idx, opt_idx, optimizer) # calculate running loss for display self.running_loss.append(self.batch_loss_value.mean()) # reset for next set of accumulated grads self.batch_loss_value.reset() # Batch end events with self.profiler.profile('on_batch_end'): # callbacks self.on_batch_end() # model hooks if self.is_function_implemented('on_batch_end'): self.get_model().on_batch_end() # collapse all metrics into one dict batch_log_metrics = { k: v for d in batch_log_metrics for k, v in d.items() } # track all metrics for callbacks self.callback_metrics.update( {k: v for d in batch_callback_metrics for k, v in d.items()}) result = AttributeDict( signal=0, grad_norm_dic=grad_norm_dic, batch_log_metrics=batch_log_metrics, training_step_output_for_epoch_end=opt_closure_result. training_step_output_for_epoch_end) return result def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer): # ------------------ # GRAD NORMS # ------------------ # track gradient norms when requested grad_norm_dic = {} if batch_idx % self.row_log_interval == 0: if float(self.track_grad_norm) > 0: model = self.get_model() grad_norm_dic = model.grad_norm(self.track_grad_norm) # ------------------ # CLIP GRADS # ------------------ if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu: self.scaler.unscale_(optimizer) self.clip_gradients() # ------------------ # .STEP + ZERO_GRAD # ------------------ self.call_optimizer_step(optimizer, opt_idx, batch_idx, split_batch) return grad_norm_dic def call_optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch): # calls .step(), .zero_grad() # override function to modify this behavior model = self.get_model() with self.profiler.profile('optimizer_step'): lambda_closure = lambda: self.optimizer_closure( split_batch, batch_idx, opt_idx, optimizer, self.hiddens).loss # apply TPU optimizer if self.use_tpu and XLA_AVAILABLE: model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx, lambda_closure, on_tpu=True) # for LBFGS do something a bit different elif isinstance(optimizer, torch.optim.LBFGS): # native amp + lbfgs is a no go right now if self.use_amp and NATIVE_AMP_AVALAIBLE: raise MisconfigurationException( 'native PyTorch amp and lbfgs are not compatible.' ' To request, please file a Github issue in PyTorch and tag @mcarilli' ) model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx, lambda_closure, using_lbfgs=True) # when using 16-bit else: native_amp = self.use_amp and NATIVE_AMP_AVALAIBLE model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx, lambda_closure, using_native_amp=native_amp) # in native 16-bit we need to update scaler after optimizer step if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu: self.scaler.update() # model hook model.on_before_zero_grad(optimizer) # clear gradients model.optimizer_zero_grad(self.current_epoch, batch_idx, optimizer, opt_idx) def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): """ wrap the forward step in a closure so second order methods work """ # --------------------------- # FORWARD # --------------------------- with self.profiler.profile('model_forward'): if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu: with torch.cuda.amp.autocast(): training_step_output = self.training_forward( split_batch, batch_idx, opt_idx, hiddens) else: training_step_output = self.training_forward( split_batch, batch_idx, opt_idx, hiddens) # ---------------------------- # PROCESS THE RESULT # ---------------------------- # format and reduce outputs accordingly training_step_output_for_epoch_end = training_step_output training_step_output = self.process_output(training_step_output, train=True) # TODO: temporary part of structured results PR training_step_output = AttributeDict( batch_loss=training_step_output[0], pbar_on_batch_end=training_step_output[1], log_metrics=training_step_output[2], callback_metrics=training_step_output[3], hiddens=training_step_output[4], ) # if the user decides to finally reduce things in epoch_end, save raw output without graphs if isinstance(training_step_output_for_epoch_end, torch.Tensor): training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach( ) else: training_step_output_for_epoch_end = recursive_detach( training_step_output_for_epoch_end) # accumulate loss # (if accumulate_grad_batches = 1 no effect) closure_loss = training_step_output.batch_loss / self.accumulate_grad_batches # the loss will get scaled for amp. avoid any modifications to it untouched_loss = closure_loss.detach().clone() # backward pass model_ref = self.get_model() with self.profiler.profile('model_backward'): # scale loss for 16 bit if self.precision == 16 and not self.on_tpu: closure_loss = model_ref.amp_scale_loss( closure_loss, optimizer, opt_idx) # enter amp context if not NATIVE_AMP_AVALAIBLE: context = closure_loss closure_loss = closure_loss.__enter__() # do backward pass model_ref.backward(self, closure_loss, optimizer, opt_idx) # exit amp context if self.precision == 16 and not NATIVE_AMP_AVALAIBLE and not self.on_tpu: a, b, c = None, None, None error = context.__exit__(a, b, c) if error: rank_zero_warn(a, b, c) raise Exception('apex unscale error') # once backward has been applied, release graph closure_loss = closure_loss.detach() training_step_output.batch_loss = training_step_output.batch_loss.detach( ) if self.use_horovod: # Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid optimizer.synchronize() # insert after step hook if self.is_function_implemented('on_after_backward'): model_ref = self.get_model() with self.profiler.profile('on_after_backward'): model_ref.on_after_backward() result = AttributeDict( loss=untouched_loss, training_step_output=training_step_output, training_step_output_for_epoch_end= training_step_output_for_epoch_end, hiddens=training_step_output.hiddens, ) return result def _get_optimizers_iterable(self): if not self.optimizer_frequencies: # call training_step once per optimizer return list(enumerate(self.optimizers)) optimizer_freq_cumsum = np.cumsum(self.optimizer_frequencies) optimizers_loop_length = optimizer_freq_cumsum[-1] current_place_in_loop = self.total_batch_idx % optimizers_loop_length # find optimzier index by looking for the first {item > current_place} in the cumsum list opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop) return [(opt_idx, self.optimizers[opt_idx])] # @atexit.register def run_training_teardown(self): if hasattr(self, '_teardown_already_run') and self._teardown_already_run: return self._teardown_already_run = True # Train end events with self.profiler.profile('on_train_end'): # callbacks self.on_train_end() # model hooks if self.is_function_implemented('on_train_end'): self.get_model().on_train_end() if self.logger is not None: self.logger.finalize("success") # summarize profile results if self.global_rank == 0: self.profiler.describe() if self.global_rank == 0: for proc in self.interactive_ddp_procs: subprocess.Popen.kill(proc) # clean up dist group if self.use_ddp or self.use_ddp2: torch_distrib.destroy_process_group() # clear mem if self.on_gpu: model = self.get_model() model.cpu() torch.cuda.empty_cache() def training_forward(self, batch, batch_idx, opt_idx, hiddens): """ Handle forward for each training case (distributed, single gpu, etc...) :param batch: :param batch_idx: :return: """ # --------------- # FORWARD # --------------- # enable not needing to add opt_idx to training_step args = [batch, batch_idx] if len(self.optimizers) > 1: if self.has_arg('training_step', 'optimizer_idx'): args.append(opt_idx) else: num_opts = len(self.optimizers) raise ValueError( f'Your LightningModule defines {num_opts} optimizers but ' f'training_step is missing the "optimizer_idx" argument.') # pass hiddens if using tbptt if self.truncated_bptt_steps is not None: args.append(hiddens) # distributed forward if self.use_ddp or self.use_ddp2 or self.use_dp: output = self.model(*args) # Horovod elif self.use_horovod and self.on_gpu: batch = self.transfer_batch_to_gpu(batch, hvd.local_rank()) args[0] = batch output = self.model.training_step(*args) # single GPU forward elif self.single_gpu: gpu_id = 0 if isinstance(self.data_parallel_device_ids, list): gpu_id = self.data_parallel_device_ids[0] # Don't copy the batch since there is a single gpu that the batch could # be referenced from and if there are multiple optimizers the batch will # wind up copying it to the same device repeatedly. batch = self.transfer_batch_to_gpu(batch, gpu_id) args[0] = batch output = self.model.training_step(*args) # TPU support elif self.use_tpu: batch = self.transfer_batch_to_tpu(batch, self.tpu_id) args[0] = batch output = self.model.training_step(*args) # CPU forward else: output = self.model.training_step(*args) # allow any mode to define training_step_end # do something will all the dp outputs (like softmax) if self.is_overridden('training_step_end'): model_ref = self.get_model() with self.profiler.profile('training_step_end'): output = model_ref.training_step_end(output) # allow any mode to define training_end # TODO: remove in 1.0.0 if self.is_overridden('training_end'): model_ref = self.get_model() with self.profiler.profile('training_end'): output = model_ref.training_end(output) rank_zero_warn( '`training_end` was deprecated in 0.7.0 and will be removed 1.0.0.' ' Use training_epoch_end instead', DeprecationWarning) return output def update_learning_rates(self, interval: str): """Update learning rates. Args: interval: either 'epoch' or 'step'. """ if not self.lr_schedulers: return for lr_scheduler in self.lr_schedulers: current_idx = self.batch_idx if interval == 'step' else self.current_epoch current_idx += 1 # account for both batch and epoch starts from 0 # Take step if call to update_learning_rates matches the interval key and # the current step modulo the schedulers frequency is zero if lr_scheduler[ 'interval'] == interval and current_idx % lr_scheduler[ 'frequency'] == 0: # If instance of ReduceLROnPlateau, we need to pass validation loss if lr_scheduler['reduce_on_plateau']: monitor_key = lr_scheduler['monitor'] monitor_val = self.callback_metrics.get(monitor_key) if monitor_val is None: avail_metrics = ','.join( list(self.callback_metrics.keys())) raise MisconfigurationException( f'ReduceLROnPlateau conditioned on metric {monitor_key}' f' which is not available. Available metrics are: {avail_metrics}.' ' Condition can be set using `monitor` key in lr scheduler dict' ) lr_scheduler['scheduler'].step(monitor_val) else: lr_scheduler['scheduler'].step()
def train(self): # get model model = self.get_model() # load data # if reload_dataloaders_every_epoch, this is moved to the epoch loop if not self.reload_dataloaders_every_epoch: self.reset_train_dataloader(model) self.reset_val_dataloader(model) # Train start events with self.profiler.profile('on_train_start'): # callbacks self.on_train_start() # initialize early stop callback if self.early_stop_callback is not None: self.early_stop_callback.on_train_start(self, self.get_model()) # model hooks model.on_train_start() try: # run all epochs for epoch in range(self.current_epoch, self.max_epochs): # reset train dataloader if self.reload_dataloaders_every_epoch: self.reset_train_dataloader(model) # set seed for distributed sampler (enables shuffling for each epoch) if (self.use_ddp or self.use_horovod) \ and hasattr(self.train_dataloader, 'sampler') \ and hasattr(self.train_dataloader.sampler, 'set_epoch'): self.train_dataloader.sampler.set_epoch(epoch) # update training progress in trainer and model model.current_epoch = epoch self.current_epoch = epoch # changing gradient according accumulation_scheduler self.accumulation_scheduler.on_epoch_start(self, self.get_model()) # stores accumulated grad fractions per batch self.batch_loss_value = TensorRunningAccum( window_length=self.accumulate_grad_batches ) # ----------------- # RUN TNG EPOCH # ----------------- self.run_training_epoch() if self.max_steps and self.max_steps == self.global_step: self.run_training_teardown() return # update LR schedulers self.update_learning_rates(interval='epoch') # early stopping met_min_epochs = epoch >= self.min_epochs - 1 met_min_steps = self.global_step >= self.min_steps if self.min_steps else True # TODO wrap this logic into the callback if self.enable_early_stop: if (met_min_epochs and met_min_steps) or self.fast_dev_run: should_stop = self.early_stop_callback.on_epoch_end(self, self.get_model()) # stop training stop = should_stop and met_min_epochs if stop: self.run_training_teardown() return self.run_training_teardown() except KeyboardInterrupt: if self.proc_rank == 0: log.info('Detected KeyboardInterrupt, attempting graceful shutdown...') self.interrupted = True self.run_training_teardown()
class TrainLoop: def __init__( self, trainer, max_epochs: Optional[int], min_epochs: Optional[int], max_steps: Optional[int], min_steps: Optional[int], num_sanity_val_steps: int, ): self.trainer = trainer self.accumulated_loss = None self.warning_cache = WarningCache() self._teardown_already_run = False self.running_loss = TensorRunningAccum(window_length=20) self._skip_backward = False self._optimizer_freq_cumsum = None self._hiddens = None self.global_step = 0 self.current_epoch = 0 self.trainer.should_stop = False # the total batch index across all epochs self.total_batch_idx = 0 # the current batch index in the loop that runs over the dataloader(s) self.batch_idx = 0 # the current split index when the batch gets split into chunks in truncated backprop through time self.split_idx = None self.trainer.num_training_batches = 0 self.trainer.train_dataloader = None # If neither max_epochs or max_steps is set, then use existing default of max_epochs = 1000 self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs # If neither min_epochs or min_steps is set, then use existing default of min_epochs = 1 self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs self.max_steps = max_steps self.min_steps = min_steps if num_sanity_val_steps == -1: self.trainer.num_sanity_val_steps = float("inf") else: self.trainer.num_sanity_val_steps = num_sanity_val_steps @property def num_active_optimizers(self) -> int: return len(self.get_active_optimizers()) @property def optimizer_freq_cumsum(self): if self._optimizer_freq_cumsum is None: self._optimizer_freq_cumsum = np.cumsum( self.trainer.optimizer_frequencies) return self._optimizer_freq_cumsum def should_skip_training(self) -> bool: should_by_max_steps = self.max_steps is not None and self.global_step >= self.max_steps should_by_epoch = self.max_epochs is not None and self.current_epoch >= self.max_epochs return should_by_max_steps or should_by_epoch or self.trainer.num_training_batches == 0 def on_train_start(self): # hook self.trainer.call_hook("on_train_start") def on_train_end(self): if self._teardown_already_run: return self._teardown_already_run = True # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates # when a checkpoint was saved at the last step self.global_step -= 1 self.check_checkpoint_callback(should_update=True, is_last=True) self.global_step += 1 # hook self.trainer.call_hook("on_train_end") # todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers. # It might be related to xla tensors blocked when moving the cpu # kill loggers if self.trainer.logger is not None: self.trainer.logger.finalize("success") # summarize profile results self.trainer.profiler.describe() # give accelerators a chance to finish self.trainer.accelerator.on_train_end() # reset bookkeeping self.trainer.state.stage = None def check_checkpoint_callback(self, should_update, is_last=False): # TODO bake this logic into the ModelCheckpoint callback if should_update and self.trainer.checkpoint_connector.has_trained: callbacks = self.trainer.checkpoint_callbacks if is_last and any(cb.save_last and cb.verbose for cb in callbacks): rank_zero_info("Saving latest checkpoint...") model = self.trainer.lightning_module for cb in callbacks: cb.on_validation_end(self.trainer, model) def on_train_epoch_start(self, epoch): # update training progress in trainer self.current_epoch = epoch model = self.trainer.lightning_module # reset train dataloader if epoch != 0 and self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_train_dataloader(model) # todo: specify the possible exception with suppress(Exception): # set seed for distributed sampler (enables shuffling for each epoch) self.trainer.train_dataloader.sampler.set_epoch(epoch) # changing gradient according accumulation_scheduler self.trainer.accumulation_scheduler.on_train_epoch_start( self.trainer, self.trainer.lightning_module) # stores accumulated grad fractions per batch self.accumulated_loss = TensorRunningAccum( window_length=self.trainer.accumulate_grad_batches) # hook self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start") def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx): batch_end_outputs = [ opt_idx_out for opt_idx_out in batch_end_outputs if len(opt_idx_out) ] processed_batch_end_outputs = TrainLoop._prepare_outputs( batch_end_outputs, batch_mode=True) # hook self.trainer.call_hook('on_train_batch_end', processed_batch_end_outputs, batch, batch_idx, dataloader_idx) self.trainer.call_hook('on_batch_end') # figure out what to track for epoch end self.track_epoch_end_reduce_metrics(epoch_output, batch_end_outputs) # reset batch logger internals self.trainer.logger_connector.on_train_batch_end() def reset_train_val_dataloaders(self, model) -> None: """ Resets train and val dataloaders if none are attached to the trainer. The val dataloader must be initialized before training loop starts, as the training loop inspects the val dataloader to determine whether to run the evaluation loop. """ if self.trainer.train_dataloader is None: self.trainer.reset_train_dataloader(model) if self.trainer.val_dataloaders is None: self.trainer.reset_val_dataloader(model) def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): hook_overridden = self._should_add_batch_output_to_epoch_output() # track the outputs to reduce at the end of the epoch for opt_idx, opt_outputs in enumerate(batch_end_outputs): sample_output = opt_outputs[-1] # decide if we need to reduce at the end of the epoch automatically auto_reduce_tng_result = isinstance( sample_output, Result) and sample_output.should_reduce_on_epoch_end # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end if not (hook_overridden or auto_reduce_tng_result): continue # with 1 step (no tbptt) don't use a sequence at epoch end if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance( opt_outputs[0], Result): opt_outputs = opt_outputs[0] epoch_output[opt_idx].append(opt_outputs) def _should_add_batch_output_to_epoch_output(self) -> bool: # We add to the epoch outputs if # 1. The model defines training_epoch_end OR # 2. The model overrides on_train_epoch_end which has `outputs` in the signature # TODO: in v1.5 this only needs to check if training_epoch_end is overridden lightning_module = self.trainer.lightning_module if is_overridden("training_epoch_end", model=lightning_module): return True if is_overridden("on_train_epoch_end", model=lightning_module): model_hook_fx = getattr(lightning_module, "on_train_epoch_end") if is_param_in_hook_signature(model_hook_fx, "outputs"): return True return False def get_active_optimizers( self, batch_idx: Optional[int] = None) -> List[Tuple[int, Optimizer]]: """ Returns the currently active optimizers. When multiple optimizers are used with different frequencies, only one of the optimizers is active at a time. Returns: A list of tuples (opt_idx, optimizer) of currently active optimizers. """ if not self.trainer.optimizer_frequencies: # call training_step once per optimizer return list(enumerate(self.trainer.optimizers)) batch_idx = self.total_batch_idx if batch_idx is None else batch_idx optimizers_loop_length = self.optimizer_freq_cumsum[-1] current_place_in_loop = batch_idx % optimizers_loop_length # find optimzier index by looking for the first {item > current_place} in the cumsum list opt_idx = int( np.argmax(self.optimizer_freq_cumsum > current_place_in_loop)) return [(opt_idx, self.trainer.optimizers[opt_idx])] def on_after_backward(self, training_step_output, batch_idx, untouched_loss): training_step_output.detach() # insert after step hook self.trainer.call_hook("on_after_backward") # when in dev debugging track the losses self.trainer.dev_debugger.track_train_loss_history( batch_idx, untouched_loss.detach()) def _check_training_step_output(self, training_step_output): if isinstance( training_step_output, torch.Tensor ) and not self.trainer.lightning_module.automatic_optimization: if training_step_output.grad_fn is None: # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ... raise MisconfigurationException( "In manual optimization, `training_step` should not return a Tensor" ) def training_step(self, split_batch, batch_idx, opt_idx, hiddens): # give the PL module a result for logging model_ref = self.trainer.lightning_module with self.trainer.profiler.profile("model_forward"): step_kwargs = self._build_kwargs(split_batch, batch_idx, opt_idx, hiddens) # manually capture logged metrics model_ref._current_fx_name = 'training_step' model_ref._results = Result() with self.trainer.profiler.profile("training_step"): training_step_output = self.trainer.accelerator.training_step( step_kwargs) self.trainer.accelerator.post_training_step() self.trainer.logger_connector.cache_logged_metrics() self._check_training_step_output(training_step_output) training_step_output = self.trainer.call_hook( "training_step_end", training_step_output) training_step_output_for_epoch_end, training_step_output = self._process_training_step_output( training_step_output, split_batch) if training_step_output_for_epoch_end is None: return # enable empty loss when using manual opt closure_loss = None untouched_loss = None if self.trainer.lightning_module.automatic_optimization: # accumulate loss. if accumulate_grad_batches==1, no effect closure_loss = training_step_output.minimize / self.trainer.accumulate_grad_batches # the loss will get scaled for amp. avoid any modifications to it untouched_loss = closure_loss.detach().clone() # result result = AttributeDict( closure_loss=closure_loss, loss=untouched_loss, training_step_output=training_step_output, training_step_output_for_epoch_end= training_step_output_for_epoch_end, ) return result def _process_training_step_output(self, training_step_output, split_batch): training_step_output_for_epoch_end = training_step_output # enable validation_step return None if training_step_output_for_epoch_end is None: return None, None result = self.trainer.lightning_module._results loss = None hiddens = None result["extra"] = {} # handle dict return if isinstance(training_step_output, dict): loss = training_step_output.pop("loss", None) hiddens = training_step_output.pop("hiddens", None) if hiddens is not None: hiddens = hiddens.detach() result["extra"] = training_step_output # handle scalar return elif isinstance(training_step_output, torch.Tensor): loss = training_step_output # map to results under the hood result.minimize = loss self._hiddens = hiddens # track batch for manual reduction with result result.track_batch_size(len(split_batch)) # track metrics without grads for epoch reduction training_step_output_for_epoch_end = copy(result) training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach( ) if self.trainer.move_metrics_to_cpu: training_step_output_for_epoch_end = training_step_output_for_epoch_end.cpu( ) return training_step_output_for_epoch_end, result @staticmethod def _prepare_outputs( outputs: List[List[List[Result]]], batch_mode: bool, ) -> Union[List[List[List[Dict]]], List[List[Dict]], List[Dict], Dict]: """ Extract required information from batch or epoch end results. Args: outputs: A 3-dimensional list of ``Result`` objects with dimensions: [optimizer outs][batch outs][tbptt steps]. batch_mode: If True, ignore the batch output dimension. Returns: The cleaned outputs with ``Result`` objects converted to dictionaries. All list dimensions of size one will be collapsed. """ processed_outputs = [] for opt_outputs in outputs: # handle an edge case where an optimizer output is the empty list if len(opt_outputs) == 0: continue processed_batch_outputs = [] if batch_mode: opt_outputs = [opt_outputs] for batch_outputs in opt_outputs: processed_tbptt_outputs = [] for tbptt_output in batch_outputs: out = tbptt_output.extra out['loss'] = tbptt_output.minimize processed_tbptt_outputs.append(out) # if there was only one tbptt step then we can collapse that dimension if len(processed_tbptt_outputs) == 1: processed_tbptt_outputs = processed_tbptt_outputs[0] processed_batch_outputs.append(processed_tbptt_outputs) # batch_outputs should be just one dict (or a list of dicts if using tbptt) per optimizer if batch_mode: processed_batch_outputs = processed_batch_outputs[0] processed_outputs.append(processed_batch_outputs) # if there is only one optimiser then we collapse that dimension if len(processed_outputs) == 1: processed_outputs = processed_outputs[0] return processed_outputs def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): model_ref = self.trainer.lightning_module is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) using_native_amp = self.trainer.amp_backend == AMPType.NATIVE # native amp + lbfgs is a no go right now if using_native_amp and is_lbfgs: raise MisconfigurationException( 'native PyTorch amp and lbfgs are not compatible.' ' To request, please file a Github issue in PyTorch and tag @mcarilli' ) # wraps into LightningOptimizer only for running step optimizer = LightningOptimizer._to_lightning_optimizer( optimizer, self.trainer, opt_idx) # model hook model_ref.optimizer_step( self.trainer.current_epoch, batch_idx, optimizer, opt_idx, train_step_and_backward_closure, on_tpu=self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE, using_native_amp=using_native_amp, using_lbfgs=is_lbfgs, ) def on_before_zero_grad(self, optimizer): self.trainer.call_hook('on_before_zero_grad', optimizer) def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): self.trainer.accelerator.optimizer_zero_grad( self.trainer.current_epoch, batch_idx, optimizer, opt_idx) def track_and_norm_grad(self, optimizer) -> dict: # track gradient norms grad_norm_dict = self._track_gradient_norm() # clip gradients self.trainer.accelerator.clip_gradients( optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm) return grad_norm_dict def _track_gradient_norm(self): grad_norm_dict = {} if (self.global_step + 1) % self.trainer.log_every_n_steps == 0: if float(self.trainer.track_grad_norm) > 0: model = self.trainer.lightning_module grad_norm_dict = grad_norm(model, self.trainer.track_grad_norm) return grad_norm_dict def _tbptt_split_batch(self, batch: Any) -> List[Any]: splits = [batch] truncated_bptt_enabled = self._truncated_bptt_enabled() if truncated_bptt_enabled: model_ref = self.trainer.lightning_module with self.trainer.profiler.profile("tbptt_split_batch"): splits = model_ref.tbptt_split_batch( batch, self._truncated_bptt_steps()) return splits def run_training_epoch(self): # modify dataloader if needed (ddp, etc...) train_dataloader = self.trainer.accelerator.process_dataloader( self.trainer.train_dataloader) # track epoch output epoch_output = [[] for _ in range(self.num_active_optimizers)] train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader( train_dataloader) dataloader_idx = 0 batch_idx = None is_last_batch = None for batch_idx, (batch, is_last_batch) in train_dataloader: self.batch_idx = batch_idx # ------------------------------------ # TRAINING_STEP + TRAINING_STEP_END # ------------------------------------ with self.trainer.profiler.profile("run_training_batch"): batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx) # when returning -1 from train_step, we end epoch early if batch_output.signal == -1: break # hook # TODO: add outputs to batches self.on_train_batch_end( epoch_output, batch_output.training_step_output_for_epoch_end, batch, batch_idx, dataloader_idx, ) # ----------------------------------------- # SAVE METRICS TO LOGGERS # ----------------------------------------- self.trainer.logger_connector.log_train_step_metrics(batch_output) # ----------------------------------------- # VALIDATE IF NEEDED # ----------------------------------------- should_check_val = self._should_check_val_fx( batch_idx, is_last_batch) if should_check_val: self.trainer.validating = True self.trainer._run_evaluation() self.trainer.training = True # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) # ----------------------------------------- self.save_loggers_on_train_batch_end() # update LR schedulers monitor_metrics = deepcopy( self.trainer.logger_connector.callback_metrics) self.update_train_loop_lr_schedulers( monitor_metrics=monitor_metrics) self.trainer.checkpoint_connector.has_trained = True # max steps reached, end training if (self.max_steps is not None and self.max_steps <= self.global_step + 1 and self._accumulated_batches_reached()): break # end epoch early # stop when the flag is changed or we've gone past the amount # requested in the batches if self.trainer.should_stop: break self.total_batch_idx += 1 # stop epoch if we limited the number of training batches if self._num_training_batches_reached(is_last_batch): break # progress global step according to grads progress self.increment_accumulated_grad_global_step() if batch_idx is None: # dataloader/iterator did not produce a batch return # handle epoch_output on epoch end self.on_train_epoch_end(epoch_output) # log epoch metrics self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output) should_check_val = self._should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation( self.trainer.num_val_batches) should_train_only = self.trainer.disable_validation or should_skip_eval # update epoch level lr_schedulers if no val loop outside train loop is triggered if not should_check_val or should_train_only: self.trainer.optimizer_connector.update_learning_rates( interval='epoch') if should_train_only: self.check_checkpoint_callback(True) if should_check_val: self.trainer.validating = True self.trainer._run_evaluation(on_epoch=True) self.trainer.training = True # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None: # inform logger the batch loop has finished self.trainer.logger_connector.on_train_epoch_end() # prepare epoch output processed_epoch_output = TrainLoop._prepare_outputs(epoch_output, batch_mode=False) # get the model and call model.training_epoch_end model = self.trainer.lightning_module if is_overridden('training_epoch_end', model=model): # run training_epoch_end # refresh the result for custom logging at the epoch level model._current_fx_name = 'training_epoch_end' training_epoch_end_output = model.training_epoch_end( processed_epoch_output) if training_epoch_end_output is not None: raise MisconfigurationException( 'training_epoch_end expects a return of None. ' 'HINT: remove the return statement in training_epoch_end') # capture logging self.trainer.logger_connector.cache_logged_metrics() # call train epoch end hooks self._on_train_epoch_end_hook(processed_epoch_output) self.trainer.call_hook('on_epoch_end') def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: # We cannot rely on Trainer.call_hook because the signatures might be different across # lightning module and callback # As a result, we need to inspect if the module accepts `outputs` in `on_train_epoch_end` # This implementation is copied from Trainer.call_hook hook_name = "on_train_epoch_end" # set hook_name to model + reset Result obj skip = self.trainer._reset_result_and_set_fx_name(hook_name) # always profile hooks with self.trainer.profiler.profile(hook_name): # first call trainer hook if hasattr(self.trainer, hook_name): trainer_hook = getattr(self.trainer, hook_name) trainer_hook(processed_epoch_output) # next call hook in lightningModule model_ref = self.trainer.lightning_module if is_overridden(hook_name, model_ref): hook_fx = getattr(model_ref, hook_name) if is_param_in_hook_signature(hook_fx, "outputs"): self.warning_cache.warn( "The signature of `ModelHooks.on_train_epoch_end` has changed in v1.3." " `outputs` parameter has been deprecated." " Support for the old signature will be removed in v1.5", DeprecationWarning) model_ref.on_train_epoch_end(processed_epoch_output) else: model_ref.on_train_epoch_end() # if the PL module doesn't have the hook then call the accelerator # used to auto-reduce things for the user with Results obj elif hasattr(self.trainer.accelerator, hook_name): accelerator_hook = getattr(self.trainer.accelerator, hook_name) accelerator_hook() if not skip: self.trainer._cache_logged_metrics() def run_training_batch(self, batch, batch_idx, dataloader_idx): # track grad norms grad_norm_dict = {} # bookkeeping self._hiddens = None optimizers = list(enumerate(self.trainer.optimizers)) # track all outputs across time and num of optimizers batch_outputs = [[] for _ in range(len(optimizers))] if batch is None: self.warning_cache.warn( "train_dataloader yielded None. If this was on purpose, ignore this warning..." ) return AttributeDict( signal=0, grad_norm_dict={}, training_step_output_for_epoch_end=batch_outputs, ) # hook response = self.trainer.call_hook("on_batch_start") if response == -1: return AttributeDict(signal=-1, grad_norm_dict={}) # hook response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, dataloader_idx) if response == -1: return AttributeDict(signal=-1, grad_norm_dict={}) # lightning module hook splits = self._tbptt_split_batch(batch) for split_idx, split_batch in enumerate(splits): self.split_idx = split_idx if self.trainer.lightning_module.automatic_optimization: for opt_idx, optimizer in self.get_active_optimizers( batch_idx): result = self._run_optimization(batch_idx, split_idx, split_batch, opt_idx, optimizer) if result: batch_outputs[opt_idx].append( result.training_step_output_for_epoch_end) grad_norm_dict = result.get("grad_norm_dict", {}) else: # in manual optimization, there is no looping over optimizers result = self._run_optimization(batch_idx, split_idx, split_batch) if result: batch_outputs[0].append( result.training_step_output_for_epoch_end) output = AttributeDict( signal=0, # todo: Properly aggregate grad_norm accros opt_idx and split_idx grad_norm_dict=grad_norm_dict, training_step_output_for_epoch_end=batch_outputs, ) return output def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimizer=None): # TODO: In v1.5, when optimizer_idx gets removed from training_step in manual_optimization, change # opt_idx=0 to opt_idx=None in the signature here # toggle model params + set info to logger_connector self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) result = AttributeDict() closure = self.make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens, result) if self.should_accumulate(): # For gradient accumulation # ------------------- # calculate loss (train step + train step end) # ------------------- # automatic_optimization=True: perform ddp sync only when performing optimizer_step # automatic_optimization=False: don't block synchronization here with self.block_ddp_sync_behaviour(): closure() # ------------------------------ # BACKWARD PASS # ------------------------------ # gradient update with accumulated gradients else: if self.trainer.lightning_module.automatic_optimization: self.optimizer_step(optimizer, opt_idx, batch_idx, closure) else: result = self.training_step(split_batch, batch_idx, opt_idx, self._hiddens) if not result: # user decided to skip optimization return result # update running loss + reset accumulated loss self.update_running_loss(result.loss) self._process_closure_result(result) return result def training_step_and_backward_closure( self, split_batch: Any, batch_idx: int, opt_idx: int, optimizer: Optimizer, hiddens, return_result: AttributeDict, ) -> Optional[torch.Tensor]: step_result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) if step_result is not None: return_result.update(step_result) return return_result.loss def make_closure(self, *closure_args, **closure_kwargs: Any) -> Callable: """ Wraps the training step closure into a partial object which will be called within ``optimizer.step``. """ partial_func = partial(self.training_step_and_backward_closure, *closure_args, **closure_kwargs) return update_wrapper(partial_func, self.training_step_and_backward_closure) @contextmanager def block_ddp_sync_behaviour(self, should_block_sync: bool = False): """ automatic_optimization = True Blocks ddp sync gradients behaviour on backwards pass. This is useful for skipping sync when accumulating gradients, reducing communication overhead automatic_optimization = False do not block ddp gradient sync when using manual optimization as gradients are needed within the training step Returns: context manager with sync behaviour off """ if (isinstance(self.trainer.training_type_plugin, ParallelPlugin) and (self.trainer.lightning_module.automatic_optimization or should_block_sync)): with self.trainer.training_type_plugin.block_backward_sync(): yield None else: yield None def _process_closure_result( self, opt_closure_result: Optional[AttributeDict]) -> None: if not opt_closure_result: return # cache metrics self.trainer.logger_connector.cache_training_step_metrics( opt_closure_result) # check if loss or model weights are nan if self.trainer.terminate_on_nan: self._check_finite(opt_closure_result.loss) def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): """Wrap forward, zero_grad and backward in a closure so second order methods work""" with self.trainer.profiler.profile("training_step_and_backward"): # lightning module hook result = self.training_step(split_batch, batch_idx, opt_idx, hiddens) if not self._skip_backward and self.trainer.lightning_module.automatic_optimization: is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0 if is_first_batch_to_accumulate: self.on_before_zero_grad(optimizer) self.optimizer_zero_grad(batch_idx, optimizer, opt_idx) # backward pass if result is not None: with self.trainer.profiler.profile("backward"): self.backward(result, optimizer, opt_idx) # hook - call this hook only # when gradients have finished to accumulate if not self.should_accumulate(): self.on_after_backward(result.training_step_output, batch_idx, result.loss) # check if loss or model weights are nan if self.trainer.terminate_on_nan: self._check_finite(result.loss) else: self.warning_cache.warn( "training_step returned None. If this was on purpose, ignore this warning..." ) if len(self.trainer.optimizers) > 1: # revert back to previous state self.trainer.lightning_module.untoggle_optimizer(opt_idx) return result def _check_finite(self, loss: torch.Tensor) -> None: if not torch.isfinite(loss).all(): raise ValueError( f'The loss returned in `training_step` is {loss}.') model = self.trainer.lightning_module detect_nan_parameters(model) def backward(self, result, optimizer, opt_idx, *args, **kwargs): self.trainer.dev_debugger.track_event("backward_call") should_accumulate = self.should_accumulate() # backward can be called manually in the training loop if isinstance(result, torch.Tensor): self.trainer.accelerator.backward(result, optimizer, opt_idx, should_accumulate, *args, **kwargs) else: result.closure_loss = self.trainer.accelerator.backward( result.closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs) if not self.should_accumulate(): # track gradients result.grad_norm_dict = self.track_and_norm_grad( optimizer=optimizer) def update_train_loop_lr_schedulers(self, monitor_metrics=None): num_accumulated_batches_reached = self._accumulated_batches_reached() num_training_batches_reached = self._num_training_batches_reached() if num_accumulated_batches_reached or num_training_batches_reached: # update lr self.trainer.optimizer_connector.update_learning_rates( interval="step", monitor_metrics=monitor_metrics, opt_indices=[ opt_idx for opt_idx, _ in self.get_active_optimizers() ], ) def increment_accumulated_grad_global_step(self): num_accumulated_batches_reached = self._accumulated_batches_reached() num_training_batches_reached = self._num_training_batches_reached() # progress global step according to grads progress if num_accumulated_batches_reached or num_training_batches_reached: self.global_step = self.trainer.accelerator.update_global_step( self.total_batch_idx, self.global_step) def _accumulated_batches_reached(self): return (self.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 def _num_training_batches_reached(self, is_last_batch=False): return (self.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch def should_accumulate(self): # checks if backward or backward + optimizer step (via closure) accumulation_done = self._accumulated_batches_reached() is_final_batch = self._num_training_batches_reached() return not (accumulation_done or is_final_batch) def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool: """ Decide if we should run validation. """ if not self.trainer.enable_validation: return False # check if this epoch is eligible to run validation if (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0: return False # val_check_batch is inf for iterable datasets with no length defined # TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch is_val_check_batch = False if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'): is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 elif self.trainer.val_check_batch != float('inf'): is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 # Note: num_training_batches is also inf for iterable datasets with no length defined epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0 is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float( "inf") if on_epoch: return ( is_val_check_batch and epoch_end_val_check ) or self.trainer.should_stop or is_last_batch_for_infinite_dataset else: return is_val_check_batch and not epoch_end_val_check def _build_kwargs(self, batch, batch_idx, opt_idx, hiddens): # enable not needing to add opt_idx to training_step step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)]) lightning_module = self.trainer.lightning_module if len(self.trainer.optimizers) > 1: training_step_fx = getattr(lightning_module, "training_step") has_opt_idx_in_train_step = is_param_in_hook_signature( training_step_fx, "optimizer_idx") if has_opt_idx_in_train_step: if not lightning_module.automatic_optimization: self.warning_cache.warn( "`training_step` hook signature has changed in v1.3." " `optimizer_idx` argument has been removed in case of manual optimization. Support for" " the old signature will be removed in v1.5", DeprecationWarning) step_kwargs['optimizer_idx'] = opt_idx elif not has_opt_idx_in_train_step and self.trainer.lightning_module.automatic_optimization: raise ValueError( f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but" ' `training_step` is missing the `optimizer_idx` argument.' ) # pass hiddens if using tbptt if self._truncated_bptt_enabled(): step_kwargs['hiddens'] = hiddens return step_kwargs def _truncated_bptt_enabled(self) -> bool: """ Temporary tbptt utilities until this flag is fully migrated to the lightning module. """ return self._truncated_bptt_steps() > 0 def _truncated_bptt_steps(self) -> int: lightning_module = self.trainer.lightning_module # Give precedence to the LightningModule as the Trainer flag will be removed in v1.5 if lightning_module.truncated_bptt_steps > 0: return lightning_module.truncated_bptt_steps return self.trainer.truncated_bptt_steps or 0 def save_loggers_on_train_batch_end(self): # when loggers should save to disk should_flush_logs = self.trainer.logger_connector.should_flush_logs if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): # make sure only the gradients of the current optimizer's parameters are calculated # in the training step to prevent dangling gradients in multiple-optimizer setup. if self.trainer.lightning_module.automatic_optimization and len( self.trainer.optimizers) > 1: model = self.trainer.lightning_module model.toggle_optimizer(optimizer, opt_idx) # use to track metrics internally self.trainer.logger_connector.on_train_split_start( split_idx, opt_idx, split_batch) def update_running_loss(self, current_loss: torch.Tensor) -> None: if self.trainer.lightning_module.automatic_optimization: # track total loss for logging (avoid mem leaks) self.accumulated_loss.append(current_loss) accumulated_loss = self.accumulated_loss.mean() if accumulated_loss is not None: # calculate running loss for display self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches) # reset for next set of accumulated grads self.accumulated_loss.reset()
def train(self): self.run_sanity_check(self.get_model()) # enable train mode model = self.get_model() model.train() torch.set_grad_enabled(True) # reload data when needed self.train_loop.reset_train_val_dataloaders(model) # hook self.train_loop.on_train_start() try: # run all epochs for epoch in range(self.current_epoch, self.max_epochs): # reset train dataloader if self.reload_dataloaders_every_epoch: self.reset_train_dataloader(model) # set seed for distributed sampler (enables shuffling for each epoch) if (self.use_ddp or self.use_horovod or self.on_tpu) \ and hasattr(self.train_dataloader, 'sampler') \ and hasattr(self.train_dataloader.sampler, 'set_epoch'): self.train_dataloader.sampler.set_epoch(epoch) # update training progress in trainer and model model.current_epoch = epoch self.current_epoch = epoch # changing gradient according accumulation_scheduler self.accumulation_scheduler.on_epoch_start( self, self.get_model()) # stores accumulated grad fractions per batch self.batch_loss_value = TensorRunningAccum( window_length=self.accumulate_grad_batches) # ----------------- # RUN TNG EPOCH # ----------------- self.run_training_epoch() if self.max_steps and self.max_steps <= self.global_step: # hook self.train_loop.on_train_end() return # update LR schedulers self.update_learning_rates(interval='epoch') # early stopping met_min_epochs = epoch >= self.min_epochs - 1 met_min_steps = self.global_step >= self.min_steps if self.min_steps else True if self.should_stop: if (met_min_epochs and met_min_steps): self.train_loop.on_train_end() return else: log.info( 'Trainer was signaled to stop but required minimum epochs' f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has' ' not been met. Training will continue...') # hook self.train_loop.on_train_end() except KeyboardInterrupt: rank_zero_warn( 'Detected KeyboardInterrupt, attempting graceful shutdown...') # user could press ctrl+c many times... only shutdown once if not self.interrupted: self.interrupted = True self._state = TrainerState.INTERRUPTED self.on_keyboard_interrupt() # hook self.train_loop.on_train_end()
def train(self): # add signal handlers for process kills # def _signal_kill_handler(*args): # return TrainerTrainLoopMixin.run_training_teardown(self) # # orig_signal_handlers = {} # for sig_name in SIGNAL_TERMINATE: # orig_signal_handlers[sig_name] = signal.signal(getattr(signal, sig_name), # _signal_kill_handler) # get model model = self.get_model() # load data # if reload_dataloaders_every_epoch, this is moved to the epoch loop if not self.reload_dataloaders_every_epoch: self.reset_train_dataloader(model) self.reset_val_dataloader(model) # Train start events with self.profiler.profile('on_train_start'): # callbacks self.on_train_start() # initialize early stop callback if self.early_stop_callback is not None: self.early_stop_callback.on_train_start(self, self.get_model()) # model hooks model.on_train_start() try: # run all epochs for epoch in range(self.current_epoch, self.max_epochs): # reset train dataloader if self.reload_dataloaders_every_epoch: self.reset_train_dataloader(model) # set seed for distributed sampler (enables shuffling for each epoch) if (self.use_ddp or self.use_horovod) \ and hasattr(self.train_dataloader, 'sampler') \ and hasattr(self.train_dataloader.sampler, 'set_epoch'): self.train_dataloader.sampler.set_epoch(epoch) # update training progress in trainer and model model.current_epoch = epoch self.current_epoch = epoch # changing gradient according accumulation_scheduler self.accumulation_scheduler.on_epoch_start( self, self.get_model()) # stores accumulated grad fractions per batch self.batch_loss_value = TensorRunningAccum( window_length=self.accumulate_grad_batches) # ----------------- # RUN TNG EPOCH # ----------------- self.run_training_epoch() if self.max_steps and self.max_steps == self.global_step: self.run_training_teardown() return # update LR schedulers self.update_learning_rates(interval='epoch') # early stopping met_min_epochs = epoch >= self.min_epochs - 1 met_min_steps = self.global_step >= self.min_steps if self.min_steps else True # TODO wrap this logic into the callback if self.enable_early_stop: if (met_min_epochs and met_min_steps) or self.fast_dev_run: should_stop = self.early_stop_callback.on_validation_end( self, self.get_model()) # stop training stop = should_stop and met_min_epochs if stop: self.run_training_teardown() return self.run_training_teardown() except KeyboardInterrupt: rank_zero_warn( 'Detected KeyboardInterrupt, attempting graceful shutdown...') # user could press ctrl+c many times... only shutdown once if not self.interrupted: self.interrupted = True for proc in self.interactive_ddp_procs: subprocess.Popen.kill(proc) self.run_training_teardown()
class TrainerTrainLoopMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class max_epochs: int min_epochs: int on_gpu: bool use_ddp: bool use_dp: bool use_ddp2: bool use_horovod: bool single_gpu: bool use_tpu: bool data_parallel_device_ids:... check_val_every_n_epoch:... num_training_batches: int val_check_batch:... num_val_batches: int disable_validation: bool fast_dev_run:... accumulation_scheduler:... lr_schedulers:... enable_early_stop:... early_stop_callback:... callback_metrics:... logger: Union[LightningLoggerBase, bool] global_step: int testing: bool log_save_interval: float proc_rank: int row_log_interval: float truncated_bptt_steps:... optimizers:... optimizer_frequencies:... accumulate_grad_batches: int track_grad_norm:... model: LightningModule interrupted: bool running_loss:... progress_bar_dict:... reduce_lr_on_plateau_scheduler:... profiler:... batch_idx: int precision:... train_dataloader: DataLoader reload_dataloaders_every_epoch: bool max_steps: int min_steps: int total_batch_idx: int checkpoint_callback:... terminate_on_nan: bool tpu_id: int # Callback system callbacks: List[Callback] on_train_start: Callable on_train_end: Callable on_batch_start: Callable on_batch_end: Callable on_epoch_start: Callable on_epoch_end: Callable on_validation_end: Callable @abstractmethod def get_model(self): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def is_function_implemented(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def run_evaluation(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def transfer_batch_to_gpu(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def transfer_batch_to_tpu(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def clip_gradients(self): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def detect_nan_tensors(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def is_overridden(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def add_progress_bar_metrics(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def log_metrics(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def process_output(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def reset_train_dataloader(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def reset_val_dataloader(self, model): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def has_arg(self, *args): """Warning: this is just empty shell for code implemented in other class.""" def train(self): # add signal handlers for process kills # def _signal_kill_handler(*args): # return TrainerTrainLoopMixin.run_training_teardown(self) # # orig_signal_handlers = {} # for sig_name in SIGNAL_TERMINATE: # orig_signal_handlers[sig_name] = signal.signal(getattr(signal, sig_name), # _signal_kill_handler) # get model model = self.get_model() # load data # if reload_dataloaders_every_epoch, this is moved to the epoch loop if not self.reload_dataloaders_every_epoch: self.reset_train_dataloader(model) self.reset_val_dataloader(model) # Train start events with self.profiler.profile('on_train_start'): # callbacks self.on_train_start() # initialize early stop callback if self.early_stop_callback is not None: self.early_stop_callback.on_train_start(self, self.get_model()) # model hooks model.on_train_start() try: # run all epochs for epoch in range(self.current_epoch, self.max_epochs): # reset train dataloader if self.reload_dataloaders_every_epoch: self.reset_train_dataloader(model) # set seed for distributed sampler (enables shuffling for each epoch) if (self.use_ddp or self.use_horovod) \ and hasattr(self.train_dataloader, 'sampler') \ and hasattr(self.train_dataloader.sampler, 'set_epoch'): self.train_dataloader.sampler.set_epoch(epoch) # update training progress in trainer and model model.current_epoch = epoch self.current_epoch = epoch # changing gradient according accumulation_scheduler self.accumulation_scheduler.on_epoch_start( self, self.get_model()) # stores accumulated grad fractions per batch self.batch_loss_value = TensorRunningAccum( window_length=self.accumulate_grad_batches) # ----------------- # RUN TNG EPOCH # ----------------- self.run_training_epoch() if self.max_steps and self.max_steps == self.global_step: self.run_training_teardown() return # update LR schedulers self.update_learning_rates(interval='epoch') # early stopping met_min_epochs = epoch >= self.min_epochs - 1 met_min_steps = self.global_step >= self.min_steps if self.min_steps else True # TODO wrap this logic into the callback if self.enable_early_stop: if (met_min_epochs and met_min_steps) or self.fast_dev_run: should_stop = self.early_stop_callback.on_validation_end( self, self.get_model()) # stop training stop = should_stop and met_min_epochs if stop: self.run_training_teardown() return self.run_training_teardown() except KeyboardInterrupt: rank_zero_warn( 'Detected KeyboardInterrupt, attempting graceful shutdown...') # user could press ctrl+c many times... only shutdown once if not self.interrupted: self.interrupted = True for proc in self.interactive_ddp_procs: subprocess.Popen.kill(proc) self.run_training_teardown() def run_training_epoch(self): # get model model = self.get_model() # Epoch start events with self.profiler.profile('on_epoch_start'): # callbacks self.on_epoch_start() # model hooks if self.is_function_implemented('on_epoch_start'): model.on_epoch_start() # track local dataloader so TPU can wrap each epoch train_dataloader = self.train_dataloader # on TPU we have to wrap it under the ParallelLoader if self.use_tpu: device = xm.xla_device(self.tpu_id) train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device]) train_dataloader = train_dataloader.per_device_loader(device) # bookkeeping outputs = [] # run epoch for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable( enumerate(_with_is_last(train_dataloader)), "get_train_batch"): # stop epoch if we limited the number of training batches if batch_idx >= self.num_training_batches: break self.batch_idx = batch_idx model.global_step = self.global_step # --------------- # RUN TRAIN STEP # --------------- _outputs = self.run_training_batch(batch, batch_idx) batch_result, grad_norm_dic, batch_step_metrics, batch_output = _outputs # only track outputs when user implements training_epoch_end # otherwise we will build up unnecessary memory if self.is_overridden('training_epoch_end', model=self.get_model()): outputs.append(batch_output) # when returning -1 from train_step, we end epoch early early_stop_epoch = batch_result == -1 # TODO: consolidate all actions that need to take place only after # self.accumulate_grad_batches steps (optimizer step, lr update, global step increment) if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: # update lr self.update_learning_rates(interval='step') # --------------- # RUN VAL STEP # --------------- is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0 can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 can_check_val = not self.disable_validation and can_check_epoch should_check_val = is_val_check_batch or early_stop_epoch should_check_val = should_check_val or ( is_last_batch and self.val_check_batch == float('inf')) should_check_val = can_check_val and should_check_val # --------------- # CHECKPOINTING, EARLY STOPPING # --------------- # fast_dev_run always forces val checking after train batch if self.fast_dev_run or should_check_val: self.run_evaluation(test_mode=self.testing) self.call_checkpoint_callback() # when logs should be saved should_save_log = ( batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch if should_save_log or self.fast_dev_run: if self.proc_rank == 0 and self.logger is not None: self.logger.save() # when metrics should be logged should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch if should_log_metrics or self.fast_dev_run: # logs user requested information to logger self.log_metrics(batch_step_metrics, grad_norm_dic) # progress global step according to grads progress if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: self.global_step += 1 self.total_batch_idx += 1 # max steps reached, end training if self.max_steps is not None and self.max_steps == self.global_step: break # end epoch early # stop when the flag is changed or we've gone past the amount # requested in the batches if early_stop_epoch or self.fast_dev_run: break if self.use_horovod: hvd.join(hvd.local_rank() if self.on_gpu else -1) # process epoch outputs model = self.get_model() if self.is_overridden('training_epoch_end', model=model): epoch_output = model.training_epoch_end(outputs) _processed_outputs = self.process_output(epoch_output) log_epoch_metrics = _processed_outputs[2] callback_epoch_metrics = _processed_outputs[3] self.log_metrics(log_epoch_metrics, {}) self.callback_metrics.update(callback_epoch_metrics) self.add_progress_bar_metrics(_processed_outputs[1]) # when no val loop is present or fast-dev-run still need to call checkpoints if not self.is_overridden('validation_step') and not ( self.fast_dev_run or should_check_val): self.call_checkpoint_callback() # Epoch end events with self.profiler.profile('on_epoch_end'): # callbacks self.on_epoch_end() # model hooks if self.is_function_implemented('on_epoch_end'): model.on_epoch_end() def run_training_batch(self, batch, batch_idx): # track grad norms grad_norm_dic = {} # track all metrics for callbacks all_callback_metrics = [] # track metrics to log all_log_metrics = [] if batch is None: return 0, grad_norm_dic, {}, {} # Batch start events with self.profiler.profile('on_batch_start'): # callbacks self.on_batch_start() # hooks if self.is_function_implemented('on_batch_start'): response = self.get_model().on_batch_start(batch) if response == -1: return -1, grad_norm_dic, {}, {} splits = [batch] if self.truncated_bptt_steps is not None: model_ref = self.get_model() with self.profiler.profile('tbptt_split_batch'): splits = model_ref.tbptt_split_batch(batch, self.truncated_bptt_steps) self.hiddens = None for split_idx, split_batch in enumerate(splits): self.split_idx = split_idx for opt_idx, optimizer in self._get_optimizers_iterable(): # make sure only the gradients of the current optimizer's parameters are calculated # in the training step to prevent dangling gradients in multiple-optimizer setup. if len(self.optimizers) > 1: for param in self.get_model().parameters(): param.requires_grad = False for group in optimizer.param_groups: for param in group['params']: param.requires_grad = True # wrap the forward step in a closure so second order methods work def optimizer_closure(): # forward pass with self.profiler.profile('model_forward'): if self.use_amp and self.use_native_amp: with torch.cuda.amp.autocast(): output_dict = self.training_forward( split_batch, batch_idx, opt_idx, self.hiddens) else: output_dict = self.training_forward( split_batch, batch_idx, opt_idx, self.hiddens) # format and reduce outputs accordingly processed_output = self.process_output(output_dict, train=True) closure_loss, progress_bar_metrics, log_metrics, callback_metrics, self.hiddens = processed_output # accumulate loss # (if accumulate_grad_batches = 1 no effect) closure_loss = closure_loss / self.accumulate_grad_batches # backward pass model_ref = self.get_model() with self.profiler.profile('model_backward'): model_ref.backward(self, closure_loss, optimizer, opt_idx) # track metrics for callbacks all_callback_metrics.append(callback_metrics) # track progress bar metrics self.add_progress_bar_metrics(progress_bar_metrics) all_log_metrics.append(log_metrics) if self.use_horovod: # Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid optimizer.synchronize() # insert after step hook if self.is_function_implemented('on_after_backward'): model_ref = self.get_model() with self.profiler.profile('on_after_backward'): model_ref.on_after_backward() return closure_loss, callback_metrics # calculate loss loss, batch_output = optimizer_closure() # check if loss or model weights are nan if self.terminate_on_nan: self.detect_nan_tensors(loss) # track total loss for logging (avoid mem leaks) self.batch_loss_value.append(loss) # gradient update with accumulated gradients if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: # track gradient norms when requested if batch_idx % self.row_log_interval == 0: if float(self.track_grad_norm) > 0: model = self.get_model() grad_norm_dic = model.grad_norm( self.track_grad_norm) # clip gradients if self.use_amp and self.use_native_amp: self.scaler.unscale_(optimizer) self.clip_gradients() # calls .step(), .zero_grad() # override function to modify this behavior model = self.get_model() with self.profiler.profile('optimizer_step'): model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx, lambda: optimizer_closure()[0]) # calculate running loss for display self.running_loss.append(self.batch_loss_value.mean()) # reset for next set of accumulated grads self.batch_loss_value.reset() # Batch end events with self.profiler.profile('on_batch_end'): # callbacks self.on_batch_end() # model hooks if self.is_function_implemented('on_batch_end'): self.get_model().on_batch_end() # collapse all metrics into one dict all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()} # track all metrics for callbacks self.callback_metrics.update( {k: v for d in all_callback_metrics for k, v in d.items()}) return 0, grad_norm_dic, all_log_metrics, batch_output def _get_optimizers_iterable(self): if not self.optimizer_frequencies: # call training_step once per optimizer return list(enumerate(self.optimizers)) optimizer_freq_cumsum = np.cumsum(self.optimizer_frequencies) optimizers_loop_length = optimizer_freq_cumsum[-1] current_place_in_loop = self.total_batch_idx % optimizers_loop_length # find optimzier index by looking for the first {item > current_place} in the cumsum list opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop) return [(opt_idx, self.optimizers[opt_idx])] # @atexit.register def run_training_teardown(self): if hasattr(self, '_teardown_already_run') and self._teardown_already_run: return # Train end events with self.profiler.profile('on_train_end'): # callbacks self.on_train_end() # model hooks if self.is_function_implemented('on_train_end'): self.get_model().on_train_end() if self.logger is not None: self.logger.finalize("success") # summarize profile results self.profiler.describe() self._teardown_already_run = True def training_forward(self, batch, batch_idx, opt_idx, hiddens): """ Handle forward for each training case (distributed, single gpu, etc...) :param batch: :param batch_idx: :return: """ # --------------- # FORWARD # --------------- # enable not needing to add opt_idx to training_step args = [batch, batch_idx] if len(self.optimizers) > 1: if self.has_arg('training_step', 'optimizer_idx'): args.append(opt_idx) else: num_opts = len(self.optimizers) raise ValueError( f'Your LightningModule defines {num_opts} optimizers but ' f'training_step is missing the "optimizer_idx" argument.') # pass hiddens if using tbptt if self.truncated_bptt_steps is not None: args.append(hiddens) # distributed forward if self.use_ddp or self.use_ddp2 or self.use_dp: output = self.model(*args) # Horovod elif self.use_horovod and self.on_gpu: batch = self.transfer_batch_to_gpu(batch, hvd.local_rank()) args[0] = batch output = self.model.training_step(*args) # single GPU forward elif self.single_gpu: gpu_id = 0 if isinstance(self.data_parallel_device_ids, list): gpu_id = self.data_parallel_device_ids[0] # Don't copy the batch since there is a single gpu that the batch could # be referenced from and if there are multiple optimizers the batch will # wind up copying it to the same device repeatedly. batch = self.transfer_batch_to_gpu(batch, gpu_id) args[0] = batch output = self.model.training_step(*args) # TPU support elif self.use_tpu: batch = self.transfer_batch_to_tpu(batch, self.tpu_id) args[0] = batch output = self.model.training_step(*args) # CPU forward else: output = self.model.training_step(*args) # allow any mode to define training_step_end # do something will all the dp outputs (like softmax) if self.is_overridden('training_step_end'): model_ref = self.get_model() with self.profiler.profile('training_step_end'): output = model_ref.training_step_end(output) # allow any mode to define training_end # TODO: remove in 1.0.0 if self.is_overridden('training_end'): model_ref = self.get_model() with self.profiler.profile('training_end'): output = model_ref.training_end(output) rank_zero_warn( '`training_end` was deprecated in 0.7.0 and will be removed 1.0.0.' ' Use training_epoch_end instead', DeprecationWarning) return output def update_learning_rates(self, interval: str): """Update learning rates. Args: interval: either 'epoch' or 'step'. """ if not self.lr_schedulers: return for lr_scheduler in self.lr_schedulers: current_idx = self.batch_idx if interval == 'step' else self.current_epoch current_idx += 1 # account for both batch and epoch starts from 0 # Take step if call to update_learning_rates matches the interval key and # the current step modulo the schedulers frequency is zero if lr_scheduler[ 'interval'] == interval and current_idx % lr_scheduler[ 'frequency'] == 0: # If instance of ReduceLROnPlateau, we need to pass validation loss if lr_scheduler['reduce_on_plateau']: monitor_key = lr_scheduler['monitor'] monitor_val = self.callback_metrics.get(monitor_key) if monitor_val is None: avail_metrics = ','.join( list(self.callback_metrics.keys())) raise MisconfigurationException( f'ReduceLROnPlateau conditioned on metric {monitor_key}' f' which is not available. Available metrics are: {avail_metrics}.' ' Condition can be set using `monitor` key in lr scheduler dict' ) lr_scheduler['scheduler'].step(monitor_val) else: lr_scheduler['scheduler'].step() def call_checkpoint_callback(self): if self.checkpoint_callback is not None: self.checkpoint_callback.on_validation_end(self, self.get_model())
def __init__( self, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, checkpoint_callback: Union[ModelCheckpoint, bool] = True, early_stop_callback: Optional[Union[EarlyStopping, bool]] = False, callbacks: Optional[List[Callback]] = None, default_root_dir: Optional[str] = None, gradient_clip_val: float = 0, process_position: int = 0, num_nodes: int = 1, num_processes: int = 1, gpus: Optional[Union[List[int], str, int]] = None, auto_select_gpus: bool = False, tpu_cores: Optional[Union[List[int], int]] = None, log_gpu_memory: Optional[str] = None, progress_bar_refresh_rate: int = 1, overfit_pct: float = 0.0, track_grad_norm: int = -1, check_val_every_n_epoch: int = 1, fast_dev_run: bool = False, accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1, max_epochs: int = 1000, min_epochs: int = 1, max_steps: Optional[int] = None, min_steps: Optional[int] = None, train_percent_check: float = 1.0, val_percent_check: float = 1.0, test_percent_check: float = 1.0, val_check_interval: float = 1.0, log_save_interval: int = 100, row_log_interval: int = 10, add_row_log_interval=None, # backward compatible, todo: remove in v0.8.0 distributed_backend: Optional[str] = None, precision: int = 32, print_nan_grads: bool = False, # backward compatible, todo: remove in v0.9.0 weights_summary: Optional[str] = 'full', weights_save_path: Optional[str] = None, num_sanity_val_steps: int = 2, truncated_bptt_steps: Optional[int] = None, resume_from_checkpoint: Optional[str] = None, profiler: Optional[Union[BaseProfiler, bool]] = None, benchmark: bool = False, deterministic: bool = False, reload_dataloaders_every_epoch: bool = False, auto_lr_find: Union[bool, str] = False, replace_sampler_ddp: bool = True, terminate_on_nan: bool = False, auto_scale_batch_size: Union[str, bool] = False, num_tpu_cores: Optional[ int] = None, # backward compatible, todo: remove in v0.9.0 amp_level: str = 'O1', # backward compatible, todo: remove in v0.8.0 default_save_path=None, # backward compatible, todo: remove in v0.8.0 gradient_clip=None, # backward compatible, todo: remove in v0.8.0 nb_gpu_nodes=None, # backward compatible, todo: remove in v0.8.0 max_nb_epochs=None, # backward compatible, todo: remove in v0.8.0 min_nb_epochs=None, # backward compatible, todo: remove in v0.8.0 use_amp=None, # backward compatible, todo: remove in v0.9.0 show_progress_bar=None, # backward compatible, todo: remove in v0.9.0 nb_sanity_val_steps=None, # backward compatible, todo: remove in v0.8.0 ): r""" Customize every aspect of training via flags Args: logger: Logger (or iterable collection of loggers) for experiment tracking. checkpoint_callback: Callback for checkpointing. early_stop_callback (:class:`pytorch_lightning.callbacks.EarlyStopping`): callbacks: Add a list of callbacks. default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed default_save_path: .. warning:: .. deprecated:: 0.7.3 Use `default_root_dir` instead. Will remove 0.9.0. gradient_clip_val: 0 means don't clip. gradient_clip: .. warning:: .. deprecated:: 0.7.0 Use `gradient_clip_val` instead. Will remove 0.9.0. process_position: orders the progress bar when running multiple models on same machine. num_nodes: number of GPU nodes for distributed training. nb_gpu_nodes: .. warning:: .. deprecated:: 0.7.0 Use `num_nodes` instead. Will remove 0.9.0. gpus: Which GPUs to train on. auto_select_gpus: If enabled and `gpus` is an integer, pick available gpus automatically. This is especially useful when GPUs are configured to be in "exclusive mode", such that only one process at a time can access them. tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on [1] num_tpu_cores: How many TPU cores to train on (1 or 8) .. warning:: .. deprecated:: 0.7.6. Will remove 0.9.0. log_gpu_memory: None, 'min_max', 'all'. Might slow performance show_progress_bar: .. warning:: .. deprecated:: 0.7.2 Set `progress_bar_refresh_rate` to positive integer to enable. Will remove 0.9.0. progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar. Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`. overfit_pct: How much of training-, validation-, and test dataset to check. track_grad_norm: -1 no tracking. Otherwise tracks that norm check_val_every_n_epoch: Check val every n train epochs. fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test). accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict. max_epochs: Stop training once this number of epochs is reached. max_nb_epochs: .. warning:: .. deprecated:: 0.7.0 Use `max_epochs` instead. Will remove 0.9.0. min_epochs: Force training for at least these many epochs min_nb_epochs: .. warning:: .. deprecated:: 0.7.0 Use `min_epochs` instead. Will remove 0.9.0. max_steps: Stop training after this number of steps. Disabled by default (None). min_steps: Force training for at least these number of steps. Disabled by default (None). train_percent_check: How much of training dataset to check. val_percent_check: How much of validation dataset to check. test_percent_check: How much of test dataset to check. val_check_interval: How often within one training epoch to check the validation set log_save_interval: Writes logs to disk this often row_log_interval: How often to add logging rows (does not write to disk) add_row_log_interval: .. warning:: .. deprecated:: 0.7.0 Use `row_log_interval` instead. Will remove 0.9.0. distributed_backend: The distributed backend to use. use_amp: .. warning:: .. deprecated:: 0.7.0 Use `precision` instead. Will remove 0.9.0. precision: Full precision (32), half precision (16). print_nan_grads: .. warning:: .. deprecated:: 0.7.2 Has no effect. When detected, NaN grads will be printed automatically. Will remove 0.9.0. weights_summary: Prints a summary of the weights when training begins. weights_save_path: Where to save weights if specified. Will override default_root_dir for checkpoints only. Use this if for whatever reason you need the checkpoints stored in a different place than the logs written in `default_root_dir`. amp_level: The optimization level to use (O1, O2, etc...). num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine. nb_sanity_val_steps: .. warning:: .. deprecated:: 0.7.0 Use `num_sanity_val_steps` instead. Will remove 0.8.0. truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here. profiler: To profile individual steps during training and assist in reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch auto_lr_find: If set to True, will `initially` run a learning rate finder, trying to optimize initial learning for faster convergence. Sets learning rate in self.lr or self.learning_rate in the LightningModule. To use a different key, set a string instead of True with the key name. replace_sampler_ddp: Explicitly enables or disables sampler replacement. If not specified this will toggled automatically ddp is used benchmark: If true enables cudnn.benchmark. deterministic: If true enables cudnn.deterministic terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the end of each training batch, if any of the parameters or the loss are NaN or +/-inf. auto_scale_batch_size: If set to True, will `initially` run a batch size finder trying to find the largest batch size that fits into memory. The result will be stored in self.batch_size in the LightningModule. Additionally, can be set to either `power` that estimates the batch size through a power search or `binsearch` that estimates the batch size through a binary search. """ super().__init__() self.deterministic = deterministic torch.backends.cudnn.deterministic = self.deterministic if self.deterministic: # fixing non-deterministic part of horovod # https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383 os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0) # Init callbacks self.callbacks = callbacks or [] self.on_init_start() # benchmarking self.benchmark = benchmark torch.backends.cudnn.benchmark = self.benchmark # Transfer params self.num_nodes = num_nodes # Backward compatibility, TODO: remove in v0.8.0 if nb_gpu_nodes is not None: rank_zero_warn( "Argument `nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) self.num_gpu_nodes = nb_gpu_nodes self.log_gpu_memory = log_gpu_memory self.gradient_clip_val = gradient_clip_val # Backward compatibility, TODO: remove in v0.8.0 if gradient_clip is not None: rank_zero_warn( "Argument `gradient_clip` has renamed to `gradient_clip_val` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) self.gradient_clip = gradient_clip self.check_val_every_n_epoch = check_val_every_n_epoch self.track_grad_norm = track_grad_norm self.on_gpu = True if (gpus and torch.cuda.is_available()) else False # tpu config if num_tpu_cores is not None: rank_zero_warn( "Argument `num_tpu_cores` is now set by `tpu_cores` since v0.7.6" " and this argument will be removed in v0.9.0", DeprecationWarning) if tpu_cores is None: tpu_cores = num_tpu_cores self.on_tpu = tpu_cores is not None self.tpu_cores = tpu_cores assert self.tpu_cores in (1, 8, None) or ( isinstance(self.tpu_cores, (list, tuple, set)) and len(self.tpu_cores) == 1), '`tpu_cores` can only be 1, 8 or [<1-8>]' self.tpu_id = tpu_cores[0] if isinstance(tpu_cores, list) else None if num_processes != 1 and distributed_backend != "ddp_cpu": rank_zero_warn( "num_processes is only used for distributed_backend=\"ddp_cpu\". Ignoring it." ) self.num_processes = num_processes self.weights_summary = weights_summary self.max_epochs = max_epochs # Backward compatibility, TODO: remove in v0.8.0 if max_nb_epochs is not None: rank_zero_warn( "Argument `max_nb_epochs` has renamed to `max_epochs` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) self.max_nb_epochs = max_nb_epochs self.min_epochs = min_epochs # Backward compatibility, TODO: remove in v0.8.0 if min_nb_epochs is not None: rank_zero_warn( "Argument `min_nb_epochs` has renamed to `min_epochs` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) self.min_nb_epochs = min_nb_epochs self.max_steps = max_steps self.min_steps = min_steps self.num_sanity_val_steps = num_sanity_val_steps # Backward compatibility, TODO: remove in v0.8.0 if nb_sanity_val_steps is not None: rank_zero_warn( "Argument `nb_sanity_val_steps` has renamed to " "`num_sanity_val_steps` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) self.nb_sanity_val_steps = nb_sanity_val_steps # Backward compatibility, TODO: remove in v0.9.0 if print_nan_grads: rank_zero_warn( "Argument `print_nan_grads` has no effect and will be removed in v0.9.0." " NaN grads will be printed automatically when detected.", DeprecationWarning) self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch self.auto_lr_find = auto_lr_find self.auto_scale_batch_size = auto_scale_batch_size self._is_data_prepared = False self.replace_sampler_ddp = replace_sampler_ddp self.truncated_bptt_steps = truncated_bptt_steps self.resume_from_checkpoint = resume_from_checkpoint self.terminate_on_nan = terminate_on_nan self.shown_warnings = set() self.fast_dev_run = fast_dev_run if self.fast_dev_run: self.num_sanity_val_steps = 0 self.max_epochs = 1 log.info('Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch') # set default save path if user didn't provide one self.default_root_dir = default_root_dir # Backward compatibility, TODO: remove in v0.8.0 if default_save_path is not None: self.default_root_dir = default_save_path if self.default_root_dir is None: self.default_root_dir = os.getcwd() # training bookeeping self.total_batch_idx = 0 self.running_loss = TensorRunningAccum(window_length=20) self.batch_idx = 0 self.progress_bar_metrics = {} self.callback_metrics = {} self.num_val_batches = 0 self.num_training_batches = 0 self.num_test_batches = 0 self.train_dataloader = None self.test_dataloaders = None self.val_dataloaders = None # training state self.model = None self.testing = False self.disable_validation = False self.lr_schedulers = [] self.optimizers = None self.optimizer_frequencies = [] self.global_step = 0 self.current_epoch = 0 self.interrupted = False # configure logger self.configure_logger(logger) # configure profiler if profiler is True: profiler = SimpleProfiler() self.profiler = profiler or PassThroughProfiler() # configure early stop callback # creates a default one if none passed in self.configure_early_stopping(early_stop_callback) # configure checkpoint callback self.checkpoint_callback = checkpoint_callback self.weights_save_path = weights_save_path # accumulated grads self.accumulate_grad_batches = accumulate_grad_batches self.configure_accumulated_gradients(accumulate_grad_batches) # for gpus allow int, string and gpu list if auto_select_gpus and isinstance(gpus, int): self.gpus = pick_multiple_gpus(gpus) else: self.gpus = gpus self.data_parallel_device_ids = parse_gpu_ids(self.gpus) self.root_gpu = determine_root_gpu_device( self.data_parallel_device_ids) self.root_device = torch.device("cpu") # tpu state flags self.use_tpu = False self.tpu_local_core_rank = None self.tpu_global_core_rank = None # distributed backend choice self.distributed_backend = distributed_backend self.set_distributed_mode(distributed_backend) # override dist backend when using tpus if self.on_tpu: self.init_tpu() # init flags for SLURM+ddp to work self.proc_rank = 0 self.world_size = 1 self.configure_slurm_ddp(self.num_nodes) self.node_rank = self.determine_ddp_node_rank() # nvidia setup self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids) # backward compatibility if show_progress_bar is not None: self.show_progress_bar = show_progress_bar self._progress_bar_callback = self.configure_progress_bar( progress_bar_refresh_rate, process_position) # logging self.log_save_interval = log_save_interval self.val_check_interval = val_check_interval # backward compatibility if add_row_log_interval is not None: rank_zero_warn( "`add_row_log_interval` has renamed to `row_log_interval` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) if not row_log_interval: # in case you did not set the proper value row_log_interval = add_row_log_interval self.row_log_interval = row_log_interval # how much of the data to use self.overfit_pct = overfit_pct self.determine_data_use_amount(train_percent_check, val_percent_check, test_percent_check, overfit_pct) # AMP init # These are the only lines needed after v0.8.0 # we wrap the user's forward with autocast and give it back at the end of fit self.autocast_original_forward = None self.use_native_amp = hasattr(torch.cuda, "amp") and hasattr( torch.cuda.amp, "autocast") self.precision = precision self.scaler = None # TODO: remove for v0.8.0 self.amp_level = amp_level self.init_amp(use_amp) self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv( 'KAGGLE_URL_BASE') # Callback system self.on_init_end()
def __init__( self, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, checkpoint_callback: Union[ModelCheckpoint, bool] = True, early_stop_callback: Optional[Union[EarlyStopping, bool]] = False, callbacks: List[Callback] = [], default_root_dir: Optional[str] = None, gradient_clip_val: float = 0, process_position: int = 0, num_nodes: int = 1, gpus: Optional[Union[List[int], str, int]] = None, auto_select_gpus: bool = False, num_tpu_cores: Optional[int] = None, log_gpu_memory: Optional[str] = None, progress_bar_refresh_rate: int = 1, overfit_pct: float = 0.0, track_grad_norm: int = -1, check_val_every_n_epoch: int = 1, fast_dev_run: bool = False, accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1, max_epochs: int = 1000, min_epochs: int = 1, max_steps: Optional[int] = None, min_steps: Optional[int] = None, train_percent_check: float = 1.0, val_percent_check: float = 1.0, test_percent_check: float = 1.0, val_check_interval: float = 1.0, log_save_interval: int = 100, row_log_interval: int = 10, add_row_log_interval=None, # backward compatible, todo: remove in v0.8.0 distributed_backend: Optional[str] = None, precision: int = 32, print_nan_grads: bool = False, # backward compatible, todo: remove in v0.9.0 weights_summary: Optional[str] = 'full', weights_save_path: Optional[str] = None, amp_level: str = 'O1', num_sanity_val_steps: int = 5, truncated_bptt_steps: Optional[int] = None, resume_from_checkpoint: Optional[str] = None, profiler: Optional[BaseProfiler] = None, benchmark: bool = False, reload_dataloaders_every_epoch: bool = False, auto_lr_find: Union[bool, str] = False, default_save_path=None, # backward compatible, todo: remove in v0.8.0 gradient_clip=None, # backward compatible, todo: remove in v0.8.0 nb_gpu_nodes=None, # backward compatible, todo: remove in v0.8.0 max_nb_epochs=None, # backward compatible, todo: remove in v0.8.0 min_nb_epochs=None, # backward compatible, todo: remove in v0.8.0 use_amp=None, # backward compatible, todo: remove in v0.9.0 show_progress_bar=None, # backward compatible, todo: remove in v0.9.0 nb_sanity_val_steps=None, # backward compatible, todo: remove in v0.8.0 terminate_on_nan: bool = False, **kwargs ): r""" Customize every aspect of training via flags Args: logger: Logger (or iterable collection of loggers) for experiment tracking. checkpoint_callback: Callback for checkpointing. early_stop_callback (:class:`pytorch_lightning.callbacks.EarlyStopping`): callbacks: Add a list of callbacks. default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed default_save_path: .. warning:: .. deprecated:: 0.7.3 Use `default_root_dir` instead. Will remove 0.9.0. gradient_clip_val: 0 means don't clip. gradient_clip: .. warning:: .. deprecated:: 0.7.0 Use `gradient_clip_val` instead. Will remove 0.9.0. process_position: orders the tqdm bar when running multiple models on same machine. num_nodes: number of GPU nodes for distributed training. nb_gpu_nodes: .. warning:: .. deprecated:: 0.7.0 Use `num_nodes` instead. Will remove 0.9.0. gpus: Which GPUs to train on. auto_select_gpus: If enabled and `gpus` is an integer, pick available gpus automatically. This is especially useful when GPUs are configured to be in "exclusive mode", such that only one process at a time can access them. num_tpu_cores: How many TPU cores to train on (1 or 8). log_gpu_memory: None, 'min_max', 'all'. Might slow performance show_progress_bar: .. warning:: .. deprecated:: 0.7.2 Set `progress_bar_refresh_rate` to postive integer to enable. Will remove 0.9.0. progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar. overfit_pct: How much of training-, validation-, and test dataset to check. track_grad_norm: -1 no tracking. Otherwise tracks that norm check_val_every_n_epoch: Check val every n train epochs. fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test). accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict. max_epochs: Stop training once this number of epochs is reached. max_nb_epochs: .. warning:: .. deprecated:: 0.7.0 Use `max_epochs` instead. Will remove 0.9.0. min_epochs: Force training for at least these many epochs min_nb_epochs: .. warning:: .. deprecated:: 0.7.0 Use `min_epochs` instead. Will remove 0.9.0. max_steps: Stop training after this number of steps. Disabled by default (None). min_steps: Force training for at least these number of steps. Disabled by default (None). train_percent_check: How much of training dataset to check. val_percent_check: How much of validation dataset to check. test_percent_check: How much of test dataset to check. val_check_interval: How often within one training epoch to check the validation set log_save_interval: Writes logs to disk this often row_log_interval: How often to add logging rows (does not write to disk) add_row_log_interval: .. warning:: .. deprecated:: 0.7.0 Use `row_log_interval` instead. Will remove 0.9.0. distributed_backend: The distributed backend to use. use_amp: .. warning:: .. deprecated:: 0.7.0 Use `precision` instead. Will remove 0.9.0. precision: Full precision (32), half precision (16). print_nan_grads: .. warning:: .. deprecated:: 0.7.2 Has no effect. When detected, NaN grads will be printed automatically. Will remove 0.9.0. weights_summary: Prints a summary of the weights when training begins. weights_save_path: Where to save weights if specified. Will override default_root_dir for checkpoints only. Use this if for whatever reason you need the checkpoints stored in a different place than the logs written in `default_root_dir`. amp_level: The optimization level to use (O1, O2, etc...). num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine. nb_sanity_val_steps: .. warning:: .. deprecated:: 0.7.0 Use `num_sanity_val_steps` instead. Will remove 0.8.0. truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here. profiler: To profile individual steps during training and assist in reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch auto_lr_find: If set to True, will `initially` run a learning rate finder, trying to optimize initial learning for faster convergence. Sets learning rate in self.hparams.lr | self.hparams.learning_rate in the lightning module. To use a different key, set a string instead of True with the key name. benchmark: If true enables cudnn.benchmark. terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the end of each training batch, if any of the parameters or the loss are NaN or +/-inf. """ # Init callbacks self.callbacks = callbacks self.on_init_start() # benchmarking self.benchmark = benchmark torch.backends.cudnn.benchmark = self.benchmark # Transfer params self.num_nodes = num_nodes # Backward compatibility, TODO: remove in v0.8.0 if nb_gpu_nodes is not None: rank_zero_warn("Argument `nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) self.num_gpu_nodes = nb_gpu_nodes self.log_gpu_memory = log_gpu_memory self.gradient_clip_val = gradient_clip_val # Backward compatibility, TODO: remove in v0.8.0 if gradient_clip is not None: rank_zero_warn("Argument `gradient_clip` has renamed to `gradient_clip_val` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) self.gradient_clip = gradient_clip self.progress_bar_refresh_rate = progress_bar_refresh_rate self.check_val_every_n_epoch = check_val_every_n_epoch self.track_grad_norm = track_grad_norm self.on_gpu = True if (gpus and torch.cuda.is_available()) else False # tpu config self.on_tpu = num_tpu_cores is not None self.num_tpu_cores = num_tpu_cores assert num_tpu_cores in [1, 8, None], 'num_tpu_cores can only be 1 or 8' self.process_position = process_position self.weights_summary = weights_summary self.max_epochs = max_epochs # Backward compatibility, TODO: remove in v0.8.0 if max_nb_epochs is not None: rank_zero_warn("Argument `max_nb_epochs` has renamed to `max_epochs` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) self.max_nb_epochs = max_nb_epochs self.min_epochs = min_epochs # Backward compatibility, TODO: remove in v0.8.0 if min_nb_epochs is not None: rank_zero_warn("Argument `min_nb_epochs` has renamed to `min_epochs` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) self.min_nb_epochs = min_nb_epochs self.max_steps = max_steps self.min_steps = min_steps self.num_sanity_val_steps = num_sanity_val_steps # Backward compatibility, TODO: remove in v0.8.0 if nb_sanity_val_steps is not None: rank_zero_warn("Argument `nb_sanity_val_steps` has renamed to " "`num_sanity_val_steps` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) self.nb_sanity_val_steps = nb_sanity_val_steps # Backward compatibility, TODO: remove in v0.9.0 if print_nan_grads: rank_zero_warn("Argument `print_nan_grads` has no effect and will be removed in v0.9.0." " NaN grads will be printed automatically when detected.", DeprecationWarning) self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch self.auto_lr_find = auto_lr_find self.truncated_bptt_steps = truncated_bptt_steps self.resume_from_checkpoint = resume_from_checkpoint self.terminate_on_nan = terminate_on_nan self.shown_warnings = set() self.fast_dev_run = fast_dev_run if self.fast_dev_run: self.num_sanity_val_steps = 0 self.max_epochs = 1 log.info('Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch') # set default save path if user didn't provide one self.default_root_dir = default_root_dir # Backward compatibility, TODO: remove in v0.8.0 if default_save_path is not None: self.default_root_dir = default_save_path if self.default_root_dir is None: self.default_root_dir = os.getcwd() # training bookeeping self.total_batch_idx = 0 self.running_loss = TensorRunningAccum(window_length=20) self.batch_idx = 0 self.tqdm_metrics = {} self.callback_metrics = {} self.num_val_batches = 0 self.num_training_batches = 0 self.num_test_batches = 0 self.train_dataloader = None self.test_dataloaders = None self.val_dataloaders = None # training state self.model = None self.testing = False self.disable_validation = False self.lr_schedulers = [] self.optimizers = None self.optimizer_frequencies = [] self.global_step = 0 self.current_epoch = 0 self.total_batches = 0 self.interrupted = False # configure logger self.configure_logger(logger) # configure profiler if profiler is True: profiler = SimpleProfiler() self.profiler = profiler or PassThroughProfiler() # configure early stop callback # creates a default one if none passed in self.configure_early_stopping(early_stop_callback) # configure checkpoint callback self.checkpoint_callback = checkpoint_callback self.weights_save_path = weights_save_path # accumulated grads self.accumulate_grad_batches = accumulate_grad_batches self.configure_accumulated_gradients(accumulate_grad_batches) # for gpus allow int, string and gpu list if auto_select_gpus and isinstance(gpus, int): self.gpus = pick_multiple_gpus(gpus) else: self.gpus = gpus self.data_parallel_device_ids = parse_gpu_ids(self.gpus) self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids) self.root_device = torch.device("cpu") # tpu state flags self.use_tpu = False self.tpu_local_core_rank = None self.tpu_global_core_rank = None # distributed backend choice self.use_ddp = False self.use_ddp2 = False self.use_dp = False self.single_gpu = False self.distributed_backend = distributed_backend self.set_distributed_mode(distributed_backend, self.num_nodes) # override dist backend when using tpus if self.on_tpu: self.init_tpu() self.current_tpu_idx = None # init flags for SLURM+ddp to work self.proc_rank = 0 self.world_size = 1 self.node_rank = 0 self.configure_slurm_ddp(self.num_nodes) # nvidia setup self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids) # can't init progress bar here because starting a new process # means the progress_bar won't survive pickling # backward compatibility if show_progress_bar is not None: self.show_progress_bar = show_progress_bar # logging self.log_save_interval = log_save_interval self.val_check_interval = val_check_interval # backward compatibility if add_row_log_interval is not None: rank_zero_warn("`add_row_log_interval` has renamed to `row_log_interval` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) if not row_log_interval: # in case you did not set the proper value row_log_interval = add_row_log_interval self.row_log_interval = row_log_interval # how much of the data to use self.overfit_pct = overfit_pct self.determine_data_use_amount(train_percent_check, val_percent_check, test_percent_check, overfit_pct) # 16 bit mixed precision training using apex self.amp_level = amp_level self.precision = precision # Backward compatibility, TODO: remove in v0.9.0 if use_amp is not None: rank_zero_warn("`use_amp` has been replaced by `precision` since v0.7.0" " and this argument will be removed in v0.9.0", DeprecationWarning) self.precision = 16 if use_amp else 32 assert self.precision in (16, 32), 'only 32 or 16 bit precision supported' if self.precision == 16 and self.num_tpu_cores is None: use_amp = True self.init_amp(use_amp) # Callback system self.on_init_end()
class TrainingBatchLoop(Loop[_OUTPUTS_TYPE]): """Runs over a single batch of data.""" def __init__(self) -> None: super().__init__() self.accumulated_loss = TensorRunningAccum(window_length=20) self.running_loss = TensorRunningAccum(window_length=20) # the current split index when the batch gets split into chunks in truncated backprop through time self.split_idx: int = 0 self.optimizer_loop = OptimizerLoop() self.manual_loop = ManualOptimization() self._outputs: _OUTPUTS_TYPE = [] self._remaining_splits: List[Tuple[int, Any]] = [] @property def done(self) -> bool: """Returns if all batch splits have been processed already.""" return len(self._remaining_splits) == 0 def connect( # type: ignore[override] self, optimizer_loop: Optional[OptimizerLoop] = None, manual_loop: Optional[ManualOptimization] = None) -> None: if optimizer_loop is not None: self.optimizer_loop = optimizer_loop if manual_loop is not None: self.manual_loop = manual_loop def reset(self) -> None: """Resets the loop state.""" self._outputs = [] def on_run_start(self, kwargs: OrderedDict) -> None: # type: ignore[override] """Splits the data into tbptt splits. Args: kwargs: the kwargs passed down to the hooks. """ batch = kwargs["batch"] self._remaining_splits = list(enumerate( self._tbptt_split_batch(batch))) def advance(self, kwargs: OrderedDict) -> None: # type: ignore[override] """Runs the train step together with optimization (if necessary) on the current batch split. Args: kwargs: the kwargs passed down to the hooks. """ # replace the batch with the split batch self.split_idx, kwargs["batch"] = self._remaining_splits.pop(0) self.trainer._logger_connector.on_train_split_start(self.split_idx) outputs: Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]] = None # for mypy # choose which loop will run the optimization if self.trainer.lightning_module.automatic_optimization: optimizers = _get_active_optimizers( self.trainer.optimizers, self.trainer.optimizer_frequencies, kwargs.get("batch_idx", 0)) outputs = self.optimizer_loop.run(optimizers, kwargs) else: outputs = self.manual_loop.run(kwargs) if outputs: # automatic: can be empty if all optimizers skip their batches # manual: #9052 added support for raising `StopIteration` in the `training_step`. If that happens, # then `advance` doesn't finish and an empty dict is returned self._outputs.append(outputs) def on_run_end(self) -> _OUTPUTS_TYPE: self.optimizer_loop._hiddens = None # this is not necessary as the manual loop runs for only 1 iteration, but just in case self.manual_loop._hiddens = None output, self._outputs = self._outputs, [] # free memory self._remaining_splits = [] return output def teardown(self) -> None: self.optimizer_loop.teardown() self.manual_loop.teardown() # release memory if self.accumulated_loss.memory is not None: self.accumulated_loss.memory = self.accumulated_loss.memory.cpu() if self.running_loss.memory is not None: self.running_loss.memory = self.running_loss.memory.cpu() def _tbptt_split_batch(self, batch: Any) -> List[Any]: """Splits a single batch into a list of sequence steps for tbptt. Args: batch: the current batch to split """ tbptt_steps = self.trainer.lightning_module.truncated_bptt_steps if tbptt_steps == 0: return [batch] splits = self.trainer._call_lightning_module_hook( "tbptt_split_batch", batch, tbptt_steps) return splits def _update_running_loss(self, current_loss: Tensor) -> None: """Updates the running loss value with the current value.""" if self.trainer.lightning_module.automatic_optimization: # track total loss for logging (avoid mem leaks) self.accumulated_loss.append(current_loss) accumulated_loss = self.accumulated_loss.mean() if accumulated_loss is not None: # calculate running loss for display self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches) # reset for next set of accumulated grads self.accumulated_loss.reset()
class TrainLoop: def __init__(self, trainer): self.trainer = trainer self.early_stopping_accumulator = None self.checkpoint_accumulator = None self.accumulated_loss = None self.warning_cache = WarningCache() self._teardown_already_run = False self.running_loss = TensorRunningAccum(window_length=20) self.automatic_optimization = True self._curr_step_result = None self._cur_grad_norm_dict = None def on_trainer_init(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps, automatic_optimization): self.trainer.global_step = 0 self.trainer.current_epoch = 0 self.trainer.interrupted = False self.trainer.should_stop = False self.trainer._state = TrainerState.INITIALIZING self.trainer.total_batch_idx = 0 self.trainer.batch_idx = 0 self.trainer.num_training_batches = 0 self.trainer.train_dataloader = None self.automatic_optimization = automatic_optimization self.trainer.max_epochs = max_epochs self.trainer.min_epochs = min_epochs self.trainer.max_steps = max_steps self.trainer.min_steps = min_steps if num_sanity_val_steps == -1: self.trainer.num_sanity_val_steps = float("inf") else: self.trainer.num_sanity_val_steps = num_sanity_val_steps @property def num_optimizers(self): num_optimizers = len(self.get_optimizers_iterable()) return num_optimizers def should_skip_training(self): if self.trainer.current_epoch >= self.trainer.max_epochs: return True if self.trainer.limit_train_batches == 0: return True return False def on_train_start(self): # clear cache before training if self.trainer.on_gpu and self.trainer.root_gpu is not None: # use context because of: # https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898 with torch.cuda.device(f"cuda:{self.trainer.root_gpu}"): torch.cuda.empty_cache() # hook self.trainer.call_hook("on_train_start") def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): # bind logger and other properties self.trainer.model_connector.copy_trainer_model_properties(model) # clean hparams if hasattr(model, "hparams"): parsing.clean_namespace(model.hparams) # links data to the trainer self.trainer.data_connector.attach_data(model, train_dataloader, val_dataloaders, datamodule) # check that model is configured correctly self.trainer.config_validator.verify_loop_configurations(model) def setup_training(self, model: LightningModule): """Sanity check a few things before starting actual training. Args: model: The model to run sanity test on. """ # -------------------------- # Setup?? # -------------------------- ref_model = model if self.trainer.data_parallel: ref_model = model.module # set the ranks and devices self.trainer.accelerator_backend.dist.rank = self.trainer.global_rank self.trainer.accelerator_backend.dist.device = ref_model.device # give model convenience properties ref_model.trainer = self.trainer # set local properties on the model self.trainer.model_connector.copy_trainer_model_properties(ref_model) # init amp. Must be done here instead of __init__ to allow ddp to work if self.trainer.amp_backend == AMPType.NATIVE and self.trainer.precision == 16 and not self.trainer.use_tpu: self.trainer.scaler = self.trainer.precision_connector.backend.scaler # log hyper-parameters if self.trainer.logger is not None: # save exp to get started (this is where the first experiment logs are written) self.trainer.logger.log_hyperparams(ref_model.hparams_initial) self.trainer.logger.log_graph(ref_model) self.trainer.logger.save() # wait for all to join if on distributed self.trainer.accelerator_backend.barrier("setup_training") # register auto-resubmit when on SLURM self.trainer.slurm_connector.register_slurm_signal_handlers() # -------------------------- # Pre-train # -------------------------- # on pretrain routine start self.trainer.on_pretrain_routine_start(ref_model) if self.trainer.is_function_implemented("on_pretrain_routine_start"): ref_model.on_pretrain_routine_start() # print model summary if self.trainer.is_global_zero and self.trainer.weights_summary is not None and not self.trainer.testing: if self.trainer.weights_summary in ModelSummary.MODES: ref_model.summarize(mode=self.trainer.weights_summary) else: raise MisconfigurationException( "weights_summary can be None, " + ", ".join(ModelSummary.MODES)) # track model now. # if cluster resets state, the model will update with the saved weights self.trainer.model = model # restore training and model before hpc is called self.trainer.checkpoint_connector.restore_weights(model) # on pretrain routine end self.trainer.on_pretrain_routine_end(ref_model) if self.trainer.is_function_implemented("on_pretrain_routine_end"): ref_model.on_pretrain_routine_end() def on_train_end(self): if self._teardown_already_run: return self._teardown_already_run = True # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates # when a checkpoint was saved at the last step self.trainer.global_step -= 1 self.check_checkpoint_callback(should_save=True, is_last=True) self.trainer.global_step += 1 # hook self.trainer.call_hook("on_train_end") # kill loggers if self.trainer.logger is not None: self.trainer.logger.finalize("success") # summarize profile results if self.trainer.global_rank == 0: self.trainer.profiler.describe() # give accelerators a chance to finish self.trainer.accelerator_backend.on_train_end() # clear mem if self.trainer.on_gpu: model = self.trainer.get_model() model.cpu() torch.cuda.empty_cache() def check_checkpoint_callback(self, should_save, is_last=False): # TODO bake this logic into the checkpoint callback if should_save and self.trainer.checkpoint_connector.has_trained: checkpoint_callbacks = [ c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint) ] if is_last and any(c.save_last for c in checkpoint_callbacks): rank_zero_info("Saving latest checkpoint...") model = self.trainer.get_model() for callback in checkpoint_callbacks: callback.on_validation_end(self.trainer, model) def on_train_epoch_start(self, epoch): # update training progress in trainer self.trainer.current_epoch = epoch model = self.trainer.get_model() # reset train dataloader if self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_train_dataloader(model) # set seed for distributed sampler (enables shuffling for each epoch) try: self.trainer.train_dataloader.sampler.set_epoch(epoch) except Exception: pass # changing gradient according accumulation_scheduler self.trainer.accumulation_scheduler.on_epoch_start( self.trainer, self.trainer.get_model()) # stores accumulated grad fractions per batch self.accumulated_loss = TensorRunningAccum( window_length=self.trainer.accumulate_grad_batches) # structured result accumulators for callbacks self.early_stopping_accumulator = Accumulator() self.checkpoint_accumulator = Accumulator() # hook self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start") def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx): # hook self.trainer.call_hook('on_batch_end') self.trainer.call_hook('on_train_batch_end', epoch_end_outputs, batch, batch_idx, dataloader_idx) # figure out what to track for epoch end self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs) # reset batch logger internals self.trainer.logger_connector.on_train_batch_end() def reset_train_val_dataloaders(self, model): if not self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_train_dataloader(model) if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_val_dataloader(model) def track_epoch_end_reduce_metrics(self, epoch_output, epoch_end_outputs): # track the outputs to reduce at the end of the epoch for opt_idx, opt_outputs in enumerate(epoch_end_outputs): # with 1 step (no tbptt) don't use a sequence at epoch end if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance( opt_outputs[0], Result): opt_outputs = opt_outputs[0] epoch_output[opt_idx].append(opt_outputs) def get_optimizers_iterable(self): """ Generates an iterable with (idx, optimizer) for each optimizer. """ if not self.trainer.optimizer_frequencies: # call training_step once per optimizer return list(enumerate(self.trainer.optimizers)) optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) optimizers_loop_length = optimizer_freq_cumsum[-1] current_place_in_loop = self.trainer.total_batch_idx % optimizers_loop_length # find optimzier index by looking for the first {item > current_place} in the cumsum list opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop) return [[opt_idx, self.trainer.optimizers[opt_idx]]] def on_after_backward(self, training_step_output, batch_idx, untouched_loss): is_result_obj = isinstance(training_step_output, Result) if is_result_obj: training_step_output.detach() else: training_step_output.batch_loss = training_step_output.batch_loss.detach( ) # insert after step hook self.trainer.call_hook("on_after_backward") # when in dev debugging track the losses self.trainer.dev_debugger.track_train_loss_history( batch_idx, untouched_loss.detach()) def _check_training_step_output(self, training_step_output): if isinstance(training_step_output, torch.Tensor) and not self.automatic_optimization: if training_step_output.grad_fn is None: # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ... raise MisconfigurationException( "In manual optimization, `training_step` should not return a Tensor" ) def training_step(self, split_batch, batch_idx, opt_idx, hiddens): # give the PL module a result for logging model_ref = self.trainer.get_model() with self.trainer.profiler.profile("model_forward"): args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens) # manually capture logged metrics model_ref._current_fx_name = 'training_step' model_ref._results = Result() training_step_output = self.trainer.accelerator_backend.training_step( args) self.trainer.logger_connector.cache_logged_metrics() self._check_training_step_output(training_step_output) training_step_output = self.trainer.call_hook( "training_step_end", training_step_output) training_step_output_for_epoch_end, training_step_output = self._process_training_step_output( training_step_output, split_batch) is_result_obj = isinstance(training_step_output, Result) if training_step_output_for_epoch_end is None: return None # enable empty loss when using manual opt closure_loss = None untouched_loss = None if self.trainer.train_loop.automatic_optimization: # accumulate loss # (if accumulate_grad_batches = 1 no effect) if is_result_obj: closure_loss = training_step_output.minimize else: closure_loss = training_step_output.batch_loss closure_loss = closure_loss / self.trainer.accumulate_grad_batches # the loss will get scaled for amp. avoid any modifications to it untouched_loss = closure_loss.detach().clone() # result result = AttributeDict( closure_loss=closure_loss, loss=untouched_loss, training_step_output=training_step_output, training_step_output_for_epoch_end= training_step_output_for_epoch_end, hiddens=training_step_output.hiddens, ) return result def _process_training_step_output(self, training_step_output, split_batch): training_step_output_for_epoch_end = training_step_output # enable validation_step return None if training_step_output_for_epoch_end is None: return None, None # ----------------------------------------- # process result return (DEPRECATE in 1.0) # ----------------------------------------- if isinstance(training_step_output, Result): training_step_output_for_epoch_end = self._process_result( training_step_output, split_batch) return training_step_output_for_epoch_end, training_step_output # ----------------------------------------- # process hybrid (1.0) # ----------------------------------------- # no need for these checks in 1.0.0 # TODO: remove checks in 1.0.0 is_tensor = isinstance(training_step_output_for_epoch_end, torch.Tensor) is_1_0_output = is_tensor or ("log" not in training_step_output and "progress_bar" not in training_step_output) if is_1_0_output: return self._process_training_step_output_1_0( training_step_output, split_batch) # ----------------------------------------- # process old dict (deprecate 1.0) # ----------------------------------------- training_step_output = self.trainer.process_dict_result( training_step_output, train=True) training_step_output = AttributeDict( batch_loss=training_step_output[0], pbar_on_batch_end=training_step_output[1], log_metrics=training_step_output[2], callback_metrics=training_step_output[3], hiddens=training_step_output[4], ) # if the user decides to finally reduce things in epoch_end, save raw output without graphs if isinstance(training_step_output_for_epoch_end, torch.Tensor): training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach( ) else: training_step_output_for_epoch_end = recursive_detach( training_step_output_for_epoch_end) return training_step_output_for_epoch_end, training_step_output def _process_training_step_output_1_0(self, training_step_output, split_batch): result = self.trainer.get_model()._results loss = None hiddens = None # handle dict return if isinstance(training_step_output, dict): loss = training_step_output.pop("loss", None) hiddens = training_step_output.pop("hiddens", None) result["extra"] = training_step_output # handle scalar return elif isinstance(training_step_output, torch.Tensor): loss = training_step_output result["extra"] = {} # map to results under the hood result.minimize = loss result.hiddens = hiddens # track batch for manual reduction with result result.track_batch_size(len(split_batch)) # track metrics without grads for epoch reduction training_step_output_for_epoch_end = copy(result) training_step_output_for_epoch_end.detach() if self.trainer.move_metrics_to_cpu: training_step_output_for_epoch_end.cpu() # what flows back into the system training_step_output = result return training_step_output_for_epoch_end, training_step_output def _process_result(self, training_step_output, split_batch): training_step_output.track_batch_size(len(split_batch)) m = """ TrainResult and EvalResult were deprecated in 0.9.1 and support will drop in 1.0.0. Use self.log and .write from the LightningModule to log metrics and write predictions. training_step can now only return a scalar (for the loss) or a dictionary with anything you want. Option 1: return loss Option 2: return {'loss': loss, 'anything_else': ...} Option 3: return {'loss': loss, 'hiddens': hiddens, 'anything_else': ...} """ rank_zero_warn(m) # don't allow EvalResult in the training_step if isinstance(training_step_output, EvalResult): raise MisconfigurationException( "training_step cannot return EvalResult, " "use a dict or TrainResult instead") training_step_output_for_epoch_end = copy(training_step_output) training_step_output_for_epoch_end.detach() return training_step_output_for_epoch_end def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): model_ref = self.trainer.get_model() is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) using_native_amp = self.trainer.amp_backend == AMPType.NATIVE # native amp + lbfgs is a no go right now if using_native_amp and is_lbfgs: raise MisconfigurationException( 'native PyTorch amp and lbfgs are not compatible.' ' To request, please file a Github issue in PyTorch and tag @mcarilli' ) # model hook model_ref.optimizer_step( self.trainer.current_epoch, batch_idx, optimizer, opt_idx, train_step_and_backward_closure, on_tpu=self.trainer.use_tpu and TPU_AVAILABLE, using_native_amp=using_native_amp, using_lbfgs=is_lbfgs, ) def on_before_zero_grad(self, optimizer): self.trainer.call_hook('on_before_zero_grad', optimizer) def track_and_norm_grad(self, optimizer): # track gradient norms grad_norm_dic = self._track_gradient_norm() # clip gradients self.trainer.accelerator_backend.clip_gradients(optimizer) self._cur_grad_norm_dict = grad_norm_dic def _track_gradient_norm(self): grad_norm_dict = {} if (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0: if float(self.trainer.track_grad_norm) > 0: model = self.trainer.get_model() grad_norm_dict = model.grad_norm(self.trainer.track_grad_norm) return grad_norm_dict def process_hiddens(self, opt_closure_result): hiddens = opt_closure_result.hiddens if isinstance(opt_closure_result.training_step_output, Result): opt_closure_result.training_step_output_for_epoch_end.drop_hiddens( ) return hiddens def tbptt_split_batch(self, batch): splits = [batch] if self.trainer.truncated_bptt_steps is not None: model_ref = self.trainer.get_model() with self.trainer.profiler.profile("tbptt_split_batch"): splits = model_ref.tbptt_split_batch( batch, self.trainer.truncated_bptt_steps) return splits def run_training_epoch(self): # get model model = self.trainer.get_model() # modify dataloader if needed (ddp, etc...) train_dataloader = self.trainer.accelerator_backend.process_dataloader( self.trainer.train_dataloader) # track epoch output epoch_output = [[] for _ in range(self.num_optimizers)] # enable profiling for the dataloader train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader( train_dataloader) dataloader_idx = 0 should_check_val = False for batch_idx, (batch, is_last_batch) in train_dataloader: self.trainer.batch_idx = batch_idx # ------------------------------------ # TRAINING_STEP + TRAINING_STEP_END # ------------------------------------ with self.trainer.profiler.profile("run_training_batch"): batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx) # when returning -1 from train_step, we end epoch early if batch_output.signal == -1: break # only track outputs when user implements training_epoch_end # otherwise we will build up unnecessary memory epoch_end_outputs = self.process_train_step_outputs( batch_output.training_step_output_for_epoch_end, self.early_stopping_accumulator, self.checkpoint_accumulator, ) # hook # TODO: add outputs to batches self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx) # ----------------------------------------- # SAVE METRICS TO LOGGERS # ----------------------------------------- self.trainer.logger_connector.log_train_step_metrics(batch_output) # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- should_check_val = self.should_check_val_fx( batch_idx, is_last_batch) if should_check_val: self.trainer.run_evaluation(test_mode=False) # reset stage to train self.trainer.logger_connector.set_stage("train") # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) # ----------------------------------------- self.save_loggers_on_train_batch_end() # update LR schedulers monitor_metrics = deepcopy( self.trainer.logger_connector.callback_metrics) self.update_train_loop_lr_schedulers( monitor_metrics=monitor_metrics) self.trainer.checkpoint_connector.has_trained = True # max steps reached, end training if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1: accumulation_done = self._accumulated_batches_reached() # Ensure accumulation across batches has completed before breaking loop if accumulation_done: break # end epoch early # stop when the flag is changed or we've gone past the amount # requested in the batches if self.trainer.should_stop: break self.trainer.total_batch_idx += 1 # stop epoch if we limited the number of training batches if (batch_idx + 1) >= self.trainer.num_training_batches: break # progress global step according to grads progress self.increment_accumulated_grad_global_step() # epoch end hook self.run_on_epoch_end_hook(epoch_output) # log epoch metrics self.trainer.logger_connector.log_train_epoch_end_metrics( epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers) # when no val loop is present or fast-dev-run still need to call checkpoints self.check_checkpoint_callback(not ( should_check_val or is_overridden('validation_step', model))) # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() def run_training_batch(self, batch, batch_idx, dataloader_idx): # track grad norms grad_norm_dic = {} # bookkeeping using_results_obj = False self.trainer.hiddens = None # track all outputs across time and num of optimizers batch_outputs = [[] for _ in range(len(self.get_optimizers_iterable()))] if batch is None: return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic) # hook response = self.trainer.call_hook("on_batch_start") if response == -1: return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) # hook response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, dataloader_idx) if response == -1: return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) # lightning module hook splits = self.tbptt_split_batch(batch) for split_idx, split_batch in enumerate(splits): # create an iterable for optimizers and loop over them for opt_idx, optimizer in self.prepare_optimizers(): # toggle model params + set info to logger_connector self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) if self.should_accumulate(): # For gradient accumulation # ------------------- # calculate loss (train step + train step end) # ------------------- # automatic_optimization=True: perform dpp sync only when performing optimizer_step # automatic_optimization=False: don't block synchronization here with self.block_ddp_sync_behaviour(): self.training_step_and_backward( split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) batch_outputs = self._process_closure_result( batch_outputs=batch_outputs, opt_idx=opt_idx, ) # ------------------------------ # BACKWARD PASS # ------------------------------ # gradient update with accumulated gradients else: if self.automatic_optimization: def train_step_and_backward_closure(): result = self.training_step_and_backward( split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) return None if result is None else result.loss # optimizer step self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) else: self._curr_step_result = self.training_step( split_batch, batch_idx, opt_idx, self.trainer.hiddens) if self._curr_step_result is None: # user decided to skip optimization # make sure to zero grad. continue batch_outputs = self._process_closure_result( batch_outputs=batch_outputs, opt_idx=opt_idx, ) # todo: Properly aggregate grad_norm accros opt_idx and split_idx grad_norm_dic = self._cur_grad_norm_dict self._cur_grad_norm_dict = None # update running loss + reset accumulated loss self.update_running_loss() result = AttributeDict( signal=0, grad_norm_dic=grad_norm_dic, training_step_output_for_epoch_end=batch_outputs, ) return result @contextmanager def block_ddp_sync_behaviour(self): """ automatic_optimization = True Blocks ddp sync gradients behaviour on backwards pass. This is useful for skipping sync when accumulating gradients, reducing communication overhead automatic_optimization = False do not block ddp gradient sync when using manual optimization as gradients are needed within the training step Returns: context manager with sync behaviour off """ if self.trainer.accelerator_backend is not None and self.automatic_optimization: yield self.trainer.accelerator_backend.block_ddp_plugin_sync_behaviour( ) else: yield None def _process_closure_result(self, batch_outputs: list, opt_idx: int) -> list: opt_closure_result = self._curr_step_result if opt_closure_result is not None: # cache metrics self.trainer.logger_connector.cache_training_step_metrics( opt_closure_result) # track hiddens self.trainer.hiddens = self.process_hiddens(opt_closure_result) # check if loss or model weights are nan if self.trainer.terminate_on_nan: self.trainer.detect_nan_tensors(opt_closure_result.loss) # track all the outputs across all steps batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0 batch_outputs[batch_opt_idx].append( opt_closure_result.training_step_output_for_epoch_end) if self.automatic_optimization: # track total loss for logging (avoid mem leaks) self.accumulated_loss.append(opt_closure_result.loss) self._curr_step_result = None return batch_outputs def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): """ wrap the forward step in a closure so second order methods work """ with self.trainer.profiler.profile("training_step_and_backward"): # lightning module hook result = self.training_step(split_batch, batch_idx, opt_idx, hiddens) self._curr_step_result = result if result is None: self.warning_cache.warn( "training_step returned None if it was on purpose, ignore this warning..." ) return None if self.trainer.train_loop.automatic_optimization: # backward pass with self.trainer.profiler.profile("model_backward"): self.backward(result, optimizer, opt_idx) # hook - call this hook only # when gradients have finished to accumulate if not self.should_accumulate(): self.on_after_backward(result.training_step_output, batch_idx, result.loss) # check if loss or model weights are nan if self.trainer.terminate_on_nan: self.trainer.detect_nan_tensors(result.loss) return result def backward(self, result, optimizer, opt_idx, *args, **kwargs): self.trainer.dev_debugger.track_event("backward_call") # backward can be called manually in the training loop if isinstance(result, torch.Tensor): self.trainer.accelerator_backend.backward(result, optimizer, opt_idx, *args, **kwargs) else: result.closure_loss = self.trainer.accelerator_backend.backward( result.closure_loss, optimizer, opt_idx, *args, **kwargs) if not self.should_accumulate(): # track gradients self.track_and_norm_grad(optimizer=optimizer) def update_train_loop_lr_schedulers(self, monitor_metrics=None): num_accumulated_batches_reached = self._accumulated_batches_reached() num_training_batches_reached = self._num_training_batches_reached() if num_accumulated_batches_reached or num_training_batches_reached: # update lr self.trainer.optimizer_connector.update_learning_rates( interval="step", monitor_metrics=monitor_metrics) def run_on_epoch_end_hook(self, epoch_output): # inform logger the batch loop has finished self.trainer.logger_connector.on_train_epoch_end() self.trainer.call_hook('on_epoch_end') self.trainer.call_hook('on_train_epoch_end', epoch_output) def increment_accumulated_grad_global_step(self): num_accumulated_batches_reached = self._accumulated_batches_reached() num_training_batches_reached = self._num_training_batches_reached() # progress global step according to grads progress if num_accumulated_batches_reached or num_training_batches_reached: self.trainer.global_step += 1 def _accumulated_batches_reached(self): return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 def _num_training_batches_reached(self): return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches def should_accumulate(self): # checks if backward or backward + optimizer step (via closure) accumulation_done = self._accumulated_batches_reached() is_final_batch = self._num_training_batches_reached() return not (accumulation_done or is_final_batch) def should_check_val_fx(self, batch_idx, is_last_batch): # decide if we should run validation is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 can_check_val = self.trainer.enable_validation and is_val_check_epoch should_check_val = is_val_check_batch or self.trainer.should_stop is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float( "inf") should_check_val = can_check_val and ( should_check_val or is_last_batch_for_infinite_dataset) return should_check_val def build_train_args(self, batch, batch_idx, opt_idx, hiddens): # enable not needing to add opt_idx to training_step args = [batch, batch_idx] if len(self.trainer.optimizers) > 1: if self.trainer.has_arg("training_step", "optimizer_idx"): args.append(opt_idx) else: num_opts = len(self.trainer.optimizers) raise ValueError( f"Your LightningModule defines {num_opts} optimizers but " f'training_step is missing the "optimizer_idx" argument.') # pass hiddens if using tbptt if self.trainer.truncated_bptt_steps is not None: args.append(hiddens) return args def save_loggers_on_train_batch_end(self): # when loggers should save to disk should_flush_logs = self.trainer.logger_connector.should_flush_logs if should_flush_logs or self.trainer.fast_dev_run is True: if self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accumulator, checkpoint_accumulator): """ Figure out what needs to be tracked/logged at the end of the epoch """ # the training step outputs a list per optimizer. The list contains the outputs at each time step # when no TBPTT is used, then the list has 1 item per batch # when TBPTT IS used, then the list has n items (1 per time step) epoch_end_outputs = [] for optimizer_idx_outputs in all_train_step_outputs: # extract one representative sample from each time step (1 if no tbptt) and 0th optimizer if len(optimizer_idx_outputs) == 0: continue sample_output = optimizer_idx_outputs[-1] # pull out callback info if available (ie: Results object) if isinstance(sample_output, dict) and "early_stop_on" in sample_output: early_stopping_accumulator.accumulate( sample_output["early_stop_on"]) if isinstance(sample_output, dict) and "checkpoint_on" in sample_output: checkpoint_accumulator.accumulate( sample_output["checkpoint_on"]) # decide if we need to reduce at the end of the epoch automatically auto_reduce_tng_result = isinstance( sample_output, Result) and sample_output.should_reduce_on_epoch_end # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end if is_overridden( "training_epoch_end", model=self.trainer.get_model()) or auto_reduce_tng_result: epoch_end_outputs.append(optimizer_idx_outputs) return epoch_end_outputs def prepare_optimizers(self): # in manual optimization we loop over all optimizers at once optimizers = self.get_optimizers_iterable() if not self.automatic_optimization: optimizers = [optimizers[0]] return optimizers def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): # set split_idx to trainer for tracking self.trainer.split_idx = split_idx # make sure only the gradients of the current optimizer's parameters are calculated # in the training step to prevent dangling gradients in multiple-optimizer setup. if self.automatic_optimization and len(self.trainer.optimizers) > 1: model = self.trainer.get_model() model.toggle_optimizer(optimizer, opt_idx) # use to track metrics internally self.trainer.logger_connector.on_train_split_start( split_idx, opt_idx, split_batch) def update_running_loss(self): accumulated_loss = self.accumulated_loss.mean() if accumulated_loss is not None: # calculate running loss for display self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches) # reset for next set of accumulated grads self.accumulated_loss.reset()
class TrainerTrainLoopMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class max_epochs: int min_epochs: int on_gpu: bool root_gpu: ... use_ddp: bool use_dp: bool use_ddp2: bool use_horovod: bool use_single_gpu: bool use_tpu: bool data_parallel_device_ids: ... check_val_every_n_epoch: ... num_training_batches: int val_check_batch: ... disable_validation: bool fast_dev_run: ... accumulation_scheduler: ... lr_schedulers: ... early_stop_callback: ... callback_metrics: ... logger: Union[LightningLoggerBase, bool] global_step: int testing: bool log_save_interval: float global_rank: int row_log_interval: float truncated_bptt_steps: ... optimizers: ... optimizer_frequencies: ... accumulate_grad_batches: int track_grad_norm: ... model: LightningModule interrupted: bool running_loss: ... progress_bar_dict: ... reduce_lr_on_plateau_scheduler: ... profiler: ... batch_idx: int precision: ... train_dataloader: DataLoader reload_dataloaders_every_epoch: bool max_steps: int min_steps: int total_batch_idx: int terminate_on_nan: bool tpu_id: int interactive_ddp_procs: ... _state: TrainerState amp_backend: AMPType on_tpu: bool accelerator_backend: ... val_dataloaders: ... # Callback system callbacks: List[Callback] on_train_start: Callable on_train_end: Callable on_batch_start: Callable on_batch_end: Callable on_train_batch_start: Callable on_train_batch_end: Callable on_epoch_start: Callable on_epoch_end: Callable on_validation_end: Callable on_keyboard_interrupt: Callable on_train_epoch_start: Callable on_train_epoch_end: Callable @abstractmethod def get_model(self) -> LightningModule: """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def is_function_implemented(self, *args, **kwargs): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def run_evaluation(self, *args, **kwargs): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def transfer_batch_to_gpu(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def clip_gradients(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def detect_nan_tensors(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def add_progress_bar_metrics(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def log_metrics(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def process_output(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def reset_train_dataloader(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def reset_val_dataloader(self, model): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def call_hook(self, hook_name, *args, **kwargs): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def has_arg(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def run_sanity_check(self, *args): """Warning: this is just empty shell for code implemented in other class.""" def train(self): self.run_sanity_check(self.get_model()) # TODO: shrink # clear cache before training if self.on_gpu and self.root_gpu is not None: # use context because of: # https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898 with torch.cuda.device(f'cuda:{self.root_gpu}'): torch.cuda.empty_cache() # get model model = self.get_model() # enable train mode model.train() # enable gradients torch.set_grad_enabled(True) # load data # if reload_dataloaders_every_epoch, this is moved to the epoch loop if not self.reload_dataloaders_every_epoch: self.reset_train_dataloader(model) if self.val_dataloaders is None and not self.reload_dataloaders_every_epoch: self.reset_val_dataloader(model) # Train start events with self.profiler.profile('on_train_start'): # callbacks self.on_train_start() # model hooks model.on_train_start() try: # run all epochs for epoch in range(self.current_epoch, self.max_epochs): # reset train dataloader if self.reload_dataloaders_every_epoch: self.reset_train_dataloader(model) # set seed for distributed sampler (enables shuffling for each epoch) if (self.use_ddp or self.use_horovod or self.on_tpu) \ and hasattr(self.train_dataloader, 'sampler') \ and hasattr(self.train_dataloader.sampler, 'set_epoch'): self.train_dataloader.sampler.set_epoch(epoch) # update training progress in trainer and model model.current_epoch = epoch self.current_epoch = epoch # changing gradient according accumulation_scheduler self.accumulation_scheduler.on_epoch_start(self, self.get_model()) # stores accumulated grad fractions per batch self.batch_loss_value = TensorRunningAccum( window_length=self.accumulate_grad_batches ) # ----------------- # RUN TNG EPOCH # ----------------- self.run_training_epoch() if self.max_steps and self.max_steps <= self.global_step: self.run_training_teardown() return # update LR schedulers self.update_learning_rates(interval='epoch') # early stopping met_min_epochs = epoch >= self.min_epochs - 1 met_min_steps = self.global_step >= self.min_steps if self.min_steps else True if self.should_stop: if (met_min_epochs and met_min_steps): self.run_training_teardown() return else: log.info('Trainer was signaled to stop but required minimum epochs' f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has' ' not been met. Training will continue...') self.run_training_teardown() except KeyboardInterrupt: rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...') # user could press ctrl+c many times... only shutdown once if not self.interrupted: self.interrupted = True self._state = TrainerState.INTERRUPTED self.on_keyboard_interrupt() self.run_training_teardown() def run_on_epoch_start_hook(self, model): # Epoch start events with self.profiler.profile('on_epoch_start'): # callbacks self.on_epoch_start() # model hooks if self.is_function_implemented('on_epoch_start'): model.on_epoch_start() # Epoch start events with self.profiler.profile('on_train_epoch_start'): # callbacks self.on_train_epoch_start() # model hooks if self.is_function_implemented('on_train_epoch_start'): model.on_train_epoch_start() def run_training_epoch(self): # get model model = self.get_model() # Epoch start events self.run_on_epoch_start_hook(model) # modify dataloader if needed (ddp, etc...) train_dataloader = self.accelerator_backend.process_dataloader(self.train_dataloader) # bookkeeping num_optimizers = len(self._get_optimizers_iterable()) epoch_output = [[] for _ in range(num_optimizers)] should_check_val = False # structured result accumulators for callbacks early_stopping_accumulator = Accumulator() checkpoint_accumulator = Accumulator() # run epoch for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable( enumerate(_with_is_last(train_dataloader)), "get_train_batch" ): # stop epoch if we limited the number of training batches if batch_idx >= self.num_training_batches: break self.batch_idx = batch_idx model.global_step = self.global_step # ------------------------------------ # TRAINING_STEP + TRAINING_STEP_END # ------------------------------------ batch_output = self.run_training_batch(batch, batch_idx) # only track outputs when user implements training_epoch_end # otherwise we will build up unnecessary memory epoch_end_outputs = self.process_train_step_outputs( batch_output.training_step_output_for_epoch_end, early_stopping_accumulator, checkpoint_accumulator ) # track the outputs to reduce at the end of the epoch for opt_idx, opt_outputs in enumerate(epoch_end_outputs): # with 1 step (no tbptt) don't use a sequence at epoch end if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result): opt_outputs = opt_outputs[0] epoch_output[opt_idx].append(opt_outputs) # when returning -1 from train_step, we end epoch early self.should_stop = batch_output.signal == -1 # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- should_check_val = self.should_check_val(batch_idx, is_last_batch) if should_check_val: self.run_evaluation(test_mode=False) # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) # ----------------------------------------- self.save_loggers_in_training_loop(batch_idx) # ----------------------------------------- # SAVE METRICS TO LOGGERS # ----------------------------------------- self.save_train_loop_metrics_to_loggers(batch_idx, batch_output) # update LR schedulers monitor_metrics = deepcopy(self.callback_metrics) monitor_metrics.update(batch_output.batch_log_metrics) self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics) # progress global step according to grads progress self.increment_accumulated_grad_global_step() # max steps reached, end training if self.max_steps is not None and self.max_steps == self.global_step: break # end epoch early # stop when the flag is changed or we've gone past the amount # requested in the batches if self.should_stop: break # let ddp devices catch up when using horovod self.sync_horovod() # process epoch outputs self.run_training_epoch_end(epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers) # checkpoint callback self.check_checkpoint_callback(should_check_val) # epoch end hook self.run_on_epoch_end_hook(model) def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accumulator, checkpoint_accumulator): """ Figure out what needs to be tracked/logged at the end of the epoch """ # the training step outputs a list per optimizer. The list contains the outputs at each time step # when no TBPTT is used, then the list has 1 item per batch # when TBPTT IS used, then the list has n items (1 per time step) epoch_end_outputs = [] for optimizer_idx_outputs in all_train_step_outputs: # extract one representative sample from each time step (1 if no tbptt) and 0th optimizer sample_output = optimizer_idx_outputs[-1] # pull out callback info if available (ie: Results object) if isinstance(sample_output, dict) and 'early_stop_on' in sample_output: early_stopping_accumulator.accumulate(sample_output['early_stop_on']) if isinstance(sample_output, dict) and 'checkpoint_on' in sample_output: checkpoint_accumulator.accumulate(sample_output['checkpoint_on']) # decide if we need to reduce at the end of the epoch automatically auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end if is_overridden('training_epoch_end', model=self.get_model()) or auto_reduce_tng_result: epoch_end_outputs.append(optimizer_idx_outputs) return epoch_end_outputs def check_checkpoint_callback(self, should_check_val): # when no val loop is present or fast-dev-run still need to call checkpoints # TODO bake this logic into the checkpoint callback should_activate = not is_overridden('validation_step', self.get_model()) and not should_check_val if should_activate: checkpoint_callbacks = [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] [c.on_validation_end(self, self.get_model()) for c in checkpoint_callbacks] def update_train_loop_lr_schedulers(self, monitor_metrics=None): if ((self.batch_idx + 1) % self.accumulate_grad_batches == 0 or (self.batch_idx + 1) == self.num_training_batches): # update lr self.update_learning_rates(interval='step', monitor_metrics=monitor_metrics) def run_on_epoch_end_hook(self, model): with self.profiler.profile('on_epoch_end'): # callbacks self.on_epoch_end() # model hooks if self.is_function_implemented('on_epoch_end'): model.on_epoch_end() with self.profiler.profile('on_train_epoch_end'): # callbacks self.on_train_epoch_end() # model hooks if self.is_function_implemented('on_train_epoch_end'): model.on_train_epoch_end() def run_training_epoch_end(self, epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers): # epoch output is a list. Each item in that list has all the outputs per optimizer # epoch_output[optimizer_idx][training_step_idx][tbptt_index] # remember that not using truncated backprop is equivalent with truncated back prop of len(1) model = self.get_model() epoch_log_metrics = {} epoch_callback_metrics = {} epoch_progress_bar_metrics = {} # ----------------------- # Calculate epoch callback values if given # ----------------------- if checkpoint_accumulator.num_values > 0: epoch_callback_metrics['checkpoint_on'] = checkpoint_accumulator.mean() if early_stopping_accumulator.num_values > 0: epoch_callback_metrics['early_stop_on'] = early_stopping_accumulator.mean() # ------------------------ # determine if using a result obj # ------------------------ # [optimizer_idx][training_step_idx][tbptt_index] opt_idx_outputs = epoch_output[0] try: sample_obj = opt_idx_outputs[0][0] if isinstance(opt_idx_outputs[0], list) else opt_idx_outputs[0] is_result_obj = len(epoch_output) > 0 and isinstance(sample_obj, Result) except IndexError as e: is_result_obj = False # -------------------------- # EPOCH END STEP IF DEFINED # -------------------------- if is_overridden('training_epoch_end', model=model): self.global_step += 1 if is_result_obj: # with result object gather across time and training steps so each opt idx has a single result obj epoch_output = self.__gather_result_across_time_and_optimizers(epoch_output) if num_optimizers == 1: epoch_output = epoch_output[0] # run training_epoch_end # a list with a result per optimizer index epoch_output = model.training_epoch_end(epoch_output) if isinstance(epoch_output, Result): epoch_log_metrics = epoch_output.epoch_log_metrics epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics else: _processed_outputs = self.process_output(epoch_output) epoch_progress_bar_metrics = _processed_outputs[1] epoch_log_metrics = _processed_outputs[2] epoch_callback_metrics = _processed_outputs[3] # -------------------------- # Structured Result (auto epoch end) # -------------------------- elif is_result_obj: epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output) # -------------------------- # track results # -------------------------- # add the metrics to the loggers if epoch_log_metrics and len(epoch_log_metrics) > 0: self.log_metrics(epoch_log_metrics, {}) # add metrics to callbacks self.callback_metrics.update(epoch_callback_metrics) # add metrics to progress_bar if len(epoch_progress_bar_metrics) > 0: self.add_progress_bar_metrics(epoch_progress_bar_metrics) def __auto_reduce_results_on_epoch_end(self, epoch_output): epoch_log_metrics = {} epoch_progress_bar_metrics = {} for opt_outputs in epoch_output: # reduce across time first time_reduced_outputs = [] for train_step_idx in range(len(opt_outputs)): tbptt_outs = opt_outputs[train_step_idx] tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs) time_reduced_outputs.append(tbptt_outs) # reduce across training steps opt_outputs = time_reduced_outputs[0].__class__.reduce_on_epoch_end(time_reduced_outputs) opt_outputs.minimize = opt_outputs.minimize.mean() epoch_log_metrics.update(opt_outputs.epoch_log_metrics) epoch_progress_bar_metrics.update(opt_outputs.epoch_pbar_metrics) return epoch_log_metrics, epoch_progress_bar_metrics def __gather_result_across_time_and_optimizers(self, epoch_output): """ Gather results into a single padded tensor per metric where each tensor is gathered across time and across time steps. Returns: a list where each element is a Result with the tensors gathered """ gathered_epoch_outputs = [] for opt_outputs in epoch_output: # gather across time first time_gathered_outputs = [] for train_step_idx in range(len(opt_outputs)): tbptt_outs = opt_outputs[train_step_idx] tbptt_outs = tbptt_outs[0].__class__.gather(tbptt_outs) time_gathered_outputs.append(tbptt_outs) # gather across training steps # each metric has dimensions (training_steps, seq_len) (seq_len=1 when no tbptt is used) gathered_opt_output = time_gathered_outputs[0].__class__.padded_gather(time_gathered_outputs) gathered_epoch_outputs.append(gathered_opt_output) return gathered_epoch_outputs def sync_horovod(self): if self.use_horovod: hvd.join(hvd.local_rank() if self.on_gpu else -1) def increment_accumulated_grad_global_step(self): # progress global step according to grads progress if ((self.batch_idx + 1) % self.accumulate_grad_batches == 0 or (self.batch_idx + 1) == self.num_training_batches): self.global_step += 1 self.total_batch_idx += 1 def save_train_loop_metrics_to_loggers(self, batch_idx, batch_output): # when metrics should be logged should_log_metrics = (batch_idx + 1) % self.row_log_interval == 0 or self.should_stop if should_log_metrics or self.fast_dev_run: # logs user requested information to logger metrics = batch_output.batch_log_metrics grad_norm_dic = batch_output.grad_norm_dic if len(metrics) > 0 or len(grad_norm_dic) > 0: self.log_metrics(metrics, grad_norm_dic) def save_loggers_in_training_loop(self, batch_idx): # when loggers should save to disk should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or self.should_stop if should_save_log or self.fast_dev_run: if self.is_global_zero and self.logger is not None: self.logger.save() def should_check_val(self, batch_idx, is_last_batch): # decide if we should run validation is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0 can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 can_check_val = self.enable_validation and can_check_epoch should_check_val = is_val_check_batch or self.should_stop is_last_batch_for_infinite_dataset = (is_last_batch and self.val_check_batch == float('inf')) should_check_val = can_check_val and (should_check_val or is_last_batch_for_infinite_dataset) return should_check_val def run_training_batch(self, batch, batch_idx): # track grad norms grad_norm_dic = {} # track all metrics for callbacks batch_callback_metrics = [] # track metrics to log batch_log_metrics = [] using_results_obj = False # track all outputs across time and num of optimizers batch_outputs = [[] for i in range(len(self._get_optimizers_iterable()))] if batch is None: return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic) # Batch start events # TODO: deprecate 1.0 with self.profiler.profile('on_batch_start'): # callbacks self.on_batch_start() # hooks if self.is_function_implemented('on_batch_start'): response = self.get_model().on_batch_start(batch) if response == -1: return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) with self.profiler.profile('on_train_batch_start'): # forward support for multiple loaders dataloader_idx = 0 self.on_train_batch_start(batch, batch_idx, dataloader_idx) # hooks if self.is_function_implemented('on_train_batch_start'): response = self.get_model().on_train_batch_start(batch, batch_idx, dataloader_idx) if response == -1: return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) splits = [batch] if self.truncated_bptt_steps is not None: model_ref = self.get_model() with self.profiler.profile('tbptt_split_batch'): splits = model_ref.tbptt_split_batch(batch, self.truncated_bptt_steps) self.hiddens = None for split_idx, split_batch in enumerate(splits): self.split_idx = split_idx for opt_idx, optimizer in self._get_optimizers_iterable(): # make sure only the gradients of the current optimizer's parameters are calculated # in the training step to prevent dangling gradients in multiple-optimizer setup. if len(self.optimizers) > 1: for param in self.get_model().parameters(): param.requires_grad = False for group in optimizer.param_groups: for param in group['params']: param.requires_grad = True # ------------------- # calculate loss (train step + train step end) # ------------------- opt_closure_result = self.optimizer_closure( split_batch, batch_idx, opt_idx, optimizer, self.hiddens ) using_results_obj = isinstance(opt_closure_result.training_step_output, Result) # ------------------------------ # POST forward bookkeeping # ------------------------------ batch_callback_metrics.append(opt_closure_result.training_step_output.callback_metrics) # add metrics to loggers if using_results_obj: metrics_to_log = opt_closure_result.training_step_output.batch_log_metrics step_pbar_metrics = opt_closure_result.training_step_output.batch_pbar_metrics else: metrics_to_log = opt_closure_result.training_step_output.log_metrics step_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end # track metrics batch_log_metrics.append(metrics_to_log) if len(step_pbar_metrics) > 0: self.add_progress_bar_metrics(step_pbar_metrics) # track hiddens self.hiddens = opt_closure_result.hiddens if using_results_obj: opt_closure_result.training_step_output_for_epoch_end.drop_hiddens() # check if loss or model weights are nan if self.terminate_on_nan: self.detect_nan_tensors(opt_closure_result.loss) # track total loss for logging (avoid mem leaks) self.batch_loss_value.append(opt_closure_result.loss) # track all the outputs across all steps batch_outputs[opt_idx].append(opt_closure_result.training_step_output_for_epoch_end) # ------------------------------ # BACKWARD PASS # ------------------------------ # gradient update with accumulated gradients if ((self.batch_idx + 1) % self.accumulate_grad_batches == 0 or (self.batch_idx + 1) == self.num_training_batches): # backward grad_norm_dic = self.run_batch_backward_pass(split_batch, batch_idx, opt_idx, optimizer) # calculate running loss for display self.running_loss.append(self.batch_loss_value.mean() * self.accumulate_grad_batches) # reset for next set of accumulated grads self.batch_loss_value.reset() # Batch end events with self.profiler.profile('on_batch_end'): # callbacks self.on_batch_end() # model hooks if self.is_function_implemented('on_batch_end'): self.get_model().on_batch_end() with self.profiler.profile('on_train_batch_end'): # forward support for multiple loaders dataloader_idx = 0 self.on_train_batch_end(batch, batch_idx, dataloader_idx) # model hooks if self.is_function_implemented('on_train_batch_end'): self.get_model().on_train_batch_end(batch, batch_idx, dataloader_idx) # collapse all metrics into one dict batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()} # track all metrics for callbacks if not using_results_obj: self.callback_metrics.update({k: v for d in batch_callback_metrics for k, v in d.items()}) result = AttributeDict( signal=0, grad_norm_dic=grad_norm_dic, batch_log_metrics=batch_log_metrics, training_step_output_for_epoch_end=batch_outputs ) return result def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer): # ------------------ # GRAD NORMS # ------------------ # track gradient norms when requested grad_norm_dic = {} if batch_idx % self.row_log_interval == 0: if float(self.track_grad_norm) > 0: model = self.get_model() grad_norm_dic = model.grad_norm( self.track_grad_norm) # ------------------ # CLIP GRADS # ------------------ if self.amp_backend == AMPType.NATIVE and not self.use_tpu: self.scaler.unscale_(optimizer) self.clip_gradients(optimizer) # ------------------ # .STEP + ZERO_GRAD # ------------------ self.call_optimizer_step(optimizer, opt_idx, batch_idx, split_batch) return grad_norm_dic def call_optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch): # calls .step(), .zero_grad() # override function to modify this behavior model = self.get_model() with self.profiler.profile('optimizer_step'): lambda_closure = lambda: self.optimizer_closure( split_batch, batch_idx, opt_idx, optimizer, self.hiddens, ).loss # apply TPU optimizer if self.use_tpu and XLA_AVAILABLE: model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx, lambda_closure, on_tpu=True) # for LBFGS do something a bit different elif isinstance(optimizer, torch.optim.LBFGS): # native amp + lbfgs is a no go right now if self.amp_backend == AMPType.NATIVE: raise MisconfigurationException( 'native PyTorch amp and lbfgs are not compatible.' ' To request, please file a Github issue in PyTorch and tag @mcarilli') model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx, lambda_closure, using_lbfgs=True) # when using 16-bit else: native_amp = self.amp_backend == AMPType.NATIVE model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx, lambda_closure, using_native_amp=native_amp) # in native 16-bit we need to update scaler after optimizer step if self.amp_backend == AMPType.NATIVE and not self.use_tpu: self.scaler.update() # model hook model.on_before_zero_grad(optimizer) # clear gradients model.optimizer_zero_grad(self.current_epoch, batch_idx, optimizer, opt_idx) def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): """ wrap the forward step in a closure so second order methods work """ # --------------------------- # FORWARD (TRAINING STEP + TRAIN STEP END) # --------------------------- with self.profiler.profile('model_forward'): args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens) training_step_output = self.accelerator_backend.training_step(args) training_step_output = self.call_hook('training_step_end', training_step_output) # ---------------------------- # PROCESS THE RESULT # ---------------------------- # format and reduce outputs accordingly training_step_output_for_epoch_end = training_step_output is_result_obj = isinstance(training_step_output, Result) # track batch size for weighted average if is_result_obj: training_step_output.track_batch_size(len(split_batch)) # don't allow EvalResult in the training_step if isinstance(training_step_output, EvalResult): raise MisconfigurationException('training_step cannot return EvalResult, ' 'use a dict or TrainResult instead') # handle regular dicts if not is_result_obj: training_step_output = self.process_output(training_step_output, train=True) training_step_output = AttributeDict( batch_loss=training_step_output[0], pbar_on_batch_end=training_step_output[1], log_metrics=training_step_output[2], callback_metrics=training_step_output[3], hiddens=training_step_output[4], ) # if the user decides to finally reduce things in epoch_end, save raw output without graphs if isinstance(training_step_output_for_epoch_end, torch.Tensor): training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach() elif is_result_obj: training_step_output_for_epoch_end = copy(training_step_output) training_step_output_for_epoch_end.detach() else: training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end) # accumulate loss # (if accumulate_grad_batches = 1 no effect) closure_loss = training_step_output.minimize if is_result_obj else training_step_output.batch_loss closure_loss = closure_loss / self.accumulate_grad_batches # the loss will get scaled for amp. avoid any modifications to it untouched_loss = closure_loss.detach().clone() # backward pass model_ref = self.get_model() with self.profiler.profile('model_backward'): # scale loss for 16 bit if self.precision == 16 and not self.on_tpu: closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx, amp_backend=self.amp_backend) # enter amp context if self.amp_backend == AMPType.APEX: self.dev_debugger.track_event('AMP', str(AMPType.APEX)) context = closure_loss closure_loss = closure_loss.__enter__() # do backward pass model_ref.backward(self, closure_loss, optimizer, opt_idx) # exit amp context if self.precision == 16 and self.amp_backend == AMPType.APEX and not self.on_tpu: a, b, c = None, None, None error = context.__exit__(a, b, c) if error: rank_zero_warn(a, b, c) raise Exception('apex unscale error') # once backward has been applied, release graph closure_loss = closure_loss.detach() if is_result_obj: training_step_output.detach() else: training_step_output.batch_loss = training_step_output.batch_loss.detach() if self.use_horovod: # Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid optimizer.synchronize() # insert after step hook if self.is_function_implemented('on_after_backward'): model_ref = self.get_model() with self.profiler.profile('on_after_backward'): model_ref.on_after_backward() # when in dev debugging track the losses self.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach()) result = AttributeDict( loss=untouched_loss, training_step_output=training_step_output, training_step_output_for_epoch_end=training_step_output_for_epoch_end, hiddens=training_step_output.hiddens, ) return result def _get_optimizers_iterable(self): if not self.optimizer_frequencies: # call training_step once per optimizer return list(enumerate(self.optimizers)) optimizer_freq_cumsum = np.cumsum(self.optimizer_frequencies) optimizers_loop_length = optimizer_freq_cumsum[-1] current_place_in_loop = self.total_batch_idx % optimizers_loop_length # find optimzier index by looking for the first {item > current_place} in the cumsum list opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop) return [(opt_idx, self.optimizers[opt_idx])] # @atexit.register def run_training_teardown(self): if hasattr(self, '_teardown_already_run') and self._teardown_already_run: return self._teardown_already_run = True # Save latest checkpoint log.info('Saving latest checkpoint..') self.check_checkpoint_callback(should_check_val=False) # Train end events with self.profiler.profile('on_train_end'): # callbacks self.on_train_end() # model hooks if self.is_function_implemented('on_train_end'): self.get_model().on_train_end() if self.logger is not None: self.logger.finalize("success") # summarize profile results if self.global_rank == 0: self.profiler.describe() if self.global_rank == 0: for proc in self.interactive_ddp_procs: subprocess.Popen.kill(proc) # clean up dist group if self.use_ddp or self.use_ddp2: torch_distrib.destroy_process_group() # clear mem if self.on_gpu: model = self.get_model() model.cpu() torch.cuda.empty_cache() def build_train_args(self, batch, batch_idx, opt_idx, hiddens): # enable not needing to add opt_idx to training_step args = [batch, batch_idx] if len(self.optimizers) > 1: if self.has_arg('training_step', 'optimizer_idx'): args.append(opt_idx) else: num_opts = len(self.optimizers) raise ValueError( f'Your LightningModule defines {num_opts} optimizers but ' f'training_step is missing the "optimizer_idx" argument.' ) # pass hiddens if using tbptt if self.truncated_bptt_steps is not None: args.append(hiddens) return args def update_learning_rates(self, interval: str, monitor_metrics=None): """Update learning rates. Args: interval: either 'epoch' or 'step'. monitor_metrics: dict of possible values to monitor """ if not self.lr_schedulers: return for scheduler_idx, lr_scheduler in enumerate(self.lr_schedulers): current_idx = self.batch_idx if interval == 'step' else self.current_epoch current_idx += 1 # account for both batch and epoch starts from 0 # Take step if call to update_learning_rates matches the interval key and # the current step modulo the schedulers frequency is zero if lr_scheduler['interval'] == interval and current_idx % lr_scheduler['frequency'] == 0: # If instance of ReduceLROnPlateau, we need to pass validation loss if lr_scheduler['reduce_on_plateau']: monitor_key = lr_scheduler['monitor'] if monitor_metrics is not None: monitor_val = monitor_metrics.get(monitor_key) else: monitor_val = self.callback_metrics.get(monitor_key) if monitor_val is None: avail_metrics = ','.join(list(self.callback_metrics.keys())) raise MisconfigurationException( f'ReduceLROnPlateau conditioned on metric {monitor_key}' f' which is not available. Available metrics are: {avail_metrics}.' ' Condition can be set using `monitor` key in lr scheduler dict' ) if self.dev_debugger.enabled: old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] # update LR lr_scheduler['scheduler'].step(monitor_val) if self.dev_debugger.enabled: new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] self.dev_debugger.track_lr_schedulers_update( self.batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key, ) else: if self.dev_debugger.enabled: old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] # update LR lr_scheduler['scheduler'].step() if self.dev_debugger.enabled: new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] self.dev_debugger.track_lr_schedulers_update( self.batch_idx, interval, scheduler_idx, old_lr, new_lr )