def backward_fn(loss: Tensor):
            self.backward(loss, optimizer, opt_idx)

            # check if loss or model weights are nan
            if self.trainer.terminate_on_nan:
                check_finite_loss(self.trainer.lightning_module, loss)

            return loss
Esempio n. 2
0
    def advance(self, batch: Any,
                batch_idx: int) -> None:  # type: ignore[override]
        """Performs the training step for manual optimization.

        Args:
            batch: the current tbptt split of the current batch
            batch_idx: the index of the current batch
        """
        assert self.trainer is not None
        lightning_module = self.trainer.lightning_module

        with self.trainer.profiler.profile("model_forward"):

            step_kwargs = _build_training_step_kwargs(lightning_module,
                                                      self.trainer.optimizers,
                                                      batch,
                                                      batch_idx,
                                                      opt_idx=None,
                                                      hiddens=self._hiddens)

            # manually capture logged metrics
            lightning_module._current_fx_name = "training_step"
            with self.trainer.profiler.profile("training_step"):
                training_step_output = self.trainer.accelerator.training_step(
                    step_kwargs)
                self.trainer.accelerator.post_training_step()

            del step_kwargs

            training_step_output = self.trainer.call_hook(
                "training_step_end", training_step_output)

            _check_training_step_output(lightning_module, training_step_output)

            self._hiddens = _extract_hiddens(
                training_step_output, lightning_module.truncated_bptt_steps)

            # TODO: do not use `ClosureResult`
            result = ClosureResult.from_training_step_output(
                training_step_output, self.trainer.accumulate_grad_batches)

            if self.trainer.terminate_on_nan:
                check_finite_loss(result.closure_loss)

            if self.trainer.move_metrics_to_cpu:
                # hiddens and the training step output are not moved as they are not considered "metrics"
                # the user might need them on the correct device for an operation in `training_epoch_end`
                assert self.trainer._results is not None
                self.trainer._results.cpu()

        self._done = True
        self._output = result
    def _training_step(self, split_batch: Any, batch_idx: int,
                       opt_idx: int) -> ClosureResult:
        """Performs the actual train step with the tied hooks.

        Args:
            split_batch: the current tbptt split of the current batch
            batch_idx: the index of the current batch
            opt_idx: the index of the current optimizer

        Returns:
            A ``ClosureResult`` containing the training step output.
        """
        # give the PL module a result for logging
        lightning_module = self.trainer.lightning_module

        with self.trainer.profiler.profile("model_forward"):

            step_kwargs = _build_training_step_kwargs(lightning_module,
                                                      self.trainer.optimizers,
                                                      split_batch, batch_idx,
                                                      opt_idx, self._hiddens)

            # manually capture logged metrics
            lightning_module._current_fx_name = "training_step"
            with self.trainer.profiler.profile("training_step"):
                training_step_output = self.trainer.accelerator.training_step(
                    step_kwargs)
                self.trainer.accelerator.post_training_step()

            del step_kwargs

            training_step_output = self.trainer.call_hook(
                "training_step_end", training_step_output)

            _check_training_step_output(lightning_module, training_step_output)

            self._hiddens = _extract_hiddens(
                training_step_output, lightning_module.truncated_bptt_steps)

            result = ClosureResult.from_training_step_output(
                training_step_output, self.trainer.accumulate_grad_batches)

            if self.trainer.terminate_on_nan:
                check_finite_loss(result.closure_loss)

            if self.trainer.move_metrics_to_cpu:
                # hiddens and the training step output are not moved as they are not considered "metrics"
                assert self.trainer._results is not None
                self.trainer._results.cpu()

        return result
Esempio n. 4
0
    def _training_step(self, split_batch: Any, batch_idx: int,
                       opt_idx: int) -> ClosureResult:
        """Performs the actual train step with the tied hooks.

        Args:
            split_batch: the current tbptt split of the current batch
            batch_idx: the index of the current batch
            opt_idx: the index of the current optimizer

        Returns:
            A ``ClosureResult`` containing the training step output.
        """
        # give the PL module a result for logging
        lightning_module = self.trainer.lightning_module

        step_kwargs = _build_training_step_kwargs(lightning_module,
                                                  self.trainer.optimizers,
                                                  split_batch, batch_idx,
                                                  opt_idx, self._hiddens)

        # manually capture logged metrics
        training_step_output = self.trainer._call_strategy_hook(
            "training_step", *step_kwargs.values())
        self.trainer.strategy.post_training_step()

        model_output = self.trainer._call_lightning_module_hook(
            "training_step_end", training_step_output)
        strategy_output = self.trainer._call_strategy_hook(
            "training_step_end", training_step_output)
        training_step_output = strategy_output if model_output is None else model_output

        self._hiddens = _extract_hiddens(training_step_output,
                                         lightning_module.truncated_bptt_steps)

        result = self.output_result_cls.from_training_step_output(
            training_step_output, self.trainer.accumulate_grad_batches)

        if self.trainer._terminate_on_nan:
            check_finite_loss(result.closure_loss)

        if self.trainer.move_metrics_to_cpu:
            # hiddens and the training step output are not moved as they are not considered "metrics"
            assert self.trainer._results is not None
            self.trainer._results.cpu()

        return result