def _build_training_step_kwargs( lightning_module: "pl.LightningModule", optimizers: Sequence[Optimizer], batch: Any, batch_idx: int, opt_idx: Optional[int], hiddens: Optional[Any], ) -> Dict[str, Any]: """Builds the keyword arguments for training_step. Args: lightning_module: the LightningModule with a `training_step` hook implementation optimizers: the list of optimizers from the Trainer batch: the batch to train on batch_idx: the index of the current batch opt_idx: the index of the current optimizer hiddens: the hidden state of the previous RNN iteration Returns: the keyword arguments for the training step """ # enable not needing to add opt_idx to training_step step_kwargs = OrderedDict([("batch", batch)]) training_step_fx = getattr(lightning_module, "training_step") if is_param_in_hook_signature(training_step_fx, "batch_idx", min_args=2): step_kwargs["batch_idx"] = batch_idx if len(optimizers) > 1: has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx") if has_opt_idx_in_train_step: if not lightning_module.automatic_optimization: raise ValueError( "Your `LightningModule.training_step` signature contains an `optimizer_idx` argument but" " in manual optimization optimizers must be handled by the user. Remove the optimizer_idx" " argument or set `self.automatic_optimization = True`." ) step_kwargs["optimizer_idx"] = opt_idx elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization: raise ValueError( f"Your LightningModule defines {len(optimizers)} optimizers but" " `training_step` is missing the `optimizer_idx` argument." ) # pass hiddens if using tbptt if lightning_module.truncated_bptt_steps > 0: step_kwargs["hiddens"] = hiddens return step_kwargs
def _build_kwargs(self, batch, batch_idx, opt_idx, hiddens): # enable not needing to add opt_idx to training_step step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)]) lightning_module = self.trainer.lightning_module if len(self.trainer.optimizers) > 1: training_step_fx = getattr(lightning_module, "training_step") has_opt_idx_in_train_step = is_param_in_hook_signature( training_step_fx, "optimizer_idx") if has_opt_idx_in_train_step: if not lightning_module.automatic_optimization: self.warning_cache.warn( "`training_step` hook signature has changed in v1.3." " `optimizer_idx` argument has been removed in case of manual optimization. Support for" " the old signature will be removed in v1.5", DeprecationWarning) step_kwargs['optimizer_idx'] = opt_idx elif not has_opt_idx_in_train_step and self.trainer.lightning_module.automatic_optimization: raise ValueError( f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but" ' `training_step` is missing the `optimizer_idx` argument.' ) # pass hiddens if using tbptt if self._truncated_bptt_enabled(): step_kwargs['hiddens'] = hiddens return step_kwargs
def on_evaluation_epoch_end( self, outputs: Union[List[List[Dict]], List[Dict]]) -> None: model_ref = self.trainer.lightning_module hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" self.trainer._reset_result_and_set_hook_fx_name(hook_name) with self.trainer.profiler.profile(hook_name): if hasattr(self.trainer, hook_name): on_evaluation_epoch_end_hook = getattr(self.trainer, hook_name) on_evaluation_epoch_end_hook(outputs) if is_overridden(hook_name, model_ref): model_hook_fx = getattr(model_ref, hook_name) if is_param_in_hook_signature(model_hook_fx, "outputs"): model_hook_fx(outputs) else: self.warning_cache.warn( f"`ModelHooks.{hook_name}` signature has changed in v1.3. `outputs` parameter has been added." " Support for the old signature will be removed in v1.5", DeprecationWarning) model_hook_fx() self.trainer._cache_logged_metrics() self.trainer.call_hook('on_epoch_end')
def _check_dl_idx_in_on_train_batch_hooks(trainer: "pl.Trainer", model: "pl.LightningModule") -> None: for hook in ("on_train_batch_start", "on_train_batch_end"): if is_param_in_hook_signature(getattr(model, hook), "dataloader_idx", explicit=True): rank_zero_deprecation( f"Base `LightningModule.{hook}` hook signature has changed in v1.5." " The `dataloader_idx` argument will be removed in v1.7.") for cb in trainer.callbacks: if is_param_in_hook_signature(getattr(cb, hook), "dataloader_idx", explicit=True): rank_zero_deprecation( f"Base `Callback.{hook}` hook signature has changed in v1.5." " The `dataloader_idx` argument will be removed in v1.7.")
def _check_setup_method(trainer: "pl.Trainer") -> None: for obj in [trainer.lightning_module, trainer.datamodule ] + trainer.callbacks: if is_overridden("setup", obj) and not is_param_in_hook_signature( obj.setup, "stage"): raise MisconfigurationException( f"`{obj.__class__.__name__}.setup` does not have a `stage` argument." )
def has_arg(self, f_name: str, arg_name: str) -> bool: rank_zero_deprecation( "Internal: TrainerModelHooksMixin.is_function_implemented is deprecated in v1.4" " and will be removed in v1.6." " Use `pytorch_lightning.utilities.signature_utils.is_param_in_hook_signature` instead." ) model = self.lightning_module f_op = getattr(model, f_name, None) if not f_op: return False return is_param_in_hook_signature(f_op, arg_name)
def on_train_batch_start(self, batch, batch_idx, dataloader_idx=0): """Called when the training batch begins.""" for callback in self.callbacks: if is_param_in_hook_signature(callback.on_train_batch_start, "dataloader_idx", explicit=True): callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx, 0) else: callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx)
def _select_data_fetcher(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]: training_step_fx = getattr(trainer.lightning_module, "training_step") if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True): rank_zero_warn( "Found `dataloader_iter` argument in the `training_step`. Note that the support for " "this signature is experimental and the behavior is subject to change." ) return DataLoaderIterDataFetcher elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1": if not isinstance(trainer.accelerator, CUDAAccelerator): raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.") return InterBatchParallelDataFetcher return DataFetcher
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx=0): """Called when the training batch ends.""" for callback in self.callbacks: if is_param_in_hook_signature(callback.on_train_batch_end, "dataloader_idx", explicit=True): callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, 0) else: callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx)
def _should_add_batch_output_to_epoch_output(self) -> bool: # We add to the epoch outputs if # 1. The model defines training_epoch_end OR # 2. The model overrides on_train_epoch_end which has `outputs` in the signature # TODO: in v1.5 this only needs to check if training_epoch_end is overridden lightning_module = self.trainer.lightning_module if is_overridden("training_epoch_end", model=lightning_module): return True if is_overridden("on_train_epoch_end", model=lightning_module): model_hook_fx = getattr(lightning_module, "on_train_epoch_end") if is_param_in_hook_signature(model_hook_fx, "outputs"): return True return False
def on_test_epoch_end(self, outputs: List[Any]): """Called when the epoch ends. Args: outputs: List of outputs on each ``test`` epoch """ for callback in self.callbacks: if is_param_in_hook_signature(callback.on_test_epoch_end, "outputs"): callback.on_test_epoch_end(self, self.lightning_module, outputs) else: warning_cache.warn( "`Callback.on_test_epoch_end` signature has changed in v1.3." " `outputs` parameter has been added." " Support for the old signature will be removed in v1.5", DeprecationWarning ) callback.on_test_epoch_end(self, self.lightning_module)
def on_train_epoch_end(self, outputs: EPOCH_OUTPUT): """Called when the epoch ends. Args: outputs: List of outputs on each ``train`` epoch """ for callback in self.callbacks: if is_param_in_hook_signature(callback.on_train_epoch_end, "outputs"): warning_cache.deprecation( "The signature of `Callback.on_train_epoch_end` has changed in v1.3." " `outputs` parameter has been removed." " Support for the old signature will be removed in v1.5" ) callback.on_train_epoch_end(self, self.lightning_module, outputs) else: callback.on_train_epoch_end(self, self.lightning_module)
def on_train_batch_start(self, batch, batch_idx, dataloader_idx=0): r""" .. deprecated:: v1.6 `TrainerCallbackHookMixin.on_train_batch_start` was deprecated in v1.6 and will be removed in v1.8. Called when the training batch begins. """ rank_zero_deprecation( "`TrainerCallbackHookMixin.on_train_batch_start` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: if is_param_in_hook_signature(callback.on_train_batch_start, "dataloader_idx", explicit=True): callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx, 0) else: callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx)
def _on_train_epoch_end_hook( self, processed_epoch_output: List[List[STEP_OUTPUT]]) -> None: """Runs ``on_train_epoch_end hook``.""" # We cannot rely on Trainer.call_hook because the signatures might be different across # lightning module and callback # As a result, we need to inspect if the module accepts `outputs` in `on_train_epoch_end` # This implementation is copied from Trainer.call_hook hook_name = "on_train_epoch_end" prev_fx_name = self.trainer.lightning_module._current_fx_name self.trainer.lightning_module._current_fx_name = hook_name # always profile hooks with self.trainer.profiler.profile(hook_name): # first call trainer hook if hasattr(self.trainer, hook_name): trainer_hook = getattr(self.trainer, hook_name) trainer_hook(processed_epoch_output) # next call hook in lightningModule model_ref = self.trainer.lightning_module if is_overridden(hook_name, model_ref): hook_fx = getattr(model_ref, hook_name) if is_param_in_hook_signature(hook_fx, "outputs"): self.warning_cache.deprecation( "The signature of `ModelHooks.on_train_epoch_end` has changed in v1.3." " `outputs` parameter has been deprecated." " Support for the old signature will be removed in v1.5", ) model_ref.on_train_epoch_end(processed_epoch_output) else: model_ref.on_train_epoch_end() # call the accelerator hook if hasattr(self.trainer.accelerator, hook_name): accelerator_hook = getattr(self.trainer.accelerator, hook_name) accelerator_hook() # restore current_fx when nested context self.trainer.lightning_module._current_fx_name = prev_fx_name
def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: # We cannot rely on Trainer.call_hook because the signatures might be different across # lightning module and callback # As a result, we need to inspect if the module accepts `outputs` in `on_train_epoch_end` # This implementation is copied from Trainer.call_hook hook_name = "on_train_epoch_end" # set hook_name to model + reset Result obj skip = self.trainer._reset_result_and_set_fx_name(hook_name) # always profile hooks with self.trainer.profiler.profile(hook_name): # first call trainer hook if hasattr(self.trainer, hook_name): trainer_hook = getattr(self.trainer, hook_name) trainer_hook(processed_epoch_output) # next call hook in lightningModule model_ref = self.trainer.lightning_module if is_overridden(hook_name, model_ref): hook_fx = getattr(model_ref, hook_name) if is_param_in_hook_signature(hook_fx, "outputs"): self.warning_cache.warn( "The signature of `ModelHooks.on_train_epoch_end` has changed in v1.3." " `outputs` parameter has been deprecated." " Support for the old signature will be removed in v1.5", DeprecationWarning) model_ref.on_train_epoch_end(processed_epoch_output) else: model_ref.on_train_epoch_end() # if the PL module doesn't have the hook then call the accelerator # used to auto-reduce things for the user with Results obj elif hasattr(self.trainer.accelerator, hook_name): accelerator_hook = getattr(self.trainer.accelerator, hook_name) accelerator_hook() if not skip: self.trainer._cache_logged_metrics()
def _build_training_step_kwargs( kwargs: OrderedDict, lightning_module: "pl.LightningModule", optimizers: Sequence[Optimizer], opt_idx: Optional[int], hiddens: Optional[Any], ) -> OrderedDict: """Builds the keyword arguments for training_step. Args: kwargs: The kwargs passed down to the hooks. lightning_module: the LightningModule with a `training_step` hook implementation optimizers: the list of optimizers from the Trainer opt_idx: the index of the current optimizer hiddens: the hidden state of the previous RNN iteration Returns: the keyword arguments for the training step """ training_step_fx = getattr(lightning_module, "training_step") if len(optimizers) > 1: has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx") if has_opt_idx_in_train_step: if not lightning_module.automatic_optimization: raise ValueError( "Your `LightningModule.training_step` signature contains an `optimizer_idx` argument but" " in manual optimization optimizers must be handled by the user. Remove the optimizer_idx" " argument or set `self.automatic_optimization = True`." ) kwargs["optimizer_idx"] = opt_idx elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization: raise ValueError( f"Your LightningModule defines {len(optimizers)} optimizers but" " `training_step` is missing the `optimizer_idx` argument." ) # pass hiddens if using tbptt if lightning_module.truncated_bptt_steps > 0: kwargs["hiddens"] = hiddens return kwargs
def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Optional[Tensor]) -> Dict[str, Any]: """Builds the keyword arguments for training_step Args: batch: the batch to train on batch_idx: the index of the current batch opt_idx: the index of the current optimizer hiddens: the hidden state of the previous RNN iteration Returns: the keyword arguments for the training step """ # enable not needing to add opt_idx to training_step step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)]) lightning_module = self.trainer.lightning_module if len(self.trainer.optimizers) > 1: training_step_fx = getattr(lightning_module, "training_step") has_opt_idx_in_train_step = is_param_in_hook_signature( training_step_fx, "optimizer_idx") if has_opt_idx_in_train_step: if not lightning_module.automatic_optimization: self._warning_cache.deprecation( "`training_step` hook signature has changed in v1.3." " `optimizer_idx` argument has been removed in case of manual optimization. Support for" " the old signature will be removed in v1.5") step_kwargs['optimizer_idx'] = opt_idx elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization: raise ValueError( f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but" ' `training_step` is missing the `optimizer_idx` argument.' ) # pass hiddens if using tbptt if self._truncated_bptt_enabled(): step_kwargs['hiddens'] = hiddens return step_kwargs
def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Optional[Tensor]) -> Dict[str, Any]: """Builds the keyword arguments for training_step Args: batch: the batch to train on batch_idx: the index of the current batch opt_idx: the index of the current optimizer hiddens: the hidden state of the previous RNN iteration Returns: the keyword arguments for the training step """ # enable not needing to add opt_idx to training_step step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)]) lightning_module = self.trainer.lightning_module if len(self.trainer.optimizers) > 1: training_step_fx = getattr(lightning_module, "training_step") has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx") if has_opt_idx_in_train_step: if not lightning_module.automatic_optimization: raise ValueError( "Your `LightningModule.training_step` signature contains an `optimizer_idx` argument but" " in manual optimization optimizers must be handled by the user. Remove the optimizer_idx" " argument or set `self.automatic_optimization = True`." ) step_kwargs["optimizer_idx"] = opt_idx elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization: raise ValueError( f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but" " `training_step` is missing the `optimizer_idx` argument." ) # pass hiddens if using tbptt if self.trainer.lightning_module.truncated_bptt_steps > 0: step_kwargs["hiddens"] = hiddens return step_kwargs
def _build_kwargs(self, kwargs: OrderedDict, batch: Any, batch_idx: int) -> OrderedDict: """Helper method to build the arguments for the current step. Args: kwargs: The kwargs passed down to the hooks. batch: The current batch to run through the step. batch_idx: The current batch idx. Returns: The kwargs passed down to the hooks. """ kwargs["batch"] = batch training_step_fx = getattr(self.trainer.lightning_module, "training_step") # the `batch_idx` is optional, however, when there's more than 1 argument we cannot differentiate whether the # user wants the `batch_idx` or another key like `optimizer_idx` as we are not strict about the argument names if is_param_in_hook_signature(training_step_fx, "batch_idx", min_args=2): kwargs["batch_idx"] = batch_idx return kwargs
def __check_training_step_requires_dataloader_iter(model: "pl.LightningModule") -> None: """Check if the current `training_step` is requesting `dataloader_iter`.""" training_step_fx = model.training_step if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True): if is_overridden("on_train_batch_start", model): raise MisconfigurationException( "The model hook `on_train_batch_start` is not compatible with " "taking a `dataloader_iter` argument in your `training_step`." ) if is_overridden("on_train_batch_end", model): raise MisconfigurationException( "The model hook `on_train_batch_end` is not compatible with " "taking a `dataloader_iter` argument in your `training_step`." ) if model.truncated_bptt_steps > 0: raise MisconfigurationException( "The model taking a `dataloader_iter` argument in your `training_step` " "is incompatible with `truncated_bptt_steps > 0`." )
def _get_input_array_copy(self, input_array: Optional[Any] = None) -> Any: """Returns a deep copy of the example input array in cases where it is expected that the input changes during the verification process. Arguments: input_array: The input to clone. """ if input_array is None and isinstance(self.model, LightningModule): input_array = self.model.example_input_array input_array = deepcopy(input_array) if isinstance(self.model, LightningModule): kwargs = {} if is_param_in_hook_signature(self.model.transfer_batch_to_device, "dataloader_idx"): # Requires for Lightning 1.4 and above kwargs["dataloader_idx"] = 0 input_array = self.model.transfer_batch_to_device( input_array, self.model.device, **kwargs) else: input_array = move_data_to_device( input_array, device=next(self.model.parameters()).device) return input_array
def _select_data_fetcher(self) -> AbstractDataFetcher: if self.trainer.sanity_checking: return DataFetcher() training_step_fx = getattr(self.trainer.lightning_module, "training_step") if self.trainer.training and is_param_in_hook_signature( training_step_fx, "dataloader_iter", explicit=True): rank_zero_warn( "Found `dataloader_iter` argument in the `training_step`. Note that the support for " "this signature is experimental and the behavior is subject to change." ) return DataLoaderIterDataFetcher() elif self.trainer.training and os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1": # note: this is an experimental feature if not self.trainer.training_type_plugin.on_gpu: raise MisconfigurationException( "Inter batch parallelism is available only when using Nvidia GPUs." ) return InterBatchParallelDataFetcher() return DataFetcher()
def advance(self, *args: Any, **kwargs: Any) -> None: """Runs a single training batch. Args: dataloader_iter: the iterator over the dataloader producing the new batch Raises: StopIteration: When the epoch is canceled by the user returning -1 """ if self.restarting and self._should_check_val_fx(self.batch_idx, self.batch_progress.is_last_batch): # skip training and run validation in `on_advance_end` return batch_idx, (batch, self.batch_progress.is_last_batch) = next(self._dataloader_iter) if not self.trainer._data_connector.train_data_fetcher.store_on_device: with self.trainer.profiler.profile("training_batch_to_device"): batch = self.trainer.accelerator.batch_to_device(batch) self.batch_progress.increment_ready() # cache the batch size value to avoid extracting it again after the batch loop runs as the value will be # different if tbptt is enabled batch_size = self.trainer.logger_connector.on_batch_start(batch_idx, batch) if batch is None: self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") batch_output = [] else: # hook response = self.trainer.call_hook("on_batch_start") if response == -1: self.batch_progress.increment_processed() raise StopIteration # TODO: Update this in v1.7 (deprecation: #9816) model_fx = self.trainer.lightning_module.on_train_batch_start extra_kwargs = ( {"dataloader_idx": 0} if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True) else {} ) # hook response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs) if response == -1: self.batch_progress.increment_processed() raise StopIteration self.batch_progress.increment_started() with self.trainer.profiler.profile("run_training_batch"): batch_output = self.batch_loop.run(batch, batch_idx) self.trainer._results.batch_size = batch_size self.batch_progress.increment_processed() # update non-plateau LR schedulers # update epoch-interval ones only when we are at the end of training epoch self.update_lr_schedulers("step", update_plateau_schedulers=False) if self._num_ready_batches_reached(): self.update_lr_schedulers("epoch", update_plateau_schedulers=False) batch_end_outputs = self._prepare_outputs_training_batch_end( batch_output, automatic=self.trainer.lightning_module.trainer.lightning_module.automatic_optimization, num_optimizers=len(self.trainer.optimizers), ) # TODO: Update this in v1.7 (deprecation: #9816) model_fx = self.trainer.lightning_module.on_train_batch_end extra_kwargs = ( {"dataloader_idx": 0} if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True) else {} ) self.trainer.call_hook("on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs) self.trainer.call_hook("on_batch_end") self.trainer.logger_connector.on_batch_end() self.batch_progress.increment_completed() if is_overridden("training_epoch_end", self.trainer.lightning_module): self._outputs.append(batch_output) # ----------------------------------------- # SAVE METRICS TO LOGGERS AND PROGRESS_BAR # ----------------------------------------- self.trainer.logger_connector.update_train_step_metrics()
def advance( self, data_fetcher: AbstractDataFetcher ) -> None: # type: ignore[override] """Runs a single training batch. Raises: StopIteration: When the epoch is canceled by the user returning -1 """ if self.restarting and self._should_check_val_fx( self.batch_idx, self.batch_progress.is_last_batch): # skip training and run validation in `on_advance_end` return # we are going to train first so the val loop does not need to restart self.val_loop.restarting = False if not isinstance(data_fetcher, DataLoaderIterDataFetcher): batch_idx = self.batch_idx + 1 batch = next(data_fetcher) else: batch_idx, batch = next(data_fetcher) self.batch_progress.is_last_batch = data_fetcher.done self.batch_progress.increment_ready() self.trainer._logger_connector.on_batch_start(batch, batch_idx) if batch is None: self._warning_cache.warn( "train_dataloader yielded None. If this was on purpose, ignore this warning..." ) batch_output = [] else: # hook self.trainer._call_callback_hooks("on_batch_start") # TODO: Update this in v1.7 (deprecation: #9816) model_fx = self.trainer.lightning_module.on_train_batch_start extra_kwargs = ({ "dataloader_idx": 0 } if callable(model_fx) and is_param_in_hook_signature( model_fx, "dataloader_idx", explicit=True) else {}) # hook self.trainer._call_callback_hooks("on_train_batch_start", batch, batch_idx, **extra_kwargs) response = self.trainer._call_lightning_module_hook( "on_train_batch_start", batch, batch_idx, **extra_kwargs) self.trainer._call_strategy_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs) if response == -1: self.batch_progress.increment_processed() raise StopIteration self.batch_progress.increment_started() with self.trainer.profiler.profile("run_training_batch"): batch_output = self.batch_loop.run(batch, batch_idx) self.batch_progress.increment_processed() # update non-plateau LR schedulers # update epoch-interval ones only when we are at the end of training epoch self.update_lr_schedulers("step", update_plateau_schedulers=False) if self._num_ready_batches_reached(): self.update_lr_schedulers("epoch", update_plateau_schedulers=False) batch_end_outputs = self._prepare_outputs_training_batch_end( batch_output, lightning_module=self.trainer.lightning_module, num_optimizers=len(self.trainer.optimizers), ) # TODO: Update this in v1.7 (deprecation: #9816) model_fx = self.trainer.lightning_module.on_train_batch_end extra_kwargs = ({ "dataloader_idx": 0 } if callable(model_fx) and is_param_in_hook_signature( model_fx, "dataloader_idx", explicit=True) else {}) self.trainer._call_callback_hooks("on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs) self.trainer._call_lightning_module_hook("on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs) self.trainer._call_callback_hooks("on_batch_end") self.trainer._logger_connector.on_batch_end() self.batch_progress.increment_completed() if is_overridden("training_epoch_end", self.trainer.lightning_module): self._outputs.append(batch_output) # ----------------------------------------- # SAVE METRICS TO LOGGERS AND PROGRESS_BAR # ----------------------------------------- self.trainer._logger_connector.update_train_step_metrics()
def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None: for callback in trainer.callbacks: if is_overridden(method_name="on_keyboard_interrupt", instance=callback): rank_zero_deprecation( "The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7." " Please use the `on_exception` callback hook instead.") # TODO: Remove this in v1.7 (deprecation: #9816) for hook in ("on_train_batch_start", "on_train_batch_end"): if is_param_in_hook_signature(getattr(callback, hook), "dataloader_idx", explicit=True): rank_zero_deprecation( f"Base `Callback.{hook}` hook signature has changed in v1.5." " The `dataloader_idx` argument will be removed in v1.7.") if is_overridden(method_name="on_init_start", instance=callback): rank_zero_deprecation( "The `on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8." ) if is_overridden(method_name="on_init_end", instance=callback): rank_zero_deprecation( "The `on_init_end` callback hook was deprecated in v1.6 and will be removed in v1.8." ) if is_overridden(method_name="on_configure_sharded_model", instance=callback): rank_zero_deprecation( "The `on_configure_sharded_model` callback hook was deprecated in" " v1.6 and will be removed in v1.8. Use `setup()` instead.") if is_overridden(method_name="on_before_accelerator_backend_setup", instance=callback): rank_zero_deprecation( "The `on_before_accelerator_backend_setup` callback hook was deprecated in" " v1.6 and will be removed in v1.8. Use `setup()` instead.") if is_overridden(method_name="on_load_checkpoint", instance=callback): rank_zero_deprecation( f"`{callback.__class__.__name__}.on_load_checkpoint` will change its signature and behavior in v1.8." " If you wish to load the state of the callback, use `load_state_dict` instead." " In v1.8 `on_load_checkpoint(..., checkpoint)` will receive the entire loaded" " checkpoint dictionary instead of callback state.") for hook, alternative_hook in ( ["on_batch_start", "on_train_batch_start"], ["on_batch_end", "on_train_batch_end"], ): if is_overridden(method_name=hook, instance=callback): rank_zero_deprecation( f"The `Callback.{hook}` hook was deprecated in v1.6 and" f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead." ) for hook, alternative_hook in ( ["on_epoch_start", "on_<train/validation/test>_epoch_start"], ["on_epoch_end", "on_<train/validation/test>_epoch_end"], ): if is_overridden(method_name=hook, instance=callback): rank_zero_deprecation( f"The `Callback.{hook}` hook was deprecated in v1.6 and" f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead." ) for hook in ("on_pretrain_routine_start", "on_pretrain_routine_end"): if is_overridden(method_name=hook, instance=callback): rank_zero_deprecation( f"The `Callback.{hook}` hook has been deprecated in v1.6 and" " will be removed in v1.8. Please use `Callback.on_fit_start` instead." )