예제 #1
0
def _build_training_step_kwargs(
    lightning_module: "pl.LightningModule",
    optimizers: Sequence[Optimizer],
    batch: Any,
    batch_idx: int,
    opt_idx: Optional[int],
    hiddens: Optional[Any],
) -> Dict[str, Any]:
    """Builds the keyword arguments for training_step.

    Args:
        lightning_module: the LightningModule with a `training_step` hook implementation
        optimizers: the list of optimizers from the Trainer
        batch: the batch to train on
        batch_idx: the index of the current batch
        opt_idx: the index of the current optimizer
        hiddens: the hidden state of the previous RNN iteration

    Returns:
        the keyword arguments for the training step
    """
    # enable not needing to add opt_idx to training_step
    step_kwargs = OrderedDict([("batch", batch)])

    training_step_fx = getattr(lightning_module, "training_step")

    if is_param_in_hook_signature(training_step_fx, "batch_idx", min_args=2):
        step_kwargs["batch_idx"] = batch_idx

    if len(optimizers) > 1:
        has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx")
        if has_opt_idx_in_train_step:
            if not lightning_module.automatic_optimization:
                raise ValueError(
                    "Your `LightningModule.training_step` signature contains an `optimizer_idx` argument but"
                    " in manual optimization optimizers must be handled by the user. Remove the optimizer_idx"
                    " argument or set `self.automatic_optimization = True`."
                )
            step_kwargs["optimizer_idx"] = opt_idx
        elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization:
            raise ValueError(
                f"Your LightningModule defines {len(optimizers)} optimizers but"
                " `training_step` is missing the `optimizer_idx` argument."
            )

    # pass hiddens if using tbptt
    if lightning_module.truncated_bptt_steps > 0:
        step_kwargs["hiddens"] = hiddens

    return step_kwargs
예제 #2
0
    def _build_kwargs(self, batch, batch_idx, opt_idx, hiddens):
        # enable not needing to add opt_idx to training_step
        step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)])

        lightning_module = self.trainer.lightning_module

        if len(self.trainer.optimizers) > 1:
            training_step_fx = getattr(lightning_module, "training_step")
            has_opt_idx_in_train_step = is_param_in_hook_signature(
                training_step_fx, "optimizer_idx")
            if has_opt_idx_in_train_step:
                if not lightning_module.automatic_optimization:
                    self.warning_cache.warn(
                        "`training_step` hook signature has changed in v1.3."
                        " `optimizer_idx` argument has been removed in case of manual optimization. Support for"
                        " the old signature will be removed in v1.5",
                        DeprecationWarning)
                step_kwargs['optimizer_idx'] = opt_idx
            elif not has_opt_idx_in_train_step and self.trainer.lightning_module.automatic_optimization:
                raise ValueError(
                    f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but"
                    ' `training_step` is missing the `optimizer_idx` argument.'
                )

        # pass hiddens if using tbptt
        if self._truncated_bptt_enabled():
            step_kwargs['hiddens'] = hiddens

        return step_kwargs
예제 #3
0
    def on_evaluation_epoch_end(
            self, outputs: Union[List[List[Dict]], List[Dict]]) -> None:
        model_ref = self.trainer.lightning_module
        hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end"

        self.trainer._reset_result_and_set_hook_fx_name(hook_name)

        with self.trainer.profiler.profile(hook_name):

            if hasattr(self.trainer, hook_name):
                on_evaluation_epoch_end_hook = getattr(self.trainer, hook_name)
                on_evaluation_epoch_end_hook(outputs)

            if is_overridden(hook_name, model_ref):
                model_hook_fx = getattr(model_ref, hook_name)
                if is_param_in_hook_signature(model_hook_fx, "outputs"):
                    model_hook_fx(outputs)
                else:
                    self.warning_cache.warn(
                        f"`ModelHooks.{hook_name}` signature has changed in v1.3. `outputs` parameter has been added."
                        " Support for the old signature will be removed in v1.5",
                        DeprecationWarning)
                    model_hook_fx()

        self.trainer._cache_logged_metrics()

        self.trainer.call_hook('on_epoch_end')
def _check_dl_idx_in_on_train_batch_hooks(trainer: "pl.Trainer",
                                          model: "pl.LightningModule") -> None:
    for hook in ("on_train_batch_start", "on_train_batch_end"):
        if is_param_in_hook_signature(getattr(model, hook),
                                      "dataloader_idx",
                                      explicit=True):
            rank_zero_deprecation(
                f"Base `LightningModule.{hook}` hook signature has changed in v1.5."
                " The `dataloader_idx` argument will be removed in v1.7.")

        for cb in trainer.callbacks:
            if is_param_in_hook_signature(getattr(cb, hook),
                                          "dataloader_idx",
                                          explicit=True):
                rank_zero_deprecation(
                    f"Base `Callback.{hook}` hook signature has changed in v1.5."
                    " The `dataloader_idx` argument will be removed in v1.7.")
def _check_setup_method(trainer: "pl.Trainer") -> None:
    for obj in [trainer.lightning_module, trainer.datamodule
                ] + trainer.callbacks:
        if is_overridden("setup", obj) and not is_param_in_hook_signature(
                obj.setup, "stage"):
            raise MisconfigurationException(
                f"`{obj.__class__.__name__}.setup` does not have a `stage` argument."
            )
예제 #6
0
 def has_arg(self, f_name: str, arg_name: str) -> bool:
     rank_zero_deprecation(
         "Internal: TrainerModelHooksMixin.is_function_implemented is deprecated in v1.4"
         " and will be removed in v1.6."
         " Use `pytorch_lightning.utilities.signature_utils.is_param_in_hook_signature` instead."
     )
     model = self.lightning_module
     f_op = getattr(model, f_name, None)
     if not f_op:
         return False
     return is_param_in_hook_signature(f_op, arg_name)
예제 #7
0
 def on_train_batch_start(self, batch, batch_idx, dataloader_idx=0):
     """Called when the training batch begins."""
     for callback in self.callbacks:
         if is_param_in_hook_signature(callback.on_train_batch_start,
                                       "dataloader_idx",
                                       explicit=True):
             callback.on_train_batch_start(self, self.lightning_module,
                                           batch, batch_idx, 0)
         else:
             callback.on_train_batch_start(self, self.lightning_module,
                                           batch, batch_idx)
예제 #8
0
def _select_data_fetcher(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]:
    training_step_fx = getattr(trainer.lightning_module, "training_step")
    if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
        rank_zero_warn(
            "Found `dataloader_iter` argument in the `training_step`. Note that the support for "
            "this signature is experimental and the behavior is subject to change."
        )
        return DataLoaderIterDataFetcher
    elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1":
        if not isinstance(trainer.accelerator, CUDAAccelerator):
            raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.")
        return InterBatchParallelDataFetcher
    return DataFetcher
예제 #9
0
 def on_train_batch_end(self,
                        outputs: STEP_OUTPUT,
                        batch,
                        batch_idx,
                        dataloader_idx=0):
     """Called when the training batch ends."""
     for callback in self.callbacks:
         if is_param_in_hook_signature(callback.on_train_batch_end,
                                       "dataloader_idx",
                                       explicit=True):
             callback.on_train_batch_end(self, self.lightning_module,
                                         outputs, batch, batch_idx, 0)
         else:
             callback.on_train_batch_end(self, self.lightning_module,
                                         outputs, batch, batch_idx)
예제 #10
0
    def _should_add_batch_output_to_epoch_output(self) -> bool:
        # We add to the epoch outputs if
        # 1. The model defines training_epoch_end OR
        # 2. The model overrides on_train_epoch_end which has `outputs` in the signature
        # TODO: in v1.5 this only needs to check if training_epoch_end is overridden
        lightning_module = self.trainer.lightning_module
        if is_overridden("training_epoch_end", model=lightning_module):
            return True

        if is_overridden("on_train_epoch_end", model=lightning_module):
            model_hook_fx = getattr(lightning_module, "on_train_epoch_end")
            if is_param_in_hook_signature(model_hook_fx, "outputs"):
                return True

        return False
예제 #11
0
    def on_test_epoch_end(self, outputs: List[Any]):
        """Called when the epoch ends.

        Args:
            outputs: List of outputs on each ``test`` epoch
        """
        for callback in self.callbacks:
            if is_param_in_hook_signature(callback.on_test_epoch_end, "outputs"):
                callback.on_test_epoch_end(self, self.lightning_module, outputs)
            else:
                warning_cache.warn(
                    "`Callback.on_test_epoch_end` signature has changed in v1.3."
                    " `outputs` parameter has been added."
                    " Support for the old signature will be removed in v1.5", DeprecationWarning
                )
                callback.on_test_epoch_end(self, self.lightning_module)
예제 #12
0
    def on_train_epoch_end(self, outputs: EPOCH_OUTPUT):
        """Called when the epoch ends.

        Args:
            outputs: List of outputs on each ``train`` epoch
        """
        for callback in self.callbacks:
            if is_param_in_hook_signature(callback.on_train_epoch_end, "outputs"):
                warning_cache.deprecation(
                    "The signature of `Callback.on_train_epoch_end` has changed in v1.3."
                    " `outputs` parameter has been removed."
                    " Support for the old signature will be removed in v1.5"
                )
                callback.on_train_epoch_end(self, self.lightning_module, outputs)
            else:
                callback.on_train_epoch_end(self, self.lightning_module)
예제 #13
0
    def on_train_batch_start(self, batch, batch_idx, dataloader_idx=0):
        r"""
        .. deprecated:: v1.6
            `TrainerCallbackHookMixin.on_train_batch_start` was deprecated in v1.6 and will be removed in v1.8.

        Called when the training batch begins.
        """
        rank_zero_deprecation(
            "`TrainerCallbackHookMixin.on_train_batch_start` was deprecated in v1.6 and will be removed in v1.8."
        )
        for callback in self.callbacks:
            if is_param_in_hook_signature(callback.on_train_batch_start,
                                          "dataloader_idx",
                                          explicit=True):
                callback.on_train_batch_start(self, self.lightning_module,
                                              batch, batch_idx, 0)
            else:
                callback.on_train_batch_start(self, self.lightning_module,
                                              batch, batch_idx)
    def _on_train_epoch_end_hook(
            self, processed_epoch_output: List[List[STEP_OUTPUT]]) -> None:
        """Runs ``on_train_epoch_end hook``."""
        # We cannot rely on Trainer.call_hook because the signatures might be different across
        # lightning module and callback
        # As a result, we need to inspect if the module accepts `outputs` in `on_train_epoch_end`

        # This implementation is copied from Trainer.call_hook
        hook_name = "on_train_epoch_end"
        prev_fx_name = self.trainer.lightning_module._current_fx_name
        self.trainer.lightning_module._current_fx_name = hook_name

        # always profile hooks
        with self.trainer.profiler.profile(hook_name):

            # first call trainer hook
            if hasattr(self.trainer, hook_name):
                trainer_hook = getattr(self.trainer, hook_name)
                trainer_hook(processed_epoch_output)

            # next call hook in lightningModule
            model_ref = self.trainer.lightning_module
            if is_overridden(hook_name, model_ref):
                hook_fx = getattr(model_ref, hook_name)
                if is_param_in_hook_signature(hook_fx, "outputs"):
                    self.warning_cache.deprecation(
                        "The signature of `ModelHooks.on_train_epoch_end` has changed in v1.3."
                        " `outputs` parameter has been deprecated."
                        " Support for the old signature will be removed in v1.5",
                    )
                    model_ref.on_train_epoch_end(processed_epoch_output)
                else:
                    model_ref.on_train_epoch_end()

            # call the accelerator hook
            if hasattr(self.trainer.accelerator, hook_name):
                accelerator_hook = getattr(self.trainer.accelerator, hook_name)
                accelerator_hook()

        # restore current_fx when nested context
        self.trainer.lightning_module._current_fx_name = prev_fx_name
예제 #15
0
    def _on_train_epoch_end_hook(self, processed_epoch_output) -> None:
        # We cannot rely on Trainer.call_hook because the signatures might be different across
        # lightning module and callback
        # As a result, we need to inspect if the module accepts `outputs` in `on_train_epoch_end`

        # This implementation is copied from Trainer.call_hook
        hook_name = "on_train_epoch_end"

        # set hook_name to model + reset Result obj
        skip = self.trainer._reset_result_and_set_fx_name(hook_name)

        # always profile hooks
        with self.trainer.profiler.profile(hook_name):

            # first call trainer hook
            if hasattr(self.trainer, hook_name):
                trainer_hook = getattr(self.trainer, hook_name)
                trainer_hook(processed_epoch_output)

            # next call hook in lightningModule
            model_ref = self.trainer.lightning_module
            if is_overridden(hook_name, model_ref):
                hook_fx = getattr(model_ref, hook_name)
                if is_param_in_hook_signature(hook_fx, "outputs"):
                    self.warning_cache.warn(
                        "The signature of `ModelHooks.on_train_epoch_end` has changed in v1.3."
                        " `outputs` parameter has been deprecated."
                        " Support for the old signature will be removed in v1.5",
                        DeprecationWarning)
                    model_ref.on_train_epoch_end(processed_epoch_output)
                else:
                    model_ref.on_train_epoch_end()

            # if the PL module doesn't have the hook then call the accelerator
            # used to auto-reduce things for the user with Results obj
            elif hasattr(self.trainer.accelerator, hook_name):
                accelerator_hook = getattr(self.trainer.accelerator, hook_name)
                accelerator_hook()

        if not skip:
            self.trainer._cache_logged_metrics()
예제 #16
0
def _build_training_step_kwargs(
    kwargs: OrderedDict,
    lightning_module: "pl.LightningModule",
    optimizers: Sequence[Optimizer],
    opt_idx: Optional[int],
    hiddens: Optional[Any],
) -> OrderedDict:
    """Builds the keyword arguments for training_step.

    Args:
        kwargs: The kwargs passed down to the hooks.
        lightning_module: the LightningModule with a `training_step` hook implementation
        optimizers: the list of optimizers from the Trainer
        opt_idx: the index of the current optimizer
        hiddens: the hidden state of the previous RNN iteration

    Returns:
        the keyword arguments for the training step
    """
    training_step_fx = getattr(lightning_module, "training_step")
    if len(optimizers) > 1:
        has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx")
        if has_opt_idx_in_train_step:
            if not lightning_module.automatic_optimization:
                raise ValueError(
                    "Your `LightningModule.training_step` signature contains an `optimizer_idx` argument but"
                    " in manual optimization optimizers must be handled by the user. Remove the optimizer_idx"
                    " argument or set `self.automatic_optimization = True`."
                )
            kwargs["optimizer_idx"] = opt_idx
        elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization:
            raise ValueError(
                f"Your LightningModule defines {len(optimizers)} optimizers but"
                " `training_step` is missing the `optimizer_idx` argument."
            )

    # pass hiddens if using tbptt
    if lightning_module.truncated_bptt_steps > 0:
        kwargs["hiddens"] = hiddens

    return kwargs
    def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int,
                      hiddens: Optional[Tensor]) -> Dict[str, Any]:
        """Builds the keyword arguments for training_step

        Args:
            batch: the batch to train on
            batch_idx: the index of the current batch
            opt_idx: the index of the current optimizer
            hiddens: the hidden state of the previous RNN iteration

        Returns:
            the keyword arguments for the training step
        """
        # enable not needing to add opt_idx to training_step
        step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)])

        lightning_module = self.trainer.lightning_module

        if len(self.trainer.optimizers) > 1:
            training_step_fx = getattr(lightning_module, "training_step")
            has_opt_idx_in_train_step = is_param_in_hook_signature(
                training_step_fx, "optimizer_idx")
            if has_opt_idx_in_train_step:
                if not lightning_module.automatic_optimization:
                    self._warning_cache.deprecation(
                        "`training_step` hook signature has changed in v1.3."
                        " `optimizer_idx` argument has been removed in case of manual optimization. Support for"
                        " the old signature will be removed in v1.5")
                step_kwargs['optimizer_idx'] = opt_idx
            elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization:
                raise ValueError(
                    f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but"
                    ' `training_step` is missing the `optimizer_idx` argument.'
                )

        # pass hiddens if using tbptt
        if self._truncated_bptt_enabled():
            step_kwargs['hiddens'] = hiddens

        return step_kwargs
    def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Optional[Tensor]) -> Dict[str, Any]:
        """Builds the keyword arguments for training_step

        Args:
            batch: the batch to train on
            batch_idx: the index of the current batch
            opt_idx: the index of the current optimizer
            hiddens: the hidden state of the previous RNN iteration

        Returns:
            the keyword arguments for the training step
        """
        # enable not needing to add opt_idx to training_step
        step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)])

        lightning_module = self.trainer.lightning_module

        if len(self.trainer.optimizers) > 1:
            training_step_fx = getattr(lightning_module, "training_step")
            has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx")
            if has_opt_idx_in_train_step:
                if not lightning_module.automatic_optimization:
                    raise ValueError(
                        "Your `LightningModule.training_step` signature contains an `optimizer_idx` argument but"
                        " in manual optimization optimizers must be handled by the user. Remove the optimizer_idx"
                        " argument or set `self.automatic_optimization = True`."
                    )
                step_kwargs["optimizer_idx"] = opt_idx
            elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization:
                raise ValueError(
                    f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but"
                    " `training_step` is missing the `optimizer_idx` argument."
                )

        # pass hiddens if using tbptt
        if self.trainer.lightning_module.truncated_bptt_steps > 0:
            step_kwargs["hiddens"] = hiddens

        return step_kwargs
    def _build_kwargs(self, kwargs: OrderedDict, batch: Any,
                      batch_idx: int) -> OrderedDict:
        """Helper method to build the arguments for the current step.

        Args:
            kwargs: The kwargs passed down to the hooks.
            batch: The current batch to run through the step.
            batch_idx: The current batch idx.

        Returns:
            The kwargs passed down to the hooks.
        """
        kwargs["batch"] = batch
        training_step_fx = getattr(self.trainer.lightning_module,
                                   "training_step")
        # the `batch_idx` is optional, however, when there's more than 1 argument we cannot differentiate whether the
        # user wants the `batch_idx` or another key like `optimizer_idx` as we are not strict about the argument names
        if is_param_in_hook_signature(training_step_fx,
                                      "batch_idx",
                                      min_args=2):
            kwargs["batch_idx"] = batch_idx
        return kwargs
def __check_training_step_requires_dataloader_iter(model: "pl.LightningModule") -> None:
    """Check if the current `training_step` is requesting `dataloader_iter`."""
    training_step_fx = model.training_step
    if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):

        if is_overridden("on_train_batch_start", model):
            raise MisconfigurationException(
                "The model hook `on_train_batch_start` is not compatible with "
                "taking a `dataloader_iter` argument in your `training_step`."
            )

        if is_overridden("on_train_batch_end", model):
            raise MisconfigurationException(
                "The model hook `on_train_batch_end` is not compatible with "
                "taking a `dataloader_iter` argument in your `training_step`."
            )

        if model.truncated_bptt_steps > 0:
            raise MisconfigurationException(
                "The model taking a `dataloader_iter` argument in your `training_step` "
                "is incompatible with `truncated_bptt_steps > 0`."
            )
예제 #21
0
    def _get_input_array_copy(self, input_array: Optional[Any] = None) -> Any:
        """Returns a deep copy of the example input array in cases where it is expected that the input changes
        during the verification process.

        Arguments:
            input_array: The input to clone.
        """
        if input_array is None and isinstance(self.model, LightningModule):
            input_array = self.model.example_input_array
        input_array = deepcopy(input_array)

        if isinstance(self.model, LightningModule):
            kwargs = {}
            if is_param_in_hook_signature(self.model.transfer_batch_to_device,
                                          "dataloader_idx"):
                # Requires for Lightning 1.4 and above
                kwargs["dataloader_idx"] = 0
            input_array = self.model.transfer_batch_to_device(
                input_array, self.model.device, **kwargs)
        else:
            input_array = move_data_to_device(
                input_array, device=next(self.model.parameters()).device)

        return input_array
예제 #22
0
    def _select_data_fetcher(self) -> AbstractDataFetcher:
        if self.trainer.sanity_checking:
            return DataFetcher()

        training_step_fx = getattr(self.trainer.lightning_module,
                                   "training_step")
        if self.trainer.training and is_param_in_hook_signature(
                training_step_fx, "dataloader_iter", explicit=True):
            rank_zero_warn(
                "Found `dataloader_iter` argument in the `training_step`. Note that the support for "
                "this signature is experimental and the behavior is subject to change."
            )
            return DataLoaderIterDataFetcher()

        elif self.trainer.training and os.getenv("PL_INTER_BATCH_PARALLELISM",
                                                 "0") == "1":
            # note: this is an experimental feature
            if not self.trainer.training_type_plugin.on_gpu:
                raise MisconfigurationException(
                    "Inter batch parallelism is available only when using Nvidia GPUs."
                )
            return InterBatchParallelDataFetcher()

        return DataFetcher()
예제 #23
0
    def advance(self, *args: Any, **kwargs: Any) -> None:
        """Runs a single training batch.

        Args:
            dataloader_iter: the iterator over the dataloader producing the new batch

        Raises:
            StopIteration: When the epoch is canceled by the user returning -1
        """
        if self.restarting and self._should_check_val_fx(self.batch_idx, self.batch_progress.is_last_batch):
            # skip training and run validation in `on_advance_end`
            return

        batch_idx, (batch, self.batch_progress.is_last_batch) = next(self._dataloader_iter)

        if not self.trainer._data_connector.train_data_fetcher.store_on_device:
            with self.trainer.profiler.profile("training_batch_to_device"):
                batch = self.trainer.accelerator.batch_to_device(batch)

        self.batch_progress.increment_ready()

        # cache the batch size value to avoid extracting it again after the batch loop runs as the value will be
        # different if tbptt is enabled
        batch_size = self.trainer.logger_connector.on_batch_start(batch_idx, batch)

        if batch is None:
            self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
            batch_output = []
        else:
            # hook
            response = self.trainer.call_hook("on_batch_start")
            if response == -1:
                self.batch_progress.increment_processed()
                raise StopIteration

            # TODO: Update this in v1.7 (deprecation: #9816)
            model_fx = self.trainer.lightning_module.on_train_batch_start
            extra_kwargs = (
                {"dataloader_idx": 0}
                if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True)
                else {}
            )

            # hook
            response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
            if response == -1:
                self.batch_progress.increment_processed()
                raise StopIteration

            self.batch_progress.increment_started()

            with self.trainer.profiler.profile("run_training_batch"):
                batch_output = self.batch_loop.run(batch, batch_idx)

        self.trainer._results.batch_size = batch_size

        self.batch_progress.increment_processed()

        # update non-plateau LR schedulers
        # update epoch-interval ones only when we are at the end of training epoch
        self.update_lr_schedulers("step", update_plateau_schedulers=False)
        if self._num_ready_batches_reached():
            self.update_lr_schedulers("epoch", update_plateau_schedulers=False)

        batch_end_outputs = self._prepare_outputs_training_batch_end(
            batch_output,
            automatic=self.trainer.lightning_module.trainer.lightning_module.automatic_optimization,
            num_optimizers=len(self.trainer.optimizers),
        )

        # TODO: Update this in v1.7 (deprecation: #9816)
        model_fx = self.trainer.lightning_module.on_train_batch_end
        extra_kwargs = (
            {"dataloader_idx": 0}
            if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True)
            else {}
        )
        self.trainer.call_hook("on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs)
        self.trainer.call_hook("on_batch_end")
        self.trainer.logger_connector.on_batch_end()

        self.batch_progress.increment_completed()

        if is_overridden("training_epoch_end", self.trainer.lightning_module):
            self._outputs.append(batch_output)

        # -----------------------------------------
        # SAVE METRICS TO LOGGERS AND PROGRESS_BAR
        # -----------------------------------------
        self.trainer.logger_connector.update_train_step_metrics()
    def advance(
            self, data_fetcher: AbstractDataFetcher
    ) -> None:  # type: ignore[override]
        """Runs a single training batch.

        Raises:
            StopIteration: When the epoch is canceled by the user returning -1
        """
        if self.restarting and self._should_check_val_fx(
                self.batch_idx, self.batch_progress.is_last_batch):
            # skip training and run validation in `on_advance_end`
            return
        # we are going to train first so the val loop does not need to restart
        self.val_loop.restarting = False

        if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
            batch_idx = self.batch_idx + 1
            batch = next(data_fetcher)
        else:
            batch_idx, batch = next(data_fetcher)
        self.batch_progress.is_last_batch = data_fetcher.done

        self.batch_progress.increment_ready()

        self.trainer._logger_connector.on_batch_start(batch, batch_idx)

        if batch is None:
            self._warning_cache.warn(
                "train_dataloader yielded None. If this was on purpose, ignore this warning..."
            )
            batch_output = []
        else:
            # hook
            self.trainer._call_callback_hooks("on_batch_start")

            # TODO: Update this in v1.7 (deprecation: #9816)
            model_fx = self.trainer.lightning_module.on_train_batch_start
            extra_kwargs = ({
                "dataloader_idx": 0
            } if callable(model_fx) and is_param_in_hook_signature(
                model_fx, "dataloader_idx", explicit=True) else {})

            # hook
            self.trainer._call_callback_hooks("on_train_batch_start", batch,
                                              batch_idx, **extra_kwargs)
            response = self.trainer._call_lightning_module_hook(
                "on_train_batch_start", batch, batch_idx, **extra_kwargs)
            self.trainer._call_strategy_hook("on_train_batch_start", batch,
                                             batch_idx, **extra_kwargs)
            if response == -1:
                self.batch_progress.increment_processed()
                raise StopIteration

            self.batch_progress.increment_started()

            with self.trainer.profiler.profile("run_training_batch"):
                batch_output = self.batch_loop.run(batch, batch_idx)

        self.batch_progress.increment_processed()

        # update non-plateau LR schedulers
        # update epoch-interval ones only when we are at the end of training epoch
        self.update_lr_schedulers("step", update_plateau_schedulers=False)
        if self._num_ready_batches_reached():
            self.update_lr_schedulers("epoch", update_plateau_schedulers=False)

        batch_end_outputs = self._prepare_outputs_training_batch_end(
            batch_output,
            lightning_module=self.trainer.lightning_module,
            num_optimizers=len(self.trainer.optimizers),
        )

        # TODO: Update this in v1.7 (deprecation: #9816)
        model_fx = self.trainer.lightning_module.on_train_batch_end
        extra_kwargs = ({
            "dataloader_idx": 0
        } if callable(model_fx) and is_param_in_hook_signature(
            model_fx, "dataloader_idx", explicit=True) else {})
        self.trainer._call_callback_hooks("on_train_batch_end",
                                          batch_end_outputs, batch, batch_idx,
                                          **extra_kwargs)
        self.trainer._call_lightning_module_hook("on_train_batch_end",
                                                 batch_end_outputs, batch,
                                                 batch_idx, **extra_kwargs)
        self.trainer._call_callback_hooks("on_batch_end")
        self.trainer._logger_connector.on_batch_end()

        self.batch_progress.increment_completed()

        if is_overridden("training_epoch_end", self.trainer.lightning_module):
            self._outputs.append(batch_output)

        # -----------------------------------------
        # SAVE METRICS TO LOGGERS AND PROGRESS_BAR
        # -----------------------------------------
        self.trainer._logger_connector.update_train_step_metrics()
def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None:
    for callback in trainer.callbacks:
        if is_overridden(method_name="on_keyboard_interrupt",
                         instance=callback):
            rank_zero_deprecation(
                "The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7."
                " Please use the `on_exception` callback hook instead.")
        # TODO: Remove this in v1.7 (deprecation: #9816)
        for hook in ("on_train_batch_start", "on_train_batch_end"):
            if is_param_in_hook_signature(getattr(callback, hook),
                                          "dataloader_idx",
                                          explicit=True):
                rank_zero_deprecation(
                    f"Base `Callback.{hook}` hook signature has changed in v1.5."
                    " The `dataloader_idx` argument will be removed in v1.7.")
        if is_overridden(method_name="on_init_start", instance=callback):
            rank_zero_deprecation(
                "The `on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8."
            )
        if is_overridden(method_name="on_init_end", instance=callback):
            rank_zero_deprecation(
                "The `on_init_end` callback hook was deprecated in v1.6 and will be removed in v1.8."
            )

        if is_overridden(method_name="on_configure_sharded_model",
                         instance=callback):
            rank_zero_deprecation(
                "The `on_configure_sharded_model` callback hook was deprecated in"
                " v1.6 and will be removed in v1.8. Use `setup()` instead.")
        if is_overridden(method_name="on_before_accelerator_backend_setup",
                         instance=callback):
            rank_zero_deprecation(
                "The `on_before_accelerator_backend_setup` callback hook was deprecated in"
                " v1.6 and will be removed in v1.8. Use `setup()` instead.")
        if is_overridden(method_name="on_load_checkpoint", instance=callback):
            rank_zero_deprecation(
                f"`{callback.__class__.__name__}.on_load_checkpoint` will change its signature and behavior in v1.8."
                " If you wish to load the state of the callback, use `load_state_dict` instead."
                " In v1.8 `on_load_checkpoint(..., checkpoint)` will receive the entire loaded"
                " checkpoint dictionary instead of callback state.")

        for hook, alternative_hook in (
            ["on_batch_start", "on_train_batch_start"],
            ["on_batch_end", "on_train_batch_end"],
        ):
            if is_overridden(method_name=hook, instance=callback):
                rank_zero_deprecation(
                    f"The `Callback.{hook}` hook was deprecated in v1.6 and"
                    f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead."
                )
        for hook, alternative_hook in (
            ["on_epoch_start", "on_<train/validation/test>_epoch_start"],
            ["on_epoch_end", "on_<train/validation/test>_epoch_end"],
        ):
            if is_overridden(method_name=hook, instance=callback):
                rank_zero_deprecation(
                    f"The `Callback.{hook}` hook was deprecated in v1.6 and"
                    f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead."
                )
        for hook in ("on_pretrain_routine_start", "on_pretrain_routine_end"):
            if is_overridden(method_name=hook, instance=callback):
                rank_zero_deprecation(
                    f"The `Callback.{hook}` hook has been deprecated in v1.6 and"
                    " will be removed in v1.8. Please use `Callback.on_fit_start` instead."
                )