def advance(self, kwargs: OrderedDict) -> None:  # type: ignore[override]
        """Performs the training step for manual optimization.

        Args:
            kwargs: The kwargs passed down to the hooks.
        """
        assert self.trainer is not None

        kwargs = self._build_kwargs(kwargs, self._hiddens)

        # manually capture logged metrics
        training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
        del kwargs  # release the batch from memory
        self.trainer.strategy.post_training_step()

        model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
        strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
        training_step_output = strategy_output if model_output is None else model_output
        self._hiddens = _extract_hiddens(training_step_output, self.trainer.lightning_module.truncated_bptt_steps)

        result = self.output_result_cls.from_training_step_output(training_step_output)

        if self.trainer.move_metrics_to_cpu:
            # hiddens and the training step output are not moved as they are not considered "metrics"
            # the user might need them on the correct device for an operation in `training_epoch_end`
            assert self.trainer._results is not None
            self.trainer._results.cpu()

        self._done = True
        self._output = result.asdict()
Beispiel #2
0
    def _training_step(self, kwargs: OrderedDict) -> ClosureResult:
        """Performs the actual train step with the tied hooks.

        Args:
            kwargs: the kwargs passed down to the hooks.

        Returns:
            A ``ClosureResult`` containing the training step output.
        """
        # manually capture logged metrics
        training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
        self.trainer.strategy.post_training_step()

        model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
        strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
        training_step_output = strategy_output if model_output is None else model_output

        self._hiddens = _extract_hiddens(training_step_output, self.trainer.lightning_module.truncated_bptt_steps)

        result = self.output_result_cls.from_training_step_output(
            training_step_output, self.trainer.accumulate_grad_batches
        )

        if self.trainer.move_metrics_to_cpu:
            # hiddens and the training step output are not moved as they are not considered "metrics"
            assert self.trainer._results is not None
            self.trainer._results.cpu()

        return result
def test_extract_hiddens():
    # tbptt not enabled, no hiddens return
    training_step_output = 1  # anything
    hiddens = _extract_hiddens(training_step_output, 0)
    assert hiddens is None

    # tbptt enabled, hiddens return
    hiddens = torch.tensor(321.12, requires_grad=True)
    training_step_output = {"hiddens": hiddens}
    hiddens = _extract_hiddens(training_step_output, 2)
    assert "hiddens" in training_step_output
    assert not hiddens.requires_grad

    # tbptt not enabled, hiddens return
    with pytest.raises(
            MisconfigurationException,
            match='returned "hiddens" .* but `truncated_bptt_steps` is disabled'
    ):
        _extract_hiddens(training_step_output, 0)
    # tbptt enabled, no hiddens return
    with pytest.raises(
            MisconfigurationException,
            match='enabled `truncated_bptt_steps` but did not return "hiddens"'
    ):
        _extract_hiddens(None, 1)
Beispiel #4
0
    def advance(self, batch: Any,
                batch_idx: int) -> None:  # type: ignore[override]
        """Performs the training step for manual optimization.

        Args:
            batch: the current tbptt split of the current batch
            batch_idx: the index of the current batch
        """
        assert self.trainer is not None
        lightning_module = self.trainer.lightning_module

        with self.trainer.profiler.profile("model_forward"):

            step_kwargs = _build_training_step_kwargs(lightning_module,
                                                      self.trainer.optimizers,
                                                      batch,
                                                      batch_idx,
                                                      opt_idx=None,
                                                      hiddens=self._hiddens)

            # manually capture logged metrics
            lightning_module._current_fx_name = "training_step"
            with self.trainer.profiler.profile("training_step"):
                training_step_output = self.trainer.accelerator.training_step(
                    step_kwargs)
                self.trainer.accelerator.post_training_step()

            del step_kwargs

            training_step_output = self.trainer.call_hook(
                "training_step_end", training_step_output)

            _check_training_step_output(lightning_module, training_step_output)

            self._hiddens = _extract_hiddens(
                training_step_output, lightning_module.truncated_bptt_steps)

            # TODO: do not use `ClosureResult`
            result = ClosureResult.from_training_step_output(
                training_step_output, self.trainer.accumulate_grad_batches)

            if self.trainer.terminate_on_nan:
                check_finite_loss(result.closure_loss)

            if self.trainer.move_metrics_to_cpu:
                # hiddens and the training step output are not moved as they are not considered "metrics"
                # the user might need them on the correct device for an operation in `training_epoch_end`
                assert self.trainer._results is not None
                self.trainer._results.cpu()

        self._done = True
        self._output = result
    def _training_step(self, split_batch: Any, batch_idx: int,
                       opt_idx: int) -> ClosureResult:
        """Performs the actual train step with the tied hooks.

        Args:
            split_batch: the current tbptt split of the current batch
            batch_idx: the index of the current batch
            opt_idx: the index of the current optimizer

        Returns:
            A ``ClosureResult`` containing the training step output.
        """
        # give the PL module a result for logging
        lightning_module = self.trainer.lightning_module

        with self.trainer.profiler.profile("model_forward"):

            step_kwargs = _build_training_step_kwargs(lightning_module,
                                                      self.trainer.optimizers,
                                                      split_batch, batch_idx,
                                                      opt_idx, self._hiddens)

            # manually capture logged metrics
            lightning_module._current_fx_name = "training_step"
            with self.trainer.profiler.profile("training_step"):
                training_step_output = self.trainer.accelerator.training_step(
                    step_kwargs)
                self.trainer.accelerator.post_training_step()

            del step_kwargs

            training_step_output = self.trainer.call_hook(
                "training_step_end", training_step_output)

            _check_training_step_output(lightning_module, training_step_output)

            self._hiddens = _extract_hiddens(
                training_step_output, lightning_module.truncated_bptt_steps)

            result = ClosureResult.from_training_step_output(
                training_step_output, self.trainer.accumulate_grad_batches)

            if self.trainer.terminate_on_nan:
                check_finite_loss(result.closure_loss)

            if self.trainer.move_metrics_to_cpu:
                # hiddens and the training step output are not moved as they are not considered "metrics"
                assert self.trainer._results is not None
                self.trainer._results.cpu()

        return result
Beispiel #6
0
    def _training_step(self, split_batch: Any, batch_idx: int,
                       opt_idx: int) -> ClosureResult:
        """Performs the actual train step with the tied hooks.

        Args:
            split_batch: the current tbptt split of the current batch
            batch_idx: the index of the current batch
            opt_idx: the index of the current optimizer

        Returns:
            A ``ClosureResult`` containing the training step output.
        """
        # give the PL module a result for logging
        lightning_module = self.trainer.lightning_module

        step_kwargs = _build_training_step_kwargs(lightning_module,
                                                  self.trainer.optimizers,
                                                  split_batch, batch_idx,
                                                  opt_idx, self._hiddens)

        # manually capture logged metrics
        training_step_output = self.trainer._call_strategy_hook(
            "training_step", *step_kwargs.values())
        self.trainer.strategy.post_training_step()

        model_output = self.trainer._call_lightning_module_hook(
            "training_step_end", training_step_output)
        strategy_output = self.trainer._call_strategy_hook(
            "training_step_end", training_step_output)
        training_step_output = strategy_output if model_output is None else model_output

        self._hiddens = _extract_hiddens(training_step_output,
                                         lightning_module.truncated_bptt_steps)

        result = self.output_result_cls.from_training_step_output(
            training_step_output, self.trainer.accumulate_grad_batches)

        if self.trainer._terminate_on_nan:
            check_finite_loss(result.closure_loss)

        if self.trainer.move_metrics_to_cpu:
            # hiddens and the training step output are not moved as they are not considered "metrics"
            assert self.trainer._results is not None
            self.trainer._results.cpu()

        return result
Beispiel #7
0
    def advance(self, batch: Any,
                batch_idx: int) -> None:  # type: ignore[override]
        """Performs the training step for manual optimization.

        Args:
            batch: the current tbptt split of the current batch
            batch_idx: the index of the current batch
        """
        assert self.trainer is not None
        lightning_module = self.trainer.lightning_module

        with self.trainer.profiler.profile("model_forward"):

            step_kwargs = _build_training_step_kwargs(lightning_module,
                                                      self.trainer.optimizers,
                                                      batch,
                                                      batch_idx,
                                                      opt_idx=None,
                                                      hiddens=self._hiddens)

            # manually capture logged metrics
            training_step_output = self.trainer._call_strategy_hook(
                "training_step", *step_kwargs.values())
            self.trainer.strategy.post_training_step()

            del step_kwargs

            model_output = self.trainer._call_lightning_module_hook(
                "training_step_end", training_step_output)
            strategy_output = self.trainer._call_strategy_hook(
                "training_step_end", training_step_output)
            training_step_output = strategy_output if model_output is None else model_output
            self._hiddens = _extract_hiddens(
                training_step_output, lightning_module.truncated_bptt_steps)

            result = self.output_result_cls.from_training_step_output(
                training_step_output)

            if self.trainer.move_metrics_to_cpu:
                # hiddens and the training step output are not moved as they are not considered "metrics"
                # the user might need them on the correct device for an operation in `training_epoch_end`
                assert self.trainer._results is not None
                self.trainer._results.cpu()

        self._done = True
        self._output = result.asdict()