Example #1
0
 def check_fn(v: Tensor) -> Tensor:
     if v.grad_fn is not None:
         rank_zero_deprecation(
             f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"
             " but this behaviour will change in v1.6. Please detach it manually:"
             " `return {'loss': ..., 'something': something.detach()}`")
     return v
    def __verify_eval_loop_configuration(self, model: "pl.LightningModule",
                                         stage: str) -> None:
        loader_name = f"{stage}_dataloader"
        step_name = "validation_step" if stage == "val" else "test_step"

        has_loader = is_overridden(loader_name, model)
        has_step = is_overridden(step_name, model)

        if has_loader and not has_step:
            rank_zero_warn(
                f"you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop"
            )
        if has_step and not has_loader:
            rank_zero_warn(
                f"you defined a {step_name} but have no {loader_name}. Skipping {stage} loop"
            )

        # ----------------------------------------------
        # verify model does not have
        # - on_val_dataloader
        # - on_test_dataloader
        # ----------------------------------------------
        has_on_val_dataloader = is_overridden("on_val_dataloader", model)
        if has_on_val_dataloader:
            rank_zero_deprecation(
                "Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
                " Please use `val_dataloader()` directly.")

        has_on_test_dataloader = is_overridden("on_test_dataloader", model)
        if has_on_test_dataloader:
            rank_zero_deprecation(
                "Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
                " Please use `test_dataloader()` directly.")
def _check_on_configure_sharded_model(trainer: "pl.Trainer") -> None:
    for callback in trainer.callbacks:
        if is_overridden(method_name="on_configure_sharded_model",
                         instance=callback):
            rank_zero_deprecation(
                "The `on_configure_sharded_model` callback hook was deprecated in"
                " v1.6 and will be removed in v1.8. Use `setup()` instead.")
def rank_zero_deprecation(*args, stacklevel: int = 5, **kwargs):
    from pytorch_lightning.utilities.warnings import rank_zero_deprecation
    rank_zero_deprecation(
        '`pytorch_lightning.utilities.distributed.rank_zero_deprecation` has been moved to'
        ' `pytorch_lightning.utilities.rank_zero_deprecation` in v1.3.7 and will be removed in v1.6'
    )
    return rank_zero_deprecation(*args, stacklevel=stacklevel, **kwargs)
Example #5
0
def test_v1_8_0_rank_zero_imports():

    import warnings

    from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_info
    from pytorch_lightning.utilities.warnings import LightningDeprecationWarning, rank_zero_deprecation, rank_zero_warn

    with pytest.deprecated_call(
        match="pytorch_lightning.utilities.distributed.rank_zero_debug has been deprecated in v1.6"
        " and will be removed in v1.8."
    ):
        rank_zero_debug("foo")
    with pytest.deprecated_call(
        match="pytorch_lightning.utilities.distributed.rank_zero_info has been deprecated in v1.6"
        " and will be removed in v1.8."
    ):
        rank_zero_info("foo")
    with pytest.deprecated_call(
        match="pytorch_lightning.utilities.warnings.rank_zero_warn has been deprecated in v1.6"
        " and will be removed in v1.8."
    ):
        rank_zero_warn("foo")
    with pytest.deprecated_call(
        match="pytorch_lightning.utilities.warnings.rank_zero_deprecation has been deprecated in v1.6"
        " and will be removed in v1.8."
    ):
        rank_zero_deprecation("foo")
    with pytest.deprecated_call(
        match="pytorch_lightning.utilities.warnings.LightningDeprecationWarning has been deprecated in v1.6"
        " and will be removed in v1.8."
    ):
        warnings.warn("foo", LightningDeprecationWarning, stacklevel=5)
def prefix_metric_keys(metrics_dict: Dict[str, float],
                       prefix: str) -> Dict[str, float]:
    rank_zero_deprecation(
        "`pytorch_lightning.callbacks.device_stats_monitor.prefix_metrics`"
        " is deprecated in v1.6 and will be removed in v1.8.")
    sep = ""
    return _prefix_metric_keys(metrics_dict, prefix, sep)
Example #7
0
    def __init__(self, min_steps: Optional[int] = 0, max_steps: int = -1) -> None:
        super().__init__()
        if max_steps is None:
            rank_zero_deprecation(
                "Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7."
                " Use `max_steps = -1` instead."
            )
            max_steps = -1
        elif max_steps < -1:
            raise MisconfigurationException(
                f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {max_steps}."
            )
        self.min_steps = min_steps
        self.max_steps = max_steps

        self.global_step: int = 0
        self.batch_progress = BatchProgress()
        self.scheduler_progress = SchedulerProgress()

        self.batch_loop: Optional[TrainingBatchLoop] = None
        self.val_loop: Optional["loops.EvaluationLoop"] = None

        self._results = ResultCollection(training=True)
        self._outputs: _OUTPUTS_TYPE = []
        self._warning_cache = WarningCache()
        self._dataloader_iter: Optional[Iterator] = None
        # caches the loaded dataloader state until dataloader objects are available
        self._dataloader_state_dict: Dict[str, Any] = {}
Example #8
0
    def batch_to(data: Any) -> Any:
        # try to move torchtext data first
        if _TORCHTEXT_LEGACY and isinstance(data, Batch):
            # TODO: also remove the torchtext dependency with Lightning 1.8
            rank_zero_deprecation(
                "The `torchtext.legacy.Batch` object is deprecated and Lightning will remove support for it in v1.8."
                " We recommend you to migrate away from Batch by following the TorchText README:"
                " https://github.com/pytorch/text#bc-breaking-legacy")
            # Shallow copy because each Batch has a reference to Dataset which contains all examples
            device_data = copy(data)
            for field, field_value in data.dataset.fields.items():
                if field_value is None:
                    continue
                device_field = move_data_to_device(getattr(data, field),
                                                   device)
                setattr(device_data, field, device_field)
            return device_data

        kwargs = {}
        # Don't issue non-blocking transfers to CPU
        if isinstance(data, Tensor) and device not in _CPU_DEVICES:
            kwargs["non_blocking"] = True
        data_output = data.to(device, **kwargs)
        if data_output is not None:
            return data_output
        # user wrongly implemented the `TransferableDataType` and forgot to return `self`.
        return data
Example #9
0
 def __init__(self, *args,
              **kwargs) -> None:  # type: ignore[no-untyped-def]
     rank_zero_deprecation(
         "The `pytorch_lightning.loggers.base.DummyExperiment` is deprecated in v1.7"
         " and will be removed in v1.9. Please use `pytorch_lightning.loggers.logger.DummyExperiment` instead."
     )
     super().__init__(*args, **kwargs)
Example #10
0
def rank_zero_warn(*args: Any, stacklevel: int = 5, **kwargs: Any) -> None:
    from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn

    rank_zero_deprecation(
        "`pytorch_lightning.utilities.distributed.rank_zero_warn` has been moved to"
        " `pytorch_lightning.utilities.rank_zero_warn` in v1.3.7 and will be removed in v1.6"
    )
    return rank_zero_warn(*args, stacklevel=stacklevel, **kwargs)
def _check_on_keyboard_interrupt(trainer: "pl.Trainer") -> None:
    """Checks if on_keyboard_interrupt is overriden and sends a deprecation warning."""
    for callback in trainer.callbacks:
        if is_overridden(method_name="on_keyboard_interrupt",
                         instance=callback):
            rank_zero_deprecation(
                "The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7."
                " Please use the `on_exception` callback hook instead.")
Example #12
0
 def torch_distributed_backend(self) -> str:
     """Deprecated property."""
     rank_zero_deprecation(
         "ParallelStrategy.torch_distributed_backend was deprecated in v1.6 and will be removed in v1.8."
     )
     pg_backend = _get_process_group_backend_from_env()
     if pg_backend:
         return pg_backend
     return get_default_process_group_backend_for_device(self.root_device)
Example #13
0
    def on_predict_dataloader(self) -> None:
        """Called before requesting the predict dataloader.

        .. deprecated:: v1.5
            :meth:`on_predict_dataloader` is deprecated and will be removed in v1.7.0.
            Please use :meth:`predict_dataloader()` directly.
        """
        rank_zero_deprecation(
            "Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
            " Please use `predict_dataloader()` directly."
        )
def _check_on_init_start_end(trainer: "pl.Trainer") -> None:
    """Checks if on_init_start/end are overridden and sends a deprecation warning."""
    for callback in trainer.callbacks:
        if is_overridden(method_name="on_init_start", instance=callback):
            rank_zero_deprecation(
                "The `on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8."
            )
        if is_overridden(method_name="on_init_end", instance=callback):
            rank_zero_deprecation(
                "The `on_init_end` callback hook was deprecated in v1.6 and will be removed in v1.8."
            )
def _check_progress_bar(model: "pl.LightningModule") -> None:
    r"""
    Checks if get_progress_bar_dict is overriden and sends a deprecation warning.

    Args:
        model: The model to check the get_progress_bar_dict method.
    """
    if is_overridden("get_progress_bar_dict", model):
        rank_zero_deprecation(
            "The `LightningModule.get_progress_bar_dict` method was deprecated in v1.5 and will be removed in v1.7."
            " Please use the `ProgressBarBase.get_metrics` instead.")
Example #16
0
 def close(self) -> None:
     """
     .. deprecated:: v1.5
         This method is deprecated in v1.5 and will be removed in v1.7.
         Please use `LoggerCollection.finalize` instead.
     """
     rank_zero_deprecation(
         "`LoggerCollection.close` method is deprecated in v1.5 and will be removed in v1.7."
         " Please use `LoggerCollection.finalize` instead.")
     for logger in self._logger_iterable:
         logger.close()
def _check_on_hpc_hooks(model: "pl.LightningModule") -> None:
    if is_overridden("on_hpc_save", model):
        rank_zero_deprecation(
            "Method `LightningModule.on_hpc_save` is deprecated in v1.6 and"
            " will be removed in v1.8. Please use `LightningModule.on_save_checkpoint` instead."
        )

    if is_overridden("on_hpc_load", model):
        rank_zero_deprecation(
            "Method `LightningModule.on_hpc_load` is deprecated in v1.6 and"
            " will be removed in v1.8. Please use `LightningModule.on_load_checkpoint` instead."
        )
Example #18
0
def merge_dicts(
    dicts: Sequence[Mapping],
    agg_key_funcs: Optional[Mapping] = None,
    default_func: Callable[[Sequence[float]], float] = np.mean,
) -> Dict:
    rank_zero_deprecation(
        "The `pytorch_lightning.loggers.base.merge_dicts` is deprecated in v1.7"
        " and will be removed in v1.9. Please use `pytorch_lightning.loggers.logger.merge_dicts` instead."
    )
    return logger.merge_dicts(dicts=dicts,
                              agg_key_funcs=agg_key_funcs,
                              default_func=default_func)
def _check_on_batch_start_end(trainer: "pl.Trainer",
                              model: "pl.LightningModule") -> None:
    hooks = (["on_batch_start",
              "on_train_batch_start"], ["on_batch_end", "on_train_batch_end"])

    for hook, alternative_hook in hooks:
        for callback in trainer.callbacks:
            if is_overridden(method_name=hook, instance=callback):
                rank_zero_deprecation(
                    f"The `Callback.{hook}` hook was deprecated in v1.6 and"
                    f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead."
                )
def _check_on_post_move_to_device(model: "pl.LightningModule") -> None:
    r"""
    Checks if `on_post_move_to_device` method is overriden and sends a deprecation warning.

    Args:
        model: The model to check the `on_post_move_to_device` method.
    """
    if is_overridden("on_post_move_to_device", model):
        rank_zero_deprecation(
            "Method `on_post_move_to_device` has been deprecated in v1.5 and will be removed in v1.7. "
            "We perform automatic parameters tying without the need of implementing `on_post_move_to_device`."
        )
Example #21
0
    def close(self) -> None:
        """Do any cleanup that is necessary to close an experiment.

        See deprecation warning below.

        .. deprecated:: v1.5
            This method is deprecated in v1.5 and will be removed in v1.7.
            Please use `LightningLoggerBase.finalize` instead.
        """
        rank_zero_deprecation(
            "`LightningLoggerBase.close` method is deprecated in v1.5 and will be removed in v1.7."
            " Please use `LightningLoggerBase.finalize` instead.")
        self.save()
def _check_add_get_queue(model: "pl.LightningModule") -> None:
    r"""
    Checks if add_to_queue or get_from_queue is overriden and sends a deprecation warning.

    Args:
        model: The lightning module
    """
    if is_overridden("add_to_queue", model):
        rank_zero_deprecation(
            "The `LightningModule.add_to_queue` method was deprecated in v1.5 and will be removed in v1.7 in "
            "favor of `DDPSpawnStrategy.add_to_queue`")
    if is_overridden("get_from_queue", model):
        rank_zero_deprecation(
            "The `LightningModule.get_from_queue` method was deprecated in v1.5 and will be removed in v1.7 in "
            "favor of `DDPSpawnStrategy.get_from_queue`")
    def _configure_model_summary_callback(
            self,
            enable_model_summary: bool,
            weights_summary: Optional[str] = None) -> None:
        if weights_summary is None:
            rank_zero_deprecation(
                "Setting `Trainer(weights_summary=None)` is deprecated in v1.5 and will be removed"
                " in v1.7. Please set `Trainer(enable_model_summary=False)` instead."
            )
            return
        if not enable_model_summary:
            return

        model_summary_cbs = [
            type(cb) for cb in self.trainer.callbacks
            if isinstance(cb, ModelSummary)
        ]
        if model_summary_cbs:
            rank_zero_info(
                f"Trainer already configured with model summary callbacks: {model_summary_cbs}."
                " Skipping setting a default `ModelSummary` callback.")
            return

        if weights_summary == "top":
            # special case the default value for weights_summary to preserve backward compatibility
            max_depth = 1
        else:
            rank_zero_deprecation(
                f"Setting `Trainer(weights_summary={weights_summary})` is deprecated in v1.5 and will be removed"
                " in v1.7. Please pass `pytorch_lightning.callbacks.model_summary.ModelSummary` with"
                " `max_depth` directly to the Trainer's `callbacks` argument instead."
            )
            if weights_summary not in ModelSummaryMode.supported_types():
                raise MisconfigurationException(
                    f"`weights_summary` can be None, {', '.join(ModelSummaryMode.supported_types())}",
                    f" but got {weights_summary}",
                )
            max_depth = ModelSummaryMode.get_max_depth(weights_summary)

        is_progress_bar_rich = isinstance(self.trainer._progress_bar_callback,
                                          RichProgressBar)

        if self.trainer._progress_bar_callback is not None and is_progress_bar_rich:
            model_summary = RichModelSummary(max_depth=max_depth)
        else:
            model_summary = ModelSummary(max_depth=max_depth)
        self.trainer.callbacks.append(model_summary)
        self.trainer._weights_summary = weights_summary
 def __verify_predict_loop_configuration(
         self, model: "pl.LightningModule") -> None:
     has_predict_dataloader = is_overridden("predict_dataloader", model)
     if not has_predict_dataloader:
         raise MisconfigurationException(
             "Dataloader not found for `Trainer.predict`")
     # ----------------------------------------------
     # verify model does not have
     # - on_predict_dataloader
     # ----------------------------------------------
     has_on_predict_dataloader = is_overridden("on_predict_dataloader",
                                               model)
     if has_on_predict_dataloader:
         rank_zero_deprecation(
             "Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
             " Please use `predict_dataloader()` directly.")
def _check_dl_idx_in_on_train_batch_hooks(trainer: "pl.Trainer",
                                          model: "pl.LightningModule") -> None:
    for hook in ("on_train_batch_start", "on_train_batch_end"):
        if is_param_in_hook_signature(getattr(model, hook),
                                      "dataloader_idx",
                                      explicit=True):
            rank_zero_deprecation(
                f"Base `LightningModule.{hook}` hook signature has changed in v1.5."
                " The `dataloader_idx` argument will be removed in v1.7.")

        for cb in trainer.callbacks:
            if is_param_in_hook_signature(getattr(cb, hook),
                                          "dataloader_idx",
                                          explicit=True):
                rank_zero_deprecation(
                    f"Base `Callback.{hook}` hook signature has changed in v1.5."
                    " The `dataloader_idx` argument will be removed in v1.7.")
Example #26
0
    def on_trainer_init(
        self,
        callbacks: Optional[Union[List[Callback], Callback]],
        checkpoint_callback: bool,
        progress_bar_refresh_rate: Optional[int],
        process_position: int,
        default_root_dir: Optional[str],
        weights_save_path: Optional[str],
        stochastic_weight_avg: bool,
        max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None,
    ):
        # init folder paths for checkpoint + weights save callbacks
        self.trainer._default_root_dir = default_root_dir or os.getcwd()
        self.trainer._weights_save_path = weights_save_path or self.trainer._default_root_dir
        self.trainer._stochastic_weight_avg = stochastic_weight_avg

        # init callbacks
        if isinstance(callbacks, Callback):
            callbacks = [callbacks]
        self.trainer.callbacks = callbacks or []

        # configure checkpoint callback
        # pass through the required args to figure out defaults
        self._configure_checkpoint_callbacks(checkpoint_callback)

        # configure swa callback
        self._configure_swa_callbacks()

        # configure the timer callback.
        # responsible to stop the training when max_time is reached.
        self._configure_timer_callback(max_time)

        # init progress bar
        if process_position != 0:
            rank_zero_deprecation(
                f"Setting `Trainer(process_position={process_position})` is deprecated in v1.5 and will be removed"
                " in v1.7. Please pass `pytorch_lightning.callbacks.progress.ProgressBar` with"
                " `process_position` directly to the Trainer's `callbacks` argument instead."
            )
        self.trainer._progress_bar_callback = self.configure_progress_bar(
            progress_bar_refresh_rate, process_position)

        # push all checkpoint callbacks to the end
        # it is important that these are the last callbacks to run
        self.trainer.callbacks = self._reorder_callbacks(
            self.trainer.callbacks)
def __verify_eval_loop_configuration(trainer: "pl.Trainer",
                                     model: "pl.LightningModule",
                                     stage: str) -> None:
    loader_name = f"{stage}_dataloader"
    step_name = "validation_step" if stage == "val" else f"{stage}_step"
    trainer_method = "validate" if stage == "val" else stage
    on_eval_hook = f"on_{loader_name}"

    has_loader = getattr(trainer._data_connector,
                         f"_{stage}_dataloader_source").is_defined()
    has_step = is_overridden(step_name, model)
    has_on_eval_dataloader = is_overridden(on_eval_hook, model)

    # ----------------------------------------------
    # verify model does not have on_eval_dataloader
    # ----------------------------------------------
    if has_on_eval_dataloader:
        rank_zero_deprecation(
            f"Method `{on_eval_hook}` is deprecated in v1.5.0 and will"
            f" be removed in v1.7.0. Please use `{loader_name}()` directly.")

    # -----------------------------------
    # verify model has an eval_dataloader
    # -----------------------------------
    if not has_loader:
        raise MisconfigurationException(
            f"No `{loader_name}()` method defined to run `Trainer.{trainer_method}`."
        )

    # predict_step is not required to be overridden
    if stage == "predict":
        if model.predict_step is None:
            raise MisconfigurationException(
                "`predict_step` cannot be None to run `Trainer.predict`")
        elif not has_step and not is_overridden("forward", model):
            raise MisconfigurationException(
                "`Trainer.predict` requires `forward` method to run.")
    else:
        # -----------------------------------
        # verify model has an eval_step
        # -----------------------------------
        if not has_step:
            raise MisconfigurationException(
                f"No `{step_name}()` method defined to run `Trainer.{trainer_method}`."
            )
Example #28
0
 def on_trainer_init(
     self,
     logger: Union[bool, LightningLoggerBase, Iterable[LightningLoggerBase]],
     flush_logs_every_n_steps: Optional[int],
     log_every_n_steps: int,
     move_metrics_to_cpu: bool,
 ) -> None:
     self.configure_logger(logger)
     if flush_logs_every_n_steps is not None:
         rank_zero_deprecation(
             f"Setting `Trainer(flush_logs_every_n_steps={flush_logs_every_n_steps})` is deprecated in v1.5 "
             "and will be removed in v1.7. Please configure flushing in the logger instead."
         )
     else:
         flush_logs_every_n_steps = 100  # original default parameter
     self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
     self.trainer.log_every_n_steps = log_every_n_steps
     self.trainer.move_metrics_to_cpu = move_metrics_to_cpu
Example #29
0
 def __init__(self, trainer: "pl.Trainer", log_gpu_memory: Optional[str] = None) -> None:
     self.trainer = trainer
     if log_gpu_memory is not None:
         rank_zero_deprecation(
             "Setting `log_gpu_memory` with the trainer flag is deprecated in v1.5 and will be removed in v1.7. "
             "Please monitor GPU stats with the `DeviceStatsMonitor` callback directly instead."
         )
     self.log_gpu_memory = log_gpu_memory
     self.eval_loop_results: List[_OUT_DICT] = []
     self._val_log_step: int = 0
     self._test_log_step: int = 0
     self._progress_bar_metrics: _PBAR_DICT = {}
     self._logged_metrics: _OUT_DICT = {}
     self._callback_metrics: _OUT_DICT = {}
     self._gpus_metrics: Dict[str, float] = {}
     self._epoch_end_reached = False
     self._current_fx: Optional[str] = None
     self._batch_idx: Optional[int] = None
     self._split_idx: Optional[int] = None
    def _configure_checkpoint_callbacks(self,
                                        checkpoint_callback: Optional[bool],
                                        enable_checkpointing: bool) -> None:
        if checkpoint_callback is not None:
            rank_zero_deprecation(
                f"Setting `Trainer(checkpoint_callback={checkpoint_callback})` is deprecated in v1.5 and will "
                f"be removed in v1.7. Please consider using `Trainer(enable_checkpointing={checkpoint_callback})`."
            )
            # if both are set then checkpoint only if both are True
            enable_checkpointing = checkpoint_callback and enable_checkpointing

        if self._trainer_has_checkpoint_callbacks(
        ) and enable_checkpointing is False:
            raise MisconfigurationException(
                "Trainer was configured with `enable_checkpointing=False`"
                " but found `ModelCheckpoint` in callbacks list.")

        if not self._trainer_has_checkpoint_callbacks(
        ) and enable_checkpointing is True:
            self.trainer.callbacks.append(ModelCheckpoint())