def on_test_batch_start(self, batch, batch_idx, dataloader_idx): r""" .. deprecated:: v1.6 `TrainerCallbackHookMixin.on_test_batch_start` was deprecated in v1.6 and will be removed in v1.8. Called when the test batch begins. """ rank_zero_deprecation( "`TrainerCallbackHookMixin.on_test_batch_start` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_test_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx)
def on_before_optimizer_step(self, optimizer, optimizer_idx): r""" .. deprecated:: v1.6 `TrainerCallbackHookMixin.on_before_optimizer_step` was deprecated in v1.6 and will be removed in v1.8. Called after on_after_backward() once the gradient is accumulated and before optimizer.step(). """ rank_zero_deprecation( "`TrainerCallbackHookMixin.on_before_optimizer_step` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_before_optimizer_step(self, self.lightning_module, optimizer, optimizer_idx)
def close(self) -> None: """Do any cleanup that is necessary to close an experiment. See deprecation warning below. .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.7. Please use `LightningLoggerBase.finalize` instead. """ rank_zero_deprecation( "`LightningLoggerBase.close` method is deprecated in v1.5 and will be removed in v1.7." " Please use `LightningLoggerBase.finalize` instead.") self.save()
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx): r""" .. deprecated:: v1.6 `TrainerCallbackHookMixin.on_train_batch_end` was deprecated in v1.6 and will be removed in v1.8. Called when the training batch ends. """ rank_zero_deprecation( "`TrainerCallbackHookMixin.on_train_batch_end` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx)
def max_steps(self, value: int) -> None: """Sets the maximum number of steps (forwards to epoch_loop)""" # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided if value is None: rank_zero_deprecation( "Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7." " Use `max_steps = -1` instead.") value = -1 elif value < -1: raise MisconfigurationException( f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {value}." ) self.epoch_loop.max_steps = value
def on_before_accelerator_backend_setup(self) -> None: r""" .. deprecated:: v1.6 `TrainerCallbackHookMixin.on_before_accelerator_backend_setup` was deprecated in v1.6 and will be removed in v1.8. Called at the beginning of fit (train + validate), validate, test, or predict, or tune. """ rank_zero_deprecation( "`TrainerCallbackHookMixin.on_before_accelerator_backend_setup` was deprecated in v1.6 " "and will be removed in v1.8." ) for callback in self.callbacks: callback.on_before_accelerator_backend_setup(self, self.lightning_module)
def _check_precision_plugin_checkpoint_hooks(trainer: "pl.Trainer") -> None: if is_overridden(method_name="on_save_checkpoint", instance=trainer.precision_plugin, parent=PrecisionPlugin): rank_zero_deprecation( "`PrecisionPlugin.on_save_checkpoint` was deprecated in" " v1.6 and will be removed in v1.8. Use `state_dict` instead.") if is_overridden(method_name="on_load_checkpoint", instance=trainer.precision_plugin, parent=PrecisionPlugin): rank_zero_deprecation( "`PrecisionPlugin.on_load_checkpoint` was deprecated in" " v1.6 and will be removed in v1.8. Use `load_state_dict` instead." )
def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH] = None) -> None: self.trainer = trainer self.resume_checkpoint_path: Optional[_PATH] = None # TODO: remove resume_from_checkpoint_fit_path in v2.0 self.resume_from_checkpoint_fit_path: Optional[ _PATH] = resume_from_checkpoint if resume_from_checkpoint is not None: rank_zero_deprecation( "Setting `Trainer(resume_from_checkpoint=)` is deprecated in v1.5 and" " will be removed in v1.7. Please pass `Trainer.fit(ckpt_path=)` directly instead." ) self._loaded_checkpoint: Dict[str, Any] = {}
def on_predict_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int) -> None: r""" .. deprecated:: v1.6 `TrainerCallbackHookMixin.on_predict_batch_end` was deprecated in v1.6 and will be removed in v1.8. Called when the predict batch ends. """ rank_zero_deprecation( "`TrainerCallbackHookMixin.on_predict_batch_end` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: callback.on_predict_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx)
def save_checkpoint(self, trainer: "pl.Trainer") -> None: # pragma: no-cover """Performs the main logic around saving a checkpoint. This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases. """ rank_zero_deprecation( f"`{self.__class__.__name__}.save_checkpoint()` was deprecated in v1.6 and will be removed in v1.8." " Instead, you can use `trainer.save_checkpoint()` to manually save a checkpoint." ) monitor_candidates = self._monitor_candidates(trainer) self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates)
def _configure_model_summary_callback( self, enable_model_summary: bool, weights_summary: Optional[str] = None) -> None: if weights_summary is None: rank_zero_deprecation( "Setting `Trainer(weights_summary=None)` is deprecated in v1.5 and will be removed" " in v1.7. Please set `Trainer(enable_model_summary=False)` instead." ) return if not enable_model_summary: return model_summary_cbs = [ type(cb) for cb in self.trainer.callbacks if isinstance(cb, ModelSummary) ] if model_summary_cbs: rank_zero_info( f"Trainer already configured with model summary callbacks: {model_summary_cbs}." " Skipping setting a default `ModelSummary` callback.") return if weights_summary == "top": # special case the default value for weights_summary to preserve backward compatibility max_depth = 1 else: rank_zero_deprecation( f"Setting `Trainer(weights_summary={weights_summary})` is deprecated in v1.5 and will be removed" " in v1.7. Please pass `pytorch_lightning.callbacks.model_summary.ModelSummary` with" " `max_depth` directly to the Trainer's `callbacks` argument instead." ) if weights_summary not in ModelSummaryMode.supported_types(): raise MisconfigurationException( f"`weights_summary` can be None, {', '.join(ModelSummaryMode.supported_types())}", f" but got {weights_summary}", ) max_depth = ModelSummaryMode.get_max_depth(weights_summary) progress_bar_callback = self.trainer.progress_bar_callback is_progress_bar_rich = isinstance(progress_bar_callback, RichProgressBar) if progress_bar_callback is not None and is_progress_bar_rich: model_summary = RichModelSummary(max_depth=max_depth) else: model_summary = ModelSummary(max_depth=max_depth) self.trainer.callbacks.append(model_summary) self.trainer._weights_summary = weights_summary
def _check_add_get_queue(model: "pl.LightningModule") -> None: r""" Checks if add_to_queue or get_from_queue is overridden and sends a deprecation warning. Args: model: The lightning module """ if is_overridden("add_to_queue", model): rank_zero_deprecation( "The `LightningModule.add_to_queue` method was deprecated in v1.5 and will be removed in v1.7." ) if is_overridden("get_from_queue", model): rank_zero_deprecation( "The `LightningModule.get_from_queue` method was deprecated in v1.5 and will be removed in v1.7." )
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: r""" .. deprecated:: v1.6 `TrainerCallbackHookMixin.on_save_checkpoint` was deprecated in v1.6 and will be removed in v1.8. Called when saving a model checkpoint. """ rank_zero_deprecation( "`TrainerCallbackHookMixin.on_save_checkpoint` was deprecated in v1.6 and will be removed in v1.8." ) callback_states = {} for callback in self.callbacks: state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint) if state: callback_states[callback.state_key] = state return callback_states
def _prepare_outputs_training_epoch_end( batch_outputs: _OUTPUTS_TYPE, lightning_module: "pl.LightningModule", num_optimizers: int, ) -> Union[List[List[List[Dict[str, Any]]]], List[List[Dict[str, Any]]], List[Dict[str, Any]]]: """Processes the outputs from the batch loop into the format passed to the ``training_epoch_end`` hook.""" # `batch_outputs` (plural) is the same as `epoch_end_output` (singular) if not batch_outputs: return [] # convert optimizer dicts to list if lightning_module.automatic_optimization: batch_outputs = apply_to_collection(batch_outputs, dtype=dict, function=_convert_optim_dict, num_optimizers=num_optimizers) array = _recursive_pad(batch_outputs) # TODO: remove in v1.8 if (num_optimizers > 1 and lightning_module.truncated_bptt_steps > 0 and not _v1_8_output_format(lightning_module.on_train_epoch_end)): rank_zero_deprecation( "You are training with multiple optimizers AND truncated backpropagation through time enabled." " The current format of the `training_epoch_end(outputs)` is a 3d list with sizes" " (n_optimizers, n_batches, tbptt_steps), however, this has been deprecated and will change in version" " v1.8 to (n_batches, tbptt_steps, n_optimizers). You can update your code by adding the following" " parameter to your hook signature: `training_epoch_end(outputs, new_format=True)`." ) # (n_batches, tbptt_steps, n_opt) -> (n_opt, n_batches, tbptt_steps) if array.ndim == 2: array = np.expand_dims(array, 2) array = array.transpose((2, 0, 1)) # squeeze all single-element dimensions array = array.squeeze() array = array.tolist() array = _recursive_unpad(array) # in case we squeezed from 1-element array to a 0-dim array array = array if isinstance(array, list) else [array] # remove residual empty lists array = [ item for item in array if not isinstance(item, list) or len(item) ] return array
def __verify_eval_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule", stage: str) -> None: loader_name = f"{stage}_dataloader" step_name = "validation_step" if stage == "val" else f"{stage}_step" trainer_method = "validate" if stage == "val" else stage on_eval_hook = f"on_{loader_name}" has_loader = getattr(trainer._data_connector, f"_{stage}_dataloader_source").is_defined() has_step = is_overridden(step_name, model) has_on_eval_dataloader = is_overridden(on_eval_hook, model) # ---------------------------------------------- # verify model does not have on_eval_dataloader # ---------------------------------------------- if has_on_eval_dataloader: rank_zero_deprecation( f"Method `{on_eval_hook}` is deprecated in v1.5.0 and will" f" be removed in v1.7.0. Please use `{loader_name}()` directly.") # ----------------------------------- # verify model has an eval_dataloader # ----------------------------------- if not has_loader: raise MisconfigurationException( f"No `{loader_name}()` method defined to run `Trainer.{trainer_method}`." ) # predict_step is not required to be overridden if stage == "predict": if model.predict_step is None: raise MisconfigurationException( "`predict_step` cannot be None to run `Trainer.predict`") elif not has_step and not is_overridden("forward", model): raise MisconfigurationException( "`Trainer.predict` requires `forward` method to run.") else: # ----------------------------------- # verify model has an eval_step # ----------------------------------- if not has_step: raise MisconfigurationException( f"No `{step_name}()` method defined to run `Trainer.{trainer_method}`." )
def on_trainer_init( self, logger: Union[bool, Logger, Iterable[Logger]], log_every_n_steps: int, move_metrics_to_cpu: bool, ) -> None: self.configure_logger(logger) self.trainer.log_every_n_steps = log_every_n_steps self.trainer.move_metrics_to_cpu = move_metrics_to_cpu for logger in self.trainer.loggers: if is_overridden("agg_and_log_metrics", logger, Logger): self._override_agg_and_log_metrics = True rank_zero_deprecation( "`Logger.agg_and_log_metrics` is deprecated in v1.6 and will be removed" " in v1.8. `Trainer` will directly call `Logger.log_metrics` so custom" " loggers should not implement `Logger.agg_and_log_metrics`." ) break
def _configure_checkpoint_callbacks(self, checkpoint_callback: Optional[bool], enable_checkpointing: bool) -> None: if checkpoint_callback is not None: rank_zero_deprecation( f"Setting `Trainer(checkpoint_callback={checkpoint_callback})` is deprecated in v1.5 and will " f"be removed in v1.7. Please consider using `Trainer(enable_checkpointing={checkpoint_callback})`." ) # if both are set then checkpoint only if both are True enable_checkpointing = checkpoint_callback and enable_checkpointing if self.trainer.checkpoint_callbacks: if not enable_checkpointing: raise MisconfigurationException( "Trainer was configured with `enable_checkpointing=False`" " but found `ModelCheckpoint` in callbacks list.") elif enable_checkpointing: self.trainer.callbacks.append(ModelCheckpoint())
def on_train_batch_start(self, batch, batch_idx, dataloader_idx=0): r""" .. deprecated:: v1.6 `TrainerCallbackHookMixin.on_train_batch_start` was deprecated in v1.6 and will be removed in v1.8. Called when the training batch begins. """ rank_zero_deprecation( "`TrainerCallbackHookMixin.on_train_batch_start` was deprecated in v1.6 and will be removed in v1.8." ) for callback in self.callbacks: if is_param_in_hook_signature(callback.on_train_batch_start, "dataloader_idx", explicit=True): callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx, 0) else: callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx)
def __init__(self, trainer: "pl.Trainer", log_gpu_memory: Optional[str] = None) -> None: self.trainer = trainer if log_gpu_memory is not None: rank_zero_deprecation( "Setting `log_gpu_memory` with the trainer flag is deprecated in v1.5 and will be removed in v1.7. " "Please monitor GPU stats with the `DeviceStatsMonitor` callback directly instead." ) self.log_gpu_memory = log_gpu_memory self._val_log_step: int = 0 self._test_log_step: int = 0 self._progress_bar_metrics: _PBAR_DICT = {} self._logged_metrics: _OUT_DICT = {} self._callback_metrics: _OUT_DICT = {} self._gpus_metrics: Dict[str, float] = {} self._epoch_end_reached = False self._current_fx: Optional[str] = None self._batch_idx: Optional[int] = None self._split_idx: Optional[int] = None self._override_agg_and_log_metrics: bool = False
def profile_iterable(self, iterable: Iterable, action_name: str) -> Generator: """Profiles over each value of an iterable. See deprecation message below. .. deprecated:: v1.6 `Profiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8. """ rank_zero_deprecation( f"`{self.__class__.__name__}.profile_iterable` is deprecated in v1.6 and will be removed in v1.8." ) iterator = iter(iterable) while True: try: self.start(action_name) value = next(iterator) self.stop(action_name) yield value except StopIteration: self.stop(action_name) break
def __init__( self, agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, agg_default_func: Optional[Callable[[Sequence[float]], float]] = None, ): self._prev_step: int = -1 self._metrics_to_agg: List[Dict[str, float]] = [] if agg_key_funcs: self._agg_key_funcs = agg_key_funcs rank_zero_deprecation( "The `agg_key_funcs` parameter for `Logger` was deprecated in v1.6" " and will be removed in v1.8.") else: self._agg_key_funcs = {} if agg_default_func: self._agg_default_func = agg_default_func rank_zero_deprecation( "The `agg_default_func` parameter for `Logger` was deprecated in v1.6" " and will be removed in v1.8.") else: self._agg_default_func = np.mean
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: r""" .. deprecated:: v1.6 `TrainerCallbackHookMixin.on_load_checkpoint` was deprecated in v1.6 and will be removed in v1.8. Called when loading a model checkpoint. """ # Todo: the `callback_states` are dropped with TPUSpawn as they # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 rank_zero_deprecation( "`TrainerCallbackHookMixin.on_load_checkpoint` was deprecated in v1.6 and will be removed in v1.8." ) callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks") if callback_states is None: return is_legacy_ckpt = Version( checkpoint["pytorch-lightning_version"]) < Version("1.5.0dev") current_callbacks_keys = { cb._legacy_state_key if is_legacy_ckpt else cb.state_key for cb in self.callbacks } difference = callback_states.keys() - current_callbacks_keys if difference: rank_zero_warn( "Be aware that when using `ckpt_path`," " callbacks used to create the checkpoint need to be provided during `Trainer` instantiation." f" Please add the following callbacks: {list(difference)}.", ) for callback in self.callbacks: state = callback_states.get( callback.state_key, callback_states.get(callback._legacy_state_key)) if state: state = deepcopy(state) callback.on_load_checkpoint(self, self.lightning_module, state)
def _prepare_outputs_training_batch_end( batch_output: _BATCH_OUTPUTS_TYPE, lightning_module: "pl.LightningModule", num_optimizers: int, ) -> Union[List[List[Dict[str, Any]]], List[Dict[str, Any]]]: """Processes the outputs from the batch loop into the format passed to the ``on_train_batch_end`` hook.""" if not batch_output: return [] # convert optimizer dicts to list if lightning_module.automatic_optimization: batch_output = apply_to_collection(batch_output, dtype=dict, function=_convert_optim_dict, num_optimizers=num_optimizers) array = np.array(batch_output, dtype=object) # TODO: remove in v1.8 if (num_optimizers > 1 and lightning_module.truncated_bptt_steps > 0 and not _v1_8_output_format(lightning_module.on_train_batch_end)): rank_zero_deprecation( "You are training with multiple optimizers AND truncated backpropagation through time enabled." " The current format of the `on_train_batch_end(outputs, ...)` is a 2d list with sizes" " (n_optimizers, tbptt_steps), however, this has been deprecated and will change in version v1.8 to" " (tbptt_steps, n_optimizers). You can update your code by adding the following parameter to your" " hook signature: `on_train_batch_end(outputs, ..., new_format=True)`." ) # (tbptt_steps, n_opt) -> (n_opt, tbptt_steps) if array.ndim == 1: array = np.expand_dims(array, 1) array = array.transpose((1, 0)) # squeeze all single-element dimensions array = array.squeeze() array = array.tolist() array = _recursive_unpad(array) return array
def __init__( self, model_class: Optional[Union[Type[LightningModule], Callable[..., LightningModule]]] = None, datamodule_class: Optional[Union[Type[LightningDataModule], Callable[ ..., LightningDataModule]]] = None, save_config_callback: Optional[ Type[SaveConfigCallback]] = SaveConfigCallback, save_config_filename: str = "config.yaml", save_config_overwrite: bool = False, save_config_multifile: bool = False, trainer_class: Union[Type[Trainer], Callable[..., Trainer]] = Trainer, trainer_defaults: Optional[Dict[str, Any]] = None, seed_everything_default: Union[bool, int] = True, description: str = "pytorch-lightning trainer command line tool", env_prefix: str = "PL", env_parse: bool = False, parser_kwargs: Optional[Union[Dict[str, Any], Dict[str, Dict[str, Any]]]] = None, subclass_mode_model: bool = False, subclass_mode_data: bool = False, run: bool = True, auto_registry: bool = False, ) -> None: """Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are called / instantiated using a parsed configuration file and / or command line args. Parsing of configuration from environment variables can be enabled by setting ``env_parse=True``. A full configuration yaml would be parsed from ``PL_CONFIG`` if set. Individual settings are so parsed from variables named for example ``PL_TRAINER__MAX_EPOCHS``. For more info, read :ref:`the CLI docs <lightning-cli>`. .. warning:: ``LightningCLI`` is in beta and subject to change. Args: model_class: An optional :class:`~pytorch_lightning.core.module.LightningModule` class to train on or a callable which returns a :class:`~pytorch_lightning.core.module.LightningModule` instance when called. If ``None``, you can pass a registered model with ``--model=MyModel``. datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class or a callable which returns a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` instance when called. If ``None``, you can pass a registered datamodule with ``--data=MyDataModule``. save_config_callback: A callback class to save the training config. save_config_filename: Filename for the config file. save_config_overwrite: Whether to overwrite an existing config file. save_config_multifile: When input is multiple config files, saved config preserves this structure. trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class or a callable which returns a :class:`~pytorch_lightning.trainer.trainer.Trainer` instance when called. trainer_defaults: Set to override Trainer defaults or add persistent callbacks. The callbacks added through this argument will not be configurable from a configuration file and will always be present for this particular CLI. Alternatively, configurable callbacks can be added as explained in :ref:`the CLI docs <lightning-cli>`. seed_everything_default: Value for the :func:`~pytorch_lightning.utilities.seed.seed_everything` seed argument. Set to True to automatically choose a valid seed. Setting it to False will not call seed_everything. description: Description of the tool shown when running ``--help``. env_prefix: Prefix for environment variables. env_parse: Whether environment variable parsing is enabled. parser_kwargs: Additional arguments to instantiate each ``LightningArgumentParser``. subclass_mode_model: Whether model can be any `subclass <https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_ of the given class. subclass_mode_data: Whether datamodule can be any `subclass <https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_ of the given class. run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer` method. If set to ``False``, the trainer and model classes will be instantiated only. auto_registry: Whether to automatically fill up the registries with all defined subclasses. """ self.save_config_callback = save_config_callback self.save_config_filename = save_config_filename self.save_config_overwrite = save_config_overwrite self.save_config_multifile = save_config_multifile self.trainer_class = trainer_class self.trainer_defaults = trainer_defaults or {} self.seed_everything_default = seed_everything_default if self.seed_everything_default is None: rank_zero_deprecation( "Setting `LightningCLI.seed_everything_default` to `None` is deprecated in v1.7 " "and will be removed in v1.9. Set it to `False` instead.") self.seed_everything_default = False self.model_class = model_class # used to differentiate between the original value and the processed value self._model_class = model_class or LightningModule self.subclass_mode_model = (model_class is None) or subclass_mode_model self.datamodule_class = datamodule_class # used to differentiate between the original value and the processed value self._datamodule_class = datamodule_class or LightningDataModule self.subclass_mode_data = (datamodule_class is None) or subclass_mode_data from pytorch_lightning.utilities.cli import _populate_registries _populate_registries(auto_registry) main_kwargs, subparser_kwargs = self._setup_parser_kwargs( parser_kwargs or {}, # type: ignore # github.com/python/mypy/issues/6463 { "description": description, "env_prefix": env_prefix, "default_env": env_parse }, ) self.setup_parser(run, main_kwargs, subparser_kwargs) self.parse_arguments(self.parser) self.subcommand = self.config["subcommand"] if run else None self._set_seed() self.before_instantiate_classes() self.instantiate_classes() if self.subcommand is not None: self._run_subcommand(self.subcommand)
def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None: # ----------------------------------- # verify model has a training step # ----------------------------------- has_training_step = is_overridden("training_step", model) if not has_training_step: raise MisconfigurationException( "No `training_step()` method defined. Lightning `Trainer` expects as minimum a" " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined." ) # ----------------------------------- # verify model has a train dataloader # ----------------------------------- has_train_dataloader = trainer._data_connector._train_dataloader_source.is_defined( ) if not has_train_dataloader: raise MisconfigurationException( "No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a" " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined." ) # ----------------------------------- # verify model has optimizer # ----------------------------------- has_optimizers = is_overridden("configure_optimizers", model) if not has_optimizers: raise MisconfigurationException( "No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a" " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined." ) # ---------------------------------------------- # verify model does not have on_train_dataloader # ---------------------------------------------- has_on_train_dataloader = is_overridden("on_train_dataloader", model) if has_on_train_dataloader: rank_zero_deprecation( "Method `on_train_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0." " Please use `train_dataloader()` directly.") trainer.overridden_optimizer_step = is_overridden("optimizer_step", model) trainer.overridden_optimizer_zero_grad = is_overridden( "optimizer_zero_grad", model) automatic_optimization = model.automatic_optimization going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches( ) has_overridden_optimization_functions = trainer.overridden_optimizer_step or trainer.overridden_optimizer_zero_grad if has_overridden_optimization_functions and going_to_accumulate_grad_batches and automatic_optimization: rank_zero_warn( "When using `Trainer(accumulate_grad_batches != 1)` and overriding" " `LightningModule.optimizer_{step,zero_grad}`, the hooks will not be called on every batch" " (rather, they are called on every optimization step).") # ----------------------------------- # verify model for val loop # ----------------------------------- has_val_loader = trainer._data_connector._val_dataloader_source.is_defined( ) has_val_step = is_overridden("validation_step", model) if has_val_loader and not has_val_step: rank_zero_warn( "You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop." ) if has_val_step and not has_val_loader: rank_zero_warn( "You defined a `validation_step` but have no `val_dataloader`. Skipping val loop." ) # ---------------------------------------------- # verify model does not have on_val_dataloader # ---------------------------------------------- has_on_val_dataloader = is_overridden("on_val_dataloader", model) if has_on_val_dataloader: rank_zero_deprecation( "Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0." " Please use `val_dataloader()` directly.")
def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None: for callback in trainer.callbacks: if is_overridden(method_name="on_init_start", instance=callback): rank_zero_deprecation( "The `on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8." ) if is_overridden(method_name="on_init_end", instance=callback): rank_zero_deprecation("The `on_init_end` callback hook was deprecated in v1.6 and will be removed in v1.8.") if is_overridden(method_name="on_configure_sharded_model", instance=callback): rank_zero_deprecation( "The `on_configure_sharded_model` callback hook was deprecated in" " v1.6 and will be removed in v1.8. Use `setup()` instead." ) if is_overridden(method_name="on_before_accelerator_backend_setup", instance=callback): rank_zero_deprecation( "The `on_before_accelerator_backend_setup` callback hook was deprecated in" " v1.6 and will be removed in v1.8. Use `setup()` instead." ) if is_overridden(method_name="on_load_checkpoint", instance=callback): rank_zero_deprecation( f"`{callback.__class__.__name__}.on_load_checkpoint` will change its signature and behavior in v1.8." " If you wish to load the state of the callback, use `load_state_dict` instead." " In v1.8 `on_load_checkpoint(..., checkpoint)` will receive the entire loaded" " checkpoint dictionary instead of callback state." ) for hook, alternative_hook in ( ["on_batch_start", "on_train_batch_start"], ["on_batch_end", "on_train_batch_end"], ): if is_overridden(method_name=hook, instance=callback): rank_zero_deprecation( f"The `Callback.{hook}` hook was deprecated in v1.6 and" f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead." ) for hook, alternative_hook in ( ["on_epoch_start", "on_<train/validation/test>_epoch_start"], ["on_epoch_end", "on_<train/validation/test>_epoch_end"], ): if is_overridden(method_name=hook, instance=callback): rank_zero_deprecation( f"The `Callback.{hook}` hook was deprecated in v1.6 and" f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead." ) for hook in ("on_pretrain_routine_start", "on_pretrain_routine_end"): if is_overridden(method_name=hook, instance=callback): rank_zero_deprecation( f"The `Callback.{hook}` hook has been deprecated in v1.6 and" " will be removed in v1.8. Please use `Callback.on_fit_start` instead." )
# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn rank_zero_deprecation( "Using `pytorch_lightning.core.decorators.parameter_validation` is deprecated in v1.5, " "and will be removed in v1.7. It has been replaced by automatic parameters tying with " "`pytorch_lightning.utilities.params_tying.set_shared_parameters`") from functools import wraps # noqa: E402 from typing import Callable # noqa: E402 def parameter_validation(fn: Callable) -> Callable: """Validates that the module parameter lengths match after moving to the device. It is useful when tying weights on TPU's. Args: fn: ``model_to_device`` method Note: TPU's require weights to be tied/shared after moving the module to the device.
def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def] rank_zero_deprecation( "`BaseProfiler` was deprecated in v1.6 and will be removed in v1.8. Please use `Profiler` instead." ) super().__init__(*args, **kwargs)
from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.warnings import WarningCache standalone = os.getenv("PL_RUN_STANDALONE_TESTS", "0") == "1" if standalone: stderr = StringIO() # recording with redirect_stderr(stderr): _warn("test1") _warn("test2", category=DeprecationWarning) rank_zero_warn("test3") rank_zero_warn("test4", category=DeprecationWarning) rank_zero_deprecation("test5") cache = WarningCache() cache.warn("test6") cache.deprecation("test7") output = stderr.getvalue() assert "test_warnings.py:31: UserWarning: test1" in output assert "test_warnings.py:32: DeprecationWarning: test2" in output assert "test_warnings.py:34: UserWarning: test3" in output assert "test_warnings.py:35: DeprecationWarning: test4" in output assert "test_warnings.py:37: LightningDeprecationWarning: test5" in output assert "test_warnings.py:40: UserWarning: test6" in output
def _deprecation(self, show_deprecation: bool = True) -> None: if show_deprecation and not getattr(self, "deprecation_shown", False): rank_zero_deprecation(_deprecate_registry_message) self.deprecation_shown = True