def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] """Performs the training step for manual optimization. Args: batch: the current tbptt split of the current batch batch_idx: the index of the current batch """ assert self.trainer is not None lightning_module = self.trainer.lightning_module with self.trainer.profiler.profile("model_forward"): step_kwargs = _build_training_step_kwargs(lightning_module, self.trainer.optimizers, batch, batch_idx, opt_idx=None, hiddens=self._hiddens) # manually capture logged metrics lightning_module._current_fx_name = "training_step" with self.trainer.profiler.profile("training_step"): training_step_output = self.trainer.accelerator.training_step( step_kwargs) self.trainer.accelerator.post_training_step() del step_kwargs training_step_output = self.trainer.call_hook( "training_step_end", training_step_output) _check_training_step_output(lightning_module, training_step_output) self._hiddens = _extract_hiddens( training_step_output, lightning_module.truncated_bptt_steps) # TODO: do not use `ClosureResult` result = ClosureResult.from_training_step_output( training_step_output, self.trainer.accumulate_grad_batches) if self.trainer.terminate_on_nan: check_finite_loss(result.closure_loss) if self.trainer.move_metrics_to_cpu: # hiddens and the training step output are not moved as they are not considered "metrics" # the user might need them on the correct device for an operation in `training_epoch_end` assert self.trainer._results is not None self.trainer._results.cpu() self._done = True self._output = result
def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Tensor) -> Optional[AttributeDict]: """Performs the actual train step with the tied hooks. Args: split_batch: the current tbptt split of the current batch batch_idx: the index of the current batch opt_idx: the index of the current optimizer hiddens: the model's hidden state of the previous iteration Returns: an AttributeDict containing the loss value and the training step output. """ # give the PL module a result for logging model_ref = self.trainer.lightning_module with self.trainer.profiler.profile("model_forward"): step_kwargs = _build_training_step_kwargs(model_ref, self.trainer.optimizers, split_batch, batch_idx, opt_idx, hiddens) # manually capture logged metrics model_ref._current_fx_name = "training_step" with self.trainer.profiler.profile("training_step"): training_step_output = self.trainer.accelerator.training_step( step_kwargs) self.trainer.accelerator.post_training_step() del step_kwargs training_step_output = self.trainer.call_hook( "training_step_end", training_step_output) _check_training_step_output(self.trainer.lightning_module, training_step_output) result_collection, self._hiddens = _process_training_step_output( self.trainer, training_step_output) if result_collection is None: return closure_loss = None loss = None if self.trainer.lightning_module.automatic_optimization: # accumulate loss. if accumulate_grad_batches==1, no effect closure_loss = result_collection.minimize / self.trainer.accumulate_grad_batches # the loss will get scaled for amp. avoid any modifications to it loss = closure_loss.detach().clone() return AttributeDict(closure_loss=closure_loss, loss=loss, result_collection=result_collection)
def _build_kwargs(self, kwargs: OrderedDict, hiddens: Optional[Any]) -> OrderedDict: """Helper method to build the arguments for the current step. Args: kwargs: The kwargs passed down to the hooks. hiddens: the hidden state of the previous RNN iteration. Returns: The kwargs passed down to the hooks. """ return _build_training_step_kwargs( kwargs, self.trainer.lightning_module, self.trainer.optimizers, None, hiddens )
def _get_generator(self, split_batch, batch_idx, opt_idx): step_kwargs = _build_training_step_kwargs( self.trainer.lightning_module, self.trainer.optimizers, split_batch, batch_idx, opt_idx, hiddens=None) # Here we are basically calling `lightning_module.training_step()` # and this returns a generator! The `training_step` is handled by the # accelerator to enable distributed training. return self.trainer.strategy.training_step(*step_kwargs.values())
def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> ClosureResult: """Performs the actual train step with the tied hooks. Args: split_batch: the current tbptt split of the current batch batch_idx: the index of the current batch opt_idx: the index of the current optimizer Returns: A ``ClosureResult`` containing the training step output. """ # give the PL module a result for logging lightning_module = self.trainer.lightning_module with self.trainer.profiler.profile("model_forward"): step_kwargs = _build_training_step_kwargs(lightning_module, self.trainer.optimizers, split_batch, batch_idx, opt_idx, self._hiddens) # manually capture logged metrics lightning_module._current_fx_name = "training_step" with self.trainer.profiler.profile("training_step"): training_step_output = self.trainer.accelerator.training_step( step_kwargs) self.trainer.accelerator.post_training_step() del step_kwargs training_step_output = self.trainer.call_hook( "training_step_end", training_step_output) _check_training_step_output(lightning_module, training_step_output) self._hiddens = _extract_hiddens( training_step_output, lightning_module.truncated_bptt_steps) result = ClosureResult.from_training_step_output( training_step_output, self.trainer.accumulate_grad_batches) if self.trainer.terminate_on_nan: check_finite_loss(result.closure_loss) if self.trainer.move_metrics_to_cpu: # hiddens and the training step output are not moved as they are not considered "metrics" assert self.trainer._results is not None self.trainer._results.cpu() return result
def _training_step(self, split_batch: Any, batch_idx: int, hiddens: Tensor) -> Optional[AttributeDict]: """Performs the training step for manual optimization. Args: split_batch: the current tbptt split of the current batch batch_idx: the index of the current batch hiddens: the model's hidden state of the previous iteration Returns: an AttributeDict containing the training step output. """ # give the PL module a result for logging model_ref = self.trainer.lightning_module with self.trainer.profiler.profile("model_forward"): step_kwargs = _build_training_step_kwargs(model_ref, self.trainer.optimizers, split_batch, batch_idx, opt_idx=None, hiddens=hiddens) # manually capture logged metrics model_ref._current_fx_name = "training_step" with self.trainer.profiler.profile("training_step"): training_step_output = self.trainer.accelerator.training_step( step_kwargs) self.trainer.accelerator.post_training_step() del step_kwargs training_step_output = self.trainer.call_hook( "training_step_end", training_step_output) _check_training_step_output(self.trainer.lightning_module, training_step_output) result_collection, self._hiddens = _process_training_step_output( self.trainer, training_step_output) if result_collection is None: return return AttributeDict(closure_loss=None, loss=None, result_collection=result_collection)
def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> ClosureResult: """Performs the actual train step with the tied hooks. Args: split_batch: the current tbptt split of the current batch batch_idx: the index of the current batch opt_idx: the index of the current optimizer Returns: A ``ClosureResult`` containing the training step output. """ # give the PL module a result for logging lightning_module = self.trainer.lightning_module step_kwargs = _build_training_step_kwargs(lightning_module, self.trainer.optimizers, split_batch, batch_idx, opt_idx, self._hiddens) # manually capture logged metrics training_step_output = self.trainer._call_strategy_hook( "training_step", *step_kwargs.values()) self.trainer.strategy.post_training_step() model_output = self.trainer._call_lightning_module_hook( "training_step_end", training_step_output) strategy_output = self.trainer._call_strategy_hook( "training_step_end", training_step_output) training_step_output = strategy_output if model_output is None else model_output self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps) result = self.output_result_cls.from_training_step_output( training_step_output, self.trainer.accumulate_grad_batches) if self.trainer._terminate_on_nan: check_finite_loss(result.closure_loss) if self.trainer.move_metrics_to_cpu: # hiddens and the training step output are not moved as they are not considered "metrics" assert self.trainer._results is not None self.trainer._results.cpu() return result
def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] """Performs the training step for manual optimization. Args: batch: the current tbptt split of the current batch batch_idx: the index of the current batch """ assert self.trainer is not None lightning_module = self.trainer.lightning_module with self.trainer.profiler.profile("model_forward"): step_kwargs = _build_training_step_kwargs(lightning_module, self.trainer.optimizers, batch, batch_idx, opt_idx=None, hiddens=self._hiddens) # manually capture logged metrics training_step_output = self.trainer._call_strategy_hook( "training_step", *step_kwargs.values()) self.trainer.strategy.post_training_step() del step_kwargs model_output = self.trainer._call_lightning_module_hook( "training_step_end", training_step_output) strategy_output = self.trainer._call_strategy_hook( "training_step_end", training_step_output) training_step_output = strategy_output if model_output is None else model_output self._hiddens = _extract_hiddens( training_step_output, lightning_module.truncated_bptt_steps) result = self.output_result_cls.from_training_step_output( training_step_output) if self.trainer.move_metrics_to_cpu: # hiddens and the training step output are not moved as they are not considered "metrics" # the user might need them on the correct device for an operation in `training_epoch_end` assert self.trainer._results is not None self.trainer._results.cpu() self._done = True self._output = result.asdict()