Ejemplo n.º 1
0
    def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
        r"""
        .. deprecated:: v1.6
            `TrainerCallbackHookMixin.on_test_batch_start` was deprecated in v1.6 and will be removed in v1.8.

        Called when the test batch begins.
        """
        rank_zero_deprecation(
            "`TrainerCallbackHookMixin.on_test_batch_start` was deprecated in v1.6 and will be removed in v1.8."
        )
        for callback in self.callbacks:
            callback.on_test_batch_start(self, self.lightning_module, batch,
                                         batch_idx, dataloader_idx)
Ejemplo n.º 2
0
    def on_before_optimizer_step(self, optimizer, optimizer_idx):
        r"""
        .. deprecated:: v1.6
            `TrainerCallbackHookMixin.on_before_optimizer_step` was deprecated in v1.6 and will be removed in v1.8.

        Called after on_after_backward() once the gradient is accumulated and before optimizer.step().
        """
        rank_zero_deprecation(
            "`TrainerCallbackHookMixin.on_before_optimizer_step` was deprecated in v1.6 and will be removed in v1.8."
        )
        for callback in self.callbacks:
            callback.on_before_optimizer_step(self, self.lightning_module,
                                              optimizer, optimizer_idx)
Ejemplo n.º 3
0
    def close(self) -> None:
        """Do any cleanup that is necessary to close an experiment.

        See deprecation warning below.

        .. deprecated:: v1.5
            This method is deprecated in v1.5 and will be removed in v1.7.
            Please use `LightningLoggerBase.finalize` instead.
        """
        rank_zero_deprecation(
            "`LightningLoggerBase.close` method is deprecated in v1.5 and will be removed in v1.7."
            " Please use `LightningLoggerBase.finalize` instead.")
        self.save()
Ejemplo n.º 4
0
    def on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx):
        r"""
        .. deprecated:: v1.6
            `TrainerCallbackHookMixin.on_train_batch_end` was deprecated in v1.6 and will be removed in v1.8.

        Called when the training batch ends.
        """
        rank_zero_deprecation(
            "`TrainerCallbackHookMixin.on_train_batch_end` was deprecated in v1.6 and will be removed in v1.8."
        )
        for callback in self.callbacks:
            callback.on_train_batch_end(self, self.lightning_module, outputs,
                                        batch, batch_idx)
Ejemplo n.º 5
0
 def max_steps(self, value: int) -> None:
     """Sets the maximum number of steps (forwards to epoch_loop)"""
     # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided
     if value is None:
         rank_zero_deprecation(
             "Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7."
             " Use `max_steps = -1` instead.")
         value = -1
     elif value < -1:
         raise MisconfigurationException(
             f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {value}."
         )
     self.epoch_loop.max_steps = value
Ejemplo n.º 6
0
    def on_before_accelerator_backend_setup(self) -> None:
        r"""
        .. deprecated:: v1.6
            `TrainerCallbackHookMixin.on_before_accelerator_backend_setup` was deprecated in v1.6
            and will be removed in v1.8.

        Called at the beginning of fit (train + validate), validate, test, or predict, or tune.
        """
        rank_zero_deprecation(
            "`TrainerCallbackHookMixin.on_before_accelerator_backend_setup` was deprecated in v1.6 "
            "and will be removed in v1.8."
        )
        for callback in self.callbacks:
            callback.on_before_accelerator_backend_setup(self, self.lightning_module)
def _check_precision_plugin_checkpoint_hooks(trainer: "pl.Trainer") -> None:
    if is_overridden(method_name="on_save_checkpoint",
                     instance=trainer.precision_plugin,
                     parent=PrecisionPlugin):
        rank_zero_deprecation(
            "`PrecisionPlugin.on_save_checkpoint` was deprecated in"
            " v1.6 and will be removed in v1.8. Use `state_dict` instead.")
    if is_overridden(method_name="on_load_checkpoint",
                     instance=trainer.precision_plugin,
                     parent=PrecisionPlugin):
        rank_zero_deprecation(
            "`PrecisionPlugin.on_load_checkpoint` was deprecated in"
            " v1.6 and will be removed in v1.8. Use `load_state_dict` instead."
        )
Ejemplo n.º 8
0
 def __init__(self,
              trainer: "pl.Trainer",
              resume_from_checkpoint: Optional[_PATH] = None) -> None:
     self.trainer = trainer
     self.resume_checkpoint_path: Optional[_PATH] = None
     # TODO: remove resume_from_checkpoint_fit_path in v2.0
     self.resume_from_checkpoint_fit_path: Optional[
         _PATH] = resume_from_checkpoint
     if resume_from_checkpoint is not None:
         rank_zero_deprecation(
             "Setting `Trainer(resume_from_checkpoint=)` is deprecated in v1.5 and"
             " will be removed in v1.7. Please pass `Trainer.fit(ckpt_path=)` directly instead."
         )
     self._loaded_checkpoint: Dict[str, Any] = {}
Ejemplo n.º 9
0
    def on_predict_batch_end(self, outputs: STEP_OUTPUT, batch: Any,
                             batch_idx: int, dataloader_idx: int) -> None:
        r"""
        .. deprecated:: v1.6
            `TrainerCallbackHookMixin.on_predict_batch_end` was deprecated in v1.6 and will be removed in v1.8.

        Called when the predict batch ends.
        """
        rank_zero_deprecation(
            "`TrainerCallbackHookMixin.on_predict_batch_end` was deprecated in v1.6 and will be removed in v1.8."
        )
        for callback in self.callbacks:
            callback.on_predict_batch_end(self, self.lightning_module, outputs,
                                          batch, batch_idx, dataloader_idx)
Ejemplo n.º 10
0
    def save_checkpoint(self,
                        trainer: "pl.Trainer") -> None:  # pragma: no-cover
        """Performs the main logic around saving a checkpoint.

        This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the
        behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases.
        """
        rank_zero_deprecation(
            f"`{self.__class__.__name__}.save_checkpoint()` was deprecated in v1.6 and will be removed in v1.8."
            " Instead, you can use `trainer.save_checkpoint()` to manually save a checkpoint."
        )
        monitor_candidates = self._monitor_candidates(trainer)
        self._save_topk_checkpoint(trainer, monitor_candidates)
        self._save_last_checkpoint(trainer, monitor_candidates)
Ejemplo n.º 11
0
    def _configure_model_summary_callback(
            self,
            enable_model_summary: bool,
            weights_summary: Optional[str] = None) -> None:
        if weights_summary is None:
            rank_zero_deprecation(
                "Setting `Trainer(weights_summary=None)` is deprecated in v1.5 and will be removed"
                " in v1.7. Please set `Trainer(enable_model_summary=False)` instead."
            )
            return
        if not enable_model_summary:
            return

        model_summary_cbs = [
            type(cb) for cb in self.trainer.callbacks
            if isinstance(cb, ModelSummary)
        ]
        if model_summary_cbs:
            rank_zero_info(
                f"Trainer already configured with model summary callbacks: {model_summary_cbs}."
                " Skipping setting a default `ModelSummary` callback.")
            return

        if weights_summary == "top":
            # special case the default value for weights_summary to preserve backward compatibility
            max_depth = 1
        else:
            rank_zero_deprecation(
                f"Setting `Trainer(weights_summary={weights_summary})` is deprecated in v1.5 and will be removed"
                " in v1.7. Please pass `pytorch_lightning.callbacks.model_summary.ModelSummary` with"
                " `max_depth` directly to the Trainer's `callbacks` argument instead."
            )
            if weights_summary not in ModelSummaryMode.supported_types():
                raise MisconfigurationException(
                    f"`weights_summary` can be None, {', '.join(ModelSummaryMode.supported_types())}",
                    f" but got {weights_summary}",
                )
            max_depth = ModelSummaryMode.get_max_depth(weights_summary)

        progress_bar_callback = self.trainer.progress_bar_callback
        is_progress_bar_rich = isinstance(progress_bar_callback,
                                          RichProgressBar)

        if progress_bar_callback is not None and is_progress_bar_rich:
            model_summary = RichModelSummary(max_depth=max_depth)
        else:
            model_summary = ModelSummary(max_depth=max_depth)
        self.trainer.callbacks.append(model_summary)
        self.trainer._weights_summary = weights_summary
def _check_add_get_queue(model: "pl.LightningModule") -> None:
    r"""
    Checks if add_to_queue or get_from_queue is overridden and sends a deprecation warning.

    Args:
        model: The lightning module
    """
    if is_overridden("add_to_queue", model):
        rank_zero_deprecation(
            "The `LightningModule.add_to_queue` method was deprecated in v1.5 and will be removed in v1.7."
        )
    if is_overridden("get_from_queue", model):
        rank_zero_deprecation(
            "The `LightningModule.get_from_queue` method was deprecated in v1.5 and will be removed in v1.7."
        )
Ejemplo n.º 13
0
    def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]:
        r"""
        .. deprecated:: v1.6
            `TrainerCallbackHookMixin.on_save_checkpoint` was deprecated in v1.6 and will be removed in v1.8.

        Called when saving a model checkpoint.
        """
        rank_zero_deprecation(
            "`TrainerCallbackHookMixin.on_save_checkpoint` was deprecated in v1.6 and will be removed in v1.8."
        )
        callback_states = {}
        for callback in self.callbacks:
            state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)
            if state:
                callback_states[callback.state_key] = state
        return callback_states
    def _prepare_outputs_training_epoch_end(
        batch_outputs: _OUTPUTS_TYPE,
        lightning_module: "pl.LightningModule",
        num_optimizers: int,
    ) -> Union[List[List[List[Dict[str, Any]]]], List[List[Dict[str, Any]]],
               List[Dict[str, Any]]]:
        """Processes the outputs from the batch loop into the format passed to the ``training_epoch_end`` hook."""
        # `batch_outputs` (plural) is the same as `epoch_end_output` (singular)
        if not batch_outputs:
            return []

        # convert optimizer dicts to list
        if lightning_module.automatic_optimization:
            batch_outputs = apply_to_collection(batch_outputs,
                                                dtype=dict,
                                                function=_convert_optim_dict,
                                                num_optimizers=num_optimizers)

        array = _recursive_pad(batch_outputs)
        # TODO: remove in v1.8
        if (num_optimizers > 1 and lightning_module.truncated_bptt_steps > 0
                and
                not _v1_8_output_format(lightning_module.on_train_epoch_end)):
            rank_zero_deprecation(
                "You are training with multiple optimizers AND truncated backpropagation through time enabled."
                " The current format of the `training_epoch_end(outputs)` is a 3d list with sizes"
                " (n_optimizers, n_batches, tbptt_steps), however, this has been deprecated and will change in version"
                " v1.8 to (n_batches, tbptt_steps, n_optimizers). You can update your code by adding the following"
                " parameter to your hook signature: `training_epoch_end(outputs, new_format=True)`."
            )
            # (n_batches, tbptt_steps, n_opt) -> (n_opt, n_batches, tbptt_steps)
            if array.ndim == 2:
                array = np.expand_dims(array, 2)
            array = array.transpose((2, 0, 1))
        # squeeze all single-element dimensions
        array = array.squeeze()
        array = array.tolist()
        array = _recursive_unpad(array)
        # in case we squeezed from 1-element array to a 0-dim array
        array = array if isinstance(array, list) else [array]
        # remove residual empty lists
        array = [
            item for item in array if not isinstance(item, list) or len(item)
        ]
        return array
def __verify_eval_loop_configuration(trainer: "pl.Trainer",
                                     model: "pl.LightningModule",
                                     stage: str) -> None:
    loader_name = f"{stage}_dataloader"
    step_name = "validation_step" if stage == "val" else f"{stage}_step"
    trainer_method = "validate" if stage == "val" else stage
    on_eval_hook = f"on_{loader_name}"

    has_loader = getattr(trainer._data_connector,
                         f"_{stage}_dataloader_source").is_defined()
    has_step = is_overridden(step_name, model)
    has_on_eval_dataloader = is_overridden(on_eval_hook, model)

    # ----------------------------------------------
    # verify model does not have on_eval_dataloader
    # ----------------------------------------------
    if has_on_eval_dataloader:
        rank_zero_deprecation(
            f"Method `{on_eval_hook}` is deprecated in v1.5.0 and will"
            f" be removed in v1.7.0. Please use `{loader_name}()` directly.")

    # -----------------------------------
    # verify model has an eval_dataloader
    # -----------------------------------
    if not has_loader:
        raise MisconfigurationException(
            f"No `{loader_name}()` method defined to run `Trainer.{trainer_method}`."
        )

    # predict_step is not required to be overridden
    if stage == "predict":
        if model.predict_step is None:
            raise MisconfigurationException(
                "`predict_step` cannot be None to run `Trainer.predict`")
        elif not has_step and not is_overridden("forward", model):
            raise MisconfigurationException(
                "`Trainer.predict` requires `forward` method to run.")
    else:
        # -----------------------------------
        # verify model has an eval_step
        # -----------------------------------
        if not has_step:
            raise MisconfigurationException(
                f"No `{step_name}()` method defined to run `Trainer.{trainer_method}`."
            )
 def on_trainer_init(
     self,
     logger: Union[bool, Logger, Iterable[Logger]],
     log_every_n_steps: int,
     move_metrics_to_cpu: bool,
 ) -> None:
     self.configure_logger(logger)
     self.trainer.log_every_n_steps = log_every_n_steps
     self.trainer.move_metrics_to_cpu = move_metrics_to_cpu
     for logger in self.trainer.loggers:
         if is_overridden("agg_and_log_metrics", logger, Logger):
             self._override_agg_and_log_metrics = True
             rank_zero_deprecation(
                 "`Logger.agg_and_log_metrics` is deprecated in v1.6 and will be removed"
                 " in v1.8. `Trainer` will directly call `Logger.log_metrics` so custom"
                 " loggers should not implement `Logger.agg_and_log_metrics`."
             )
             break
Ejemplo n.º 17
0
    def _configure_checkpoint_callbacks(self,
                                        checkpoint_callback: Optional[bool],
                                        enable_checkpointing: bool) -> None:
        if checkpoint_callback is not None:
            rank_zero_deprecation(
                f"Setting `Trainer(checkpoint_callback={checkpoint_callback})` is deprecated in v1.5 and will "
                f"be removed in v1.7. Please consider using `Trainer(enable_checkpointing={checkpoint_callback})`."
            )
            # if both are set then checkpoint only if both are True
            enable_checkpointing = checkpoint_callback and enable_checkpointing

        if self.trainer.checkpoint_callbacks:
            if not enable_checkpointing:
                raise MisconfigurationException(
                    "Trainer was configured with `enable_checkpointing=False`"
                    " but found `ModelCheckpoint` in callbacks list.")
        elif enable_checkpointing:
            self.trainer.callbacks.append(ModelCheckpoint())
Ejemplo n.º 18
0
    def on_train_batch_start(self, batch, batch_idx, dataloader_idx=0):
        r"""
        .. deprecated:: v1.6
            `TrainerCallbackHookMixin.on_train_batch_start` was deprecated in v1.6 and will be removed in v1.8.

        Called when the training batch begins.
        """
        rank_zero_deprecation(
            "`TrainerCallbackHookMixin.on_train_batch_start` was deprecated in v1.6 and will be removed in v1.8."
        )
        for callback in self.callbacks:
            if is_param_in_hook_signature(callback.on_train_batch_start,
                                          "dataloader_idx",
                                          explicit=True):
                callback.on_train_batch_start(self, self.lightning_module,
                                              batch, batch_idx, 0)
            else:
                callback.on_train_batch_start(self, self.lightning_module,
                                              batch, batch_idx)
 def __init__(self,
              trainer: "pl.Trainer",
              log_gpu_memory: Optional[str] = None) -> None:
     self.trainer = trainer
     if log_gpu_memory is not None:
         rank_zero_deprecation(
             "Setting `log_gpu_memory` with the trainer flag is deprecated in v1.5 and will be removed in v1.7. "
             "Please monitor GPU stats with the `DeviceStatsMonitor` callback directly instead."
         )
     self.log_gpu_memory = log_gpu_memory
     self._val_log_step: int = 0
     self._test_log_step: int = 0
     self._progress_bar_metrics: _PBAR_DICT = {}
     self._logged_metrics: _OUT_DICT = {}
     self._callback_metrics: _OUT_DICT = {}
     self._gpus_metrics: Dict[str, float] = {}
     self._epoch_end_reached = False
     self._current_fx: Optional[str] = None
     self._batch_idx: Optional[int] = None
     self._split_idx: Optional[int] = None
     self._override_agg_and_log_metrics: bool = False
Ejemplo n.º 20
0
    def profile_iterable(self, iterable: Iterable, action_name: str) -> Generator:
        """Profiles over each value of an iterable.

        See deprecation message below.

        .. deprecated:: v1.6
            `Profiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8.
        """
        rank_zero_deprecation(
            f"`{self.__class__.__name__}.profile_iterable` is deprecated in v1.6 and will be removed in v1.8."
        )
        iterator = iter(iterable)
        while True:
            try:
                self.start(action_name)
                value = next(iterator)
                self.stop(action_name)
                yield value
            except StopIteration:
                self.stop(action_name)
                break
Ejemplo n.º 21
0
 def __init__(
     self,
     agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]],
                                                   float]]] = None,
     agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
 ):
     self._prev_step: int = -1
     self._metrics_to_agg: List[Dict[str, float]] = []
     if agg_key_funcs:
         self._agg_key_funcs = agg_key_funcs
         rank_zero_deprecation(
             "The `agg_key_funcs` parameter for `Logger` was deprecated in v1.6"
             " and will be removed in v1.8.")
     else:
         self._agg_key_funcs = {}
     if agg_default_func:
         self._agg_default_func = agg_default_func
         rank_zero_deprecation(
             "The `agg_default_func` parameter for `Logger` was deprecated in v1.6"
             " and will be removed in v1.8.")
     else:
         self._agg_default_func = np.mean
Ejemplo n.º 22
0
    def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        r"""
        .. deprecated:: v1.6
            `TrainerCallbackHookMixin.on_load_checkpoint` was deprecated in v1.6 and will be removed in v1.8.

        Called when loading a model checkpoint.
        """
        # Todo: the `callback_states` are dropped with TPUSpawn as they
        # can't be saved using `xm.save`
        # https://github.com/pytorch/xla/issues/2773
        rank_zero_deprecation(
            "`TrainerCallbackHookMixin.on_load_checkpoint` was deprecated in v1.6 and will be removed in v1.8."
        )
        callback_states: Dict[Union[Type, str],
                              Dict] = checkpoint.get("callbacks")

        if callback_states is None:
            return

        is_legacy_ckpt = Version(
            checkpoint["pytorch-lightning_version"]) < Version("1.5.0dev")
        current_callbacks_keys = {
            cb._legacy_state_key if is_legacy_ckpt else cb.state_key
            for cb in self.callbacks
        }
        difference = callback_states.keys() - current_callbacks_keys
        if difference:
            rank_zero_warn(
                "Be aware that when using `ckpt_path`,"
                " callbacks used to create the checkpoint need to be provided during `Trainer` instantiation."
                f" Please add the following callbacks: {list(difference)}.", )

        for callback in self.callbacks:
            state = callback_states.get(
                callback.state_key,
                callback_states.get(callback._legacy_state_key))
            if state:
                state = deepcopy(state)
                callback.on_load_checkpoint(self, self.lightning_module, state)
    def _prepare_outputs_training_batch_end(
        batch_output: _BATCH_OUTPUTS_TYPE,
        lightning_module: "pl.LightningModule",
        num_optimizers: int,
    ) -> Union[List[List[Dict[str, Any]]], List[Dict[str, Any]]]:
        """Processes the outputs from the batch loop into the format passed to the ``on_train_batch_end`` hook."""
        if not batch_output:
            return []

        # convert optimizer dicts to list
        if lightning_module.automatic_optimization:
            batch_output = apply_to_collection(batch_output,
                                               dtype=dict,
                                               function=_convert_optim_dict,
                                               num_optimizers=num_optimizers)

        array = np.array(batch_output, dtype=object)
        # TODO: remove in v1.8
        if (num_optimizers > 1 and lightning_module.truncated_bptt_steps > 0
                and
                not _v1_8_output_format(lightning_module.on_train_batch_end)):
            rank_zero_deprecation(
                "You are training with multiple optimizers AND truncated backpropagation through time enabled."
                " The current format of the `on_train_batch_end(outputs, ...)` is a 2d list with sizes"
                " (n_optimizers, tbptt_steps), however, this has been deprecated and will change in version v1.8 to"
                " (tbptt_steps, n_optimizers). You can update your code by adding the following parameter to your"
                " hook signature: `on_train_batch_end(outputs, ..., new_format=True)`."
            )
            # (tbptt_steps, n_opt) -> (n_opt, tbptt_steps)
            if array.ndim == 1:
                array = np.expand_dims(array, 1)
            array = array.transpose((1, 0))
        # squeeze all single-element dimensions
        array = array.squeeze()
        array = array.tolist()
        array = _recursive_unpad(array)
        return array
Ejemplo n.º 24
0
    def __init__(
        self,
        model_class: Optional[Union[Type[LightningModule],
                                    Callable[..., LightningModule]]] = None,
        datamodule_class: Optional[Union[Type[LightningDataModule], Callable[
            ..., LightningDataModule]]] = None,
        save_config_callback: Optional[
            Type[SaveConfigCallback]] = SaveConfigCallback,
        save_config_filename: str = "config.yaml",
        save_config_overwrite: bool = False,
        save_config_multifile: bool = False,
        trainer_class: Union[Type[Trainer], Callable[..., Trainer]] = Trainer,
        trainer_defaults: Optional[Dict[str, Any]] = None,
        seed_everything_default: Union[bool, int] = True,
        description: str = "pytorch-lightning trainer command line tool",
        env_prefix: str = "PL",
        env_parse: bool = False,
        parser_kwargs: Optional[Union[Dict[str, Any],
                                      Dict[str, Dict[str, Any]]]] = None,
        subclass_mode_model: bool = False,
        subclass_mode_data: bool = False,
        run: bool = True,
        auto_registry: bool = False,
    ) -> None:
        """Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which
        are called / instantiated using a parsed configuration file and / or command line args.

        Parsing of configuration from environment variables can be enabled by setting ``env_parse=True``.
        A full configuration yaml would be parsed from ``PL_CONFIG`` if set.
        Individual settings are so parsed from variables named for example ``PL_TRAINER__MAX_EPOCHS``.

        For more info, read :ref:`the CLI docs <lightning-cli>`.

        .. warning:: ``LightningCLI`` is in beta and subject to change.

        Args:
            model_class: An optional :class:`~pytorch_lightning.core.module.LightningModule` class to train on or a
                callable which returns a :class:`~pytorch_lightning.core.module.LightningModule` instance when
                called. If ``None``, you can pass a registered model with ``--model=MyModel``.
            datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class or a
                callable which returns a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` instance when
                called. If ``None``, you can pass a registered datamodule with ``--data=MyDataModule``.
            save_config_callback: A callback class to save the training config.
            save_config_filename: Filename for the config file.
            save_config_overwrite: Whether to overwrite an existing config file.
            save_config_multifile: When input is multiple config files, saved config preserves this structure.
            trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class or a
                callable which returns a :class:`~pytorch_lightning.trainer.trainer.Trainer` instance when called.
            trainer_defaults: Set to override Trainer defaults or add persistent callbacks. The callbacks added through
                this argument will not be configurable from a configuration file and will always be present for
                this particular CLI. Alternatively, configurable callbacks can be added as explained in
                :ref:`the CLI docs <lightning-cli>`.
            seed_everything_default: Value for the :func:`~pytorch_lightning.utilities.seed.seed_everything`
                seed argument. Set to True to automatically choose a valid seed.
                Setting it to False will not call seed_everything.
            description: Description of the tool shown when running ``--help``.
            env_prefix: Prefix for environment variables.
            env_parse: Whether environment variable parsing is enabled.
            parser_kwargs: Additional arguments to instantiate each ``LightningArgumentParser``.
            subclass_mode_model: Whether model can be any `subclass
                <https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_
                of the given class.
            subclass_mode_data: Whether datamodule can be any `subclass
                <https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_
                of the given class.
            run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer`
                method. If set to ``False``, the trainer and model classes will be instantiated only.
            auto_registry: Whether to automatically fill up the registries with all defined subclasses.
        """
        self.save_config_callback = save_config_callback
        self.save_config_filename = save_config_filename
        self.save_config_overwrite = save_config_overwrite
        self.save_config_multifile = save_config_multifile
        self.trainer_class = trainer_class
        self.trainer_defaults = trainer_defaults or {}
        self.seed_everything_default = seed_everything_default

        if self.seed_everything_default is None:
            rank_zero_deprecation(
                "Setting `LightningCLI.seed_everything_default` to `None` is deprecated in v1.7 "
                "and will be removed in v1.9. Set it to `False` instead.")
            self.seed_everything_default = False

        self.model_class = model_class
        # used to differentiate between the original value and the processed value
        self._model_class = model_class or LightningModule
        self.subclass_mode_model = (model_class is None) or subclass_mode_model

        self.datamodule_class = datamodule_class
        # used to differentiate between the original value and the processed value
        self._datamodule_class = datamodule_class or LightningDataModule
        self.subclass_mode_data = (datamodule_class is
                                   None) or subclass_mode_data

        from pytorch_lightning.utilities.cli import _populate_registries

        _populate_registries(auto_registry)

        main_kwargs, subparser_kwargs = self._setup_parser_kwargs(
            parser_kwargs
            or {},  # type: ignore  # github.com/python/mypy/issues/6463
            {
                "description": description,
                "env_prefix": env_prefix,
                "default_env": env_parse
            },
        )
        self.setup_parser(run, main_kwargs, subparser_kwargs)
        self.parse_arguments(self.parser)

        self.subcommand = self.config["subcommand"] if run else None

        self._set_seed()

        self.before_instantiate_classes()
        self.instantiate_classes()

        if self.subcommand is not None:
            self._run_subcommand(self.subcommand)
def __verify_train_val_loop_configuration(trainer: "pl.Trainer",
                                          model: "pl.LightningModule") -> None:
    # -----------------------------------
    # verify model has a training step
    # -----------------------------------
    has_training_step = is_overridden("training_step", model)
    if not has_training_step:
        raise MisconfigurationException(
            "No `training_step()` method defined. Lightning `Trainer` expects as minimum a"
            " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined."
        )

    # -----------------------------------
    # verify model has a train dataloader
    # -----------------------------------
    has_train_dataloader = trainer._data_connector._train_dataloader_source.is_defined(
    )
    if not has_train_dataloader:
        raise MisconfigurationException(
            "No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a"
            " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined."
        )

    # -----------------------------------
    # verify model has optimizer
    # -----------------------------------
    has_optimizers = is_overridden("configure_optimizers", model)
    if not has_optimizers:
        raise MisconfigurationException(
            "No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a"
            " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined."
        )

    # ----------------------------------------------
    # verify model does not have on_train_dataloader
    # ----------------------------------------------
    has_on_train_dataloader = is_overridden("on_train_dataloader", model)
    if has_on_train_dataloader:
        rank_zero_deprecation(
            "Method `on_train_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
            " Please use `train_dataloader()` directly.")

    trainer.overridden_optimizer_step = is_overridden("optimizer_step", model)
    trainer.overridden_optimizer_zero_grad = is_overridden(
        "optimizer_zero_grad", model)
    automatic_optimization = model.automatic_optimization
    going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches(
    )

    has_overridden_optimization_functions = trainer.overridden_optimizer_step or trainer.overridden_optimizer_zero_grad
    if has_overridden_optimization_functions and going_to_accumulate_grad_batches and automatic_optimization:
        rank_zero_warn(
            "When using `Trainer(accumulate_grad_batches != 1)` and overriding"
            " `LightningModule.optimizer_{step,zero_grad}`, the hooks will not be called on every batch"
            " (rather, they are called on every optimization step).")

    # -----------------------------------
    # verify model for val loop
    # -----------------------------------

    has_val_loader = trainer._data_connector._val_dataloader_source.is_defined(
    )
    has_val_step = is_overridden("validation_step", model)

    if has_val_loader and not has_val_step:
        rank_zero_warn(
            "You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop."
        )
    if has_val_step and not has_val_loader:
        rank_zero_warn(
            "You defined a `validation_step` but have no `val_dataloader`. Skipping val loop."
        )

    # ----------------------------------------------
    # verify model does not have on_val_dataloader
    # ----------------------------------------------
    has_on_val_dataloader = is_overridden("on_val_dataloader", model)
    if has_on_val_dataloader:
        rank_zero_deprecation(
            "Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
            " Please use `val_dataloader()` directly.")
def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None:
    for callback in trainer.callbacks:
        if is_overridden(method_name="on_init_start", instance=callback):
            rank_zero_deprecation(
                "The `on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8."
            )
        if is_overridden(method_name="on_init_end", instance=callback):
            rank_zero_deprecation("The `on_init_end` callback hook was deprecated in v1.6 and will be removed in v1.8.")

        if is_overridden(method_name="on_configure_sharded_model", instance=callback):
            rank_zero_deprecation(
                "The `on_configure_sharded_model` callback hook was deprecated in"
                " v1.6 and will be removed in v1.8. Use `setup()` instead."
            )
        if is_overridden(method_name="on_before_accelerator_backend_setup", instance=callback):
            rank_zero_deprecation(
                "The `on_before_accelerator_backend_setup` callback hook was deprecated in"
                " v1.6 and will be removed in v1.8. Use `setup()` instead."
            )
        if is_overridden(method_name="on_load_checkpoint", instance=callback):
            rank_zero_deprecation(
                f"`{callback.__class__.__name__}.on_load_checkpoint` will change its signature and behavior in v1.8."
                " If you wish to load the state of the callback, use `load_state_dict` instead."
                " In v1.8 `on_load_checkpoint(..., checkpoint)` will receive the entire loaded"
                " checkpoint dictionary instead of callback state."
            )

        for hook, alternative_hook in (
            ["on_batch_start", "on_train_batch_start"],
            ["on_batch_end", "on_train_batch_end"],
        ):
            if is_overridden(method_name=hook, instance=callback):
                rank_zero_deprecation(
                    f"The `Callback.{hook}` hook was deprecated in v1.6 and"
                    f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead."
                )
        for hook, alternative_hook in (
            ["on_epoch_start", "on_<train/validation/test>_epoch_start"],
            ["on_epoch_end", "on_<train/validation/test>_epoch_end"],
        ):
            if is_overridden(method_name=hook, instance=callback):
                rank_zero_deprecation(
                    f"The `Callback.{hook}` hook was deprecated in v1.6 and"
                    f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead."
                )
        for hook in ("on_pretrain_routine_start", "on_pretrain_routine_end"):
            if is_overridden(method_name=hook, instance=callback):
                rank_zero_deprecation(
                    f"The `Callback.{hook}` hook has been deprecated in v1.6 and"
                    " will be removed in v1.8. Please use `Callback.on_fit_start` instead."
                )
Ejemplo n.º 27
0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn

rank_zero_deprecation(
    "Using `pytorch_lightning.core.decorators.parameter_validation` is deprecated in v1.5, "
    "and will be removed in v1.7. It has been replaced by automatic parameters tying with "
    "`pytorch_lightning.utilities.params_tying.set_shared_parameters`")

from functools import wraps  # noqa: E402
from typing import Callable  # noqa: E402


def parameter_validation(fn: Callable) -> Callable:
    """Validates that the module parameter lengths match after moving to the device. It is useful when tying
    weights on TPU's.

    Args:
        fn: ``model_to_device`` method

    Note:
        TPU's require weights to be tied/shared after moving the module to the device.
Ejemplo n.º 28
0
 def __init__(self, *args, **kwargs):  # type: ignore[no-untyped-def]
     rank_zero_deprecation(
         "`BaseProfiler` was deprecated in v1.6 and will be removed in v1.8. Please use `Profiler` instead."
     )
     super().__init__(*args, **kwargs)
Ejemplo n.º 29
0
from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.warnings import WarningCache

standalone = os.getenv("PL_RUN_STANDALONE_TESTS", "0") == "1"
if standalone:

    stderr = StringIO()
    # recording
    with redirect_stderr(stderr):
        _warn("test1")
        _warn("test2", category=DeprecationWarning)

        rank_zero_warn("test3")
        rank_zero_warn("test4", category=DeprecationWarning)

        rank_zero_deprecation("test5")

        cache = WarningCache()
        cache.warn("test6")
        cache.deprecation("test7")

    output = stderr.getvalue()
    assert "test_warnings.py:31: UserWarning: test1" in output
    assert "test_warnings.py:32: DeprecationWarning: test2" in output

    assert "test_warnings.py:34: UserWarning: test3" in output
    assert "test_warnings.py:35: DeprecationWarning: test4" in output

    assert "test_warnings.py:37: LightningDeprecationWarning: test5" in output

    assert "test_warnings.py:40: UserWarning: test6" in output
Ejemplo n.º 30
0
 def _deprecation(self, show_deprecation: bool = True) -> None:
     if show_deprecation and not getattr(self, "deprecation_shown", False):
         rank_zero_deprecation(_deprecate_registry_message)
         self.deprecation_shown = True