Esempio n. 1
0
    def advance(self, batch: Any,
                batch_idx: int) -> None:  # type: ignore[override]
        """Performs the training step for manual optimization.

        Args:
            batch: the current tbptt split of the current batch
            batch_idx: the index of the current batch
        """
        assert self.trainer is not None
        lightning_module = self.trainer.lightning_module

        with self.trainer.profiler.profile("model_forward"):

            step_kwargs = _build_training_step_kwargs(lightning_module,
                                                      self.trainer.optimizers,
                                                      batch,
                                                      batch_idx,
                                                      opt_idx=None,
                                                      hiddens=self._hiddens)

            # manually capture logged metrics
            lightning_module._current_fx_name = "training_step"
            with self.trainer.profiler.profile("training_step"):
                training_step_output = self.trainer.accelerator.training_step(
                    step_kwargs)
                self.trainer.accelerator.post_training_step()

            del step_kwargs

            training_step_output = self.trainer.call_hook(
                "training_step_end", training_step_output)

            _check_training_step_output(lightning_module, training_step_output)

            self._hiddens = _extract_hiddens(
                training_step_output, lightning_module.truncated_bptt_steps)

            # TODO: do not use `ClosureResult`
            result = ClosureResult.from_training_step_output(
                training_step_output, self.trainer.accumulate_grad_batches)

            if self.trainer.terminate_on_nan:
                check_finite_loss(result.closure_loss)

            if self.trainer.move_metrics_to_cpu:
                # hiddens and the training step output are not moved as they are not considered "metrics"
                # the user might need them on the correct device for an operation in `training_epoch_end`
                assert self.trainer._results is not None
                self.trainer._results.cpu()

        self._done = True
        self._output = result
    def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int,
                       hiddens: Tensor) -> Optional[AttributeDict]:
        """Performs the actual train step with the tied hooks.

        Args:
            split_batch: the current tbptt split of the current batch
            batch_idx: the index of the current batch
            opt_idx: the index of the current optimizer
            hiddens: the model's hidden state of the previous iteration

        Returns:
            an AttributeDict containing the loss value and the training step output.
        """
        # give the PL module a result for logging
        model_ref = self.trainer.lightning_module

        with self.trainer.profiler.profile("model_forward"):
            step_kwargs = _build_training_step_kwargs(model_ref,
                                                      self.trainer.optimizers,
                                                      split_batch, batch_idx,
                                                      opt_idx, hiddens)

            # manually capture logged metrics
            model_ref._current_fx_name = "training_step"
            with self.trainer.profiler.profile("training_step"):
                training_step_output = self.trainer.accelerator.training_step(
                    step_kwargs)
                self.trainer.accelerator.post_training_step()

            del step_kwargs

            training_step_output = self.trainer.call_hook(
                "training_step_end", training_step_output)

            _check_training_step_output(self.trainer.lightning_module,
                                        training_step_output)

            result_collection, self._hiddens = _process_training_step_output(
                self.trainer, training_step_output)
            if result_collection is None:
                return

        closure_loss = None
        loss = None
        if self.trainer.lightning_module.automatic_optimization:
            # accumulate loss. if accumulate_grad_batches==1, no effect
            closure_loss = result_collection.minimize / self.trainer.accumulate_grad_batches
            # the loss will get scaled for amp. avoid any modifications to it
            loss = closure_loss.detach().clone()
        return AttributeDict(closure_loss=closure_loss,
                             loss=loss,
                             result_collection=result_collection)
Esempio n. 3
0
    def _build_kwargs(self, kwargs: OrderedDict, hiddens: Optional[Any]) -> OrderedDict:
        """Helper method to build the arguments for the current step.

        Args:
            kwargs: The kwargs passed down to the hooks.
            hiddens: the hidden state of the previous RNN iteration.

        Returns:
            The kwargs passed down to the hooks.
        """
        return _build_training_step_kwargs(
            kwargs, self.trainer.lightning_module, self.trainer.optimizers, None, hiddens
        )
    def _get_generator(self, split_batch, batch_idx, opt_idx):
        step_kwargs = _build_training_step_kwargs(
            self.trainer.lightning_module,
            self.trainer.optimizers,
            split_batch,
            batch_idx,
            opt_idx,
            hiddens=None)

        # Here we are basically calling `lightning_module.training_step()`
        # and this returns a generator! The `training_step` is handled by the
        # accelerator to enable distributed training.
        return self.trainer.strategy.training_step(*step_kwargs.values())
    def _training_step(self, split_batch: Any, batch_idx: int,
                       opt_idx: int) -> ClosureResult:
        """Performs the actual train step with the tied hooks.

        Args:
            split_batch: the current tbptt split of the current batch
            batch_idx: the index of the current batch
            opt_idx: the index of the current optimizer

        Returns:
            A ``ClosureResult`` containing the training step output.
        """
        # give the PL module a result for logging
        lightning_module = self.trainer.lightning_module

        with self.trainer.profiler.profile("model_forward"):

            step_kwargs = _build_training_step_kwargs(lightning_module,
                                                      self.trainer.optimizers,
                                                      split_batch, batch_idx,
                                                      opt_idx, self._hiddens)

            # manually capture logged metrics
            lightning_module._current_fx_name = "training_step"
            with self.trainer.profiler.profile("training_step"):
                training_step_output = self.trainer.accelerator.training_step(
                    step_kwargs)
                self.trainer.accelerator.post_training_step()

            del step_kwargs

            training_step_output = self.trainer.call_hook(
                "training_step_end", training_step_output)

            _check_training_step_output(lightning_module, training_step_output)

            self._hiddens = _extract_hiddens(
                training_step_output, lightning_module.truncated_bptt_steps)

            result = ClosureResult.from_training_step_output(
                training_step_output, self.trainer.accumulate_grad_batches)

            if self.trainer.terminate_on_nan:
                check_finite_loss(result.closure_loss)

            if self.trainer.move_metrics_to_cpu:
                # hiddens and the training step output are not moved as they are not considered "metrics"
                assert self.trainer._results is not None
                self.trainer._results.cpu()

        return result
    def _training_step(self, split_batch: Any, batch_idx: int,
                       hiddens: Tensor) -> Optional[AttributeDict]:
        """Performs the training step for manual optimization.

        Args:
            split_batch: the current tbptt split of the current batch
            batch_idx: the index of the current batch
            hiddens: the model's hidden state of the previous iteration

        Returns:
            an AttributeDict containing the training step output.
        """
        # give the PL module a result for logging
        model_ref = self.trainer.lightning_module

        with self.trainer.profiler.profile("model_forward"):
            step_kwargs = _build_training_step_kwargs(model_ref,
                                                      self.trainer.optimizers,
                                                      split_batch,
                                                      batch_idx,
                                                      opt_idx=None,
                                                      hiddens=hiddens)

            # manually capture logged metrics
            model_ref._current_fx_name = "training_step"
            with self.trainer.profiler.profile("training_step"):
                training_step_output = self.trainer.accelerator.training_step(
                    step_kwargs)
                self.trainer.accelerator.post_training_step()

            del step_kwargs

            training_step_output = self.trainer.call_hook(
                "training_step_end", training_step_output)

            _check_training_step_output(self.trainer.lightning_module,
                                        training_step_output)

            result_collection, self._hiddens = _process_training_step_output(
                self.trainer, training_step_output)
            if result_collection is None:
                return

        return AttributeDict(closure_loss=None,
                             loss=None,
                             result_collection=result_collection)
Esempio n. 7
0
    def _training_step(self, split_batch: Any, batch_idx: int,
                       opt_idx: int) -> ClosureResult:
        """Performs the actual train step with the tied hooks.

        Args:
            split_batch: the current tbptt split of the current batch
            batch_idx: the index of the current batch
            opt_idx: the index of the current optimizer

        Returns:
            A ``ClosureResult`` containing the training step output.
        """
        # give the PL module a result for logging
        lightning_module = self.trainer.lightning_module

        step_kwargs = _build_training_step_kwargs(lightning_module,
                                                  self.trainer.optimizers,
                                                  split_batch, batch_idx,
                                                  opt_idx, self._hiddens)

        # manually capture logged metrics
        training_step_output = self.trainer._call_strategy_hook(
            "training_step", *step_kwargs.values())
        self.trainer.strategy.post_training_step()

        model_output = self.trainer._call_lightning_module_hook(
            "training_step_end", training_step_output)
        strategy_output = self.trainer._call_strategy_hook(
            "training_step_end", training_step_output)
        training_step_output = strategy_output if model_output is None else model_output

        self._hiddens = _extract_hiddens(training_step_output,
                                         lightning_module.truncated_bptt_steps)

        result = self.output_result_cls.from_training_step_output(
            training_step_output, self.trainer.accumulate_grad_batches)

        if self.trainer._terminate_on_nan:
            check_finite_loss(result.closure_loss)

        if self.trainer.move_metrics_to_cpu:
            # hiddens and the training step output are not moved as they are not considered "metrics"
            assert self.trainer._results is not None
            self.trainer._results.cpu()

        return result
Esempio n. 8
0
    def advance(self, batch: Any,
                batch_idx: int) -> None:  # type: ignore[override]
        """Performs the training step for manual optimization.

        Args:
            batch: the current tbptt split of the current batch
            batch_idx: the index of the current batch
        """
        assert self.trainer is not None
        lightning_module = self.trainer.lightning_module

        with self.trainer.profiler.profile("model_forward"):

            step_kwargs = _build_training_step_kwargs(lightning_module,
                                                      self.trainer.optimizers,
                                                      batch,
                                                      batch_idx,
                                                      opt_idx=None,
                                                      hiddens=self._hiddens)

            # manually capture logged metrics
            training_step_output = self.trainer._call_strategy_hook(
                "training_step", *step_kwargs.values())
            self.trainer.strategy.post_training_step()

            del step_kwargs

            model_output = self.trainer._call_lightning_module_hook(
                "training_step_end", training_step_output)
            strategy_output = self.trainer._call_strategy_hook(
                "training_step_end", training_step_output)
            training_step_output = strategy_output if model_output is None else model_output
            self._hiddens = _extract_hiddens(
                training_step_output, lightning_module.truncated_bptt_steps)

            result = self.output_result_cls.from_training_step_output(
                training_step_output)

            if self.trainer.move_metrics_to_cpu:
                # hiddens and the training step output are not moved as they are not considered "metrics"
                # the user might need them on the correct device for an operation in `training_epoch_end`
                assert self.trainer._results is not None
                self.trainer._results.cpu()

        self._done = True
        self._output = result.asdict()