def __init__(
     self,
     parallel_devices: Optional[List[torch.device]] = None,
     num_nodes: Optional[int] = None,
     cluster_environment: ClusterEnvironment = None,
     sync_batchnorm: Optional[bool] = None,
     ddp_comm_state: Optional[object] = None,
     ddp_comm_hook: Optional[callable] = None,
     ddp_comm_wrapper: Optional[callable] = None,
     **kwargs: Any,
 ):
     super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
     if num_nodes is not None:
         rank_zero_deprecation(
             "Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6. "
             "Notice that it will be overriden by the trainer setting."
         )
     self._num_nodes = num_nodes or 1
     if sync_batchnorm is not None:
         rank_zero_deprecation(
             "Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6. "
             "Notice that it will be overriden by the trainer setting."
         )
     self._sync_batchnorm = sync_batchnorm or False
     self._ddp_kwargs = kwargs
     self.dist = LightningDistributed()
     self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
     self.mp_queue = None
     self._ddp_comm_state = ddp_comm_state
     self._ddp_comm_hook = ddp_comm_hook
     self._ddp_comm_wrapper = ddp_comm_wrapper
     self._local_rank = 0
     self.set_world_ranks()
Beispiel #2
0
 def metrics_to_scalars(self, metrics: dict) -> dict:
     rank_zero_deprecation(
         "Internal: TrainerLoggingMixin.metrics_to_scalars is deprecated in v1.3"
         " and will be removed in v1.5."
         " Use `pytorch_lightning.utilities.metrics.metrics_to_scalars` instead."
     )
     return new_metrics_to_scalars(metrics)
        def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any:
            name = fn.__name__
            has_run = False

            # If calling setup, we check the stage and assign stage-specific bool args
            if name in ("setup", "teardown"):

                # Get stage either by grabbing from args or checking kwargs.
                # If not provided, set call status of 'fit', 'validate', and 'test' to True.
                # We do this so __attach_datamodule in trainer.py doesn't mistakenly call
                # setup('test') on trainer.test()
                stage = args[0] if len(args) else kwargs.get("stage", None)

                if stage is None:
                    has_run = True
                    for s in ("fit", "validate", "test"):
                        attr = f"_has_{name}_{s}"
                        has_run &= getattr(obj, attr)
                        setattr(obj, attr, True)
                else:
                    attr = f"_has_{name}_{stage}"
                    has_run = getattr(obj, attr)
                    setattr(obj, attr, True)

            elif name == "prepare_data":
                has_run = obj._has_prepared_data
                obj._has_prepared_data = True

            if has_run:
                rank_zero_deprecation(
                    f"DataModule.{name} has already been called, so it will not be called again. "
                    f"In v1.6 this behavior will change to always call DataModule.{name}."
                )
            else:
                fn(*args, **kwargs)
Beispiel #4
0
    def grad_norm(self, norm_type: Union[float, int, str]) -> Dict[str, float]:
        """Compute each parameter's gradient's norm and their overall norm.

        .. deprecated:: v1.3
            Will be removed in v1.5.0. Use :func:`pytorch_lightning.utilities.grads.grad_norm` instead.
        """
        rank_zero_deprecation(
            "LightningModule.grad_norm is deprecated in v1.3 and will be removed in v1.5."
            " Use grad_norm from pytorch_lightning.utilities.grads instead.")
        return new_grad_norm(self, norm_type)
    def has_setup_fit(self) -> bool:
        """Return bool letting you know if ``datamodule.setup(stage='fit')`` has been called or not.

        Returns:
            bool: True ``if datamodule.setup(stage='fit')`` has been called. False by default.

        .. deprecated:: v1.4
            Will be removed in v1.6.0.
        """
        rank_zero_deprecation('DataModule property `has_setup_fit` was deprecated in v1.4 and will be removed in v1.6.')
        return self._has_setup_fit
Beispiel #6
0
 def has_arg(self, f_name: str, arg_name: str) -> bool:
     rank_zero_deprecation(
         "Internal: TrainerModelHooksMixin.is_function_implemented is deprecated in v1.4"
         " and will be removed in v1.6."
         " Use `pytorch_lightning.utilities.signature_utils.is_param_in_hook_signature` instead."
     )
     model = self.lightning_module
     f_op = getattr(model, f_name, None)
     if not f_op:
         return False
     return is_param_in_hook_signature(f_op, arg_name)
Beispiel #7
0
 def is_using_torchelastic(self) -> bool:
     """
     .. deprecated:: v1.3
         Will be removed in v1.5.0.
     Returns:
         ``True`` if the current process was launched using the torchelastic command.
     """
     rank_zero_deprecation(
         "The property `AcceleratorConnector.is_using_torchelastic` was deprecated in v1.3"
         " and will be removed in 1.5. Use `TorchElasticEnvironment.is_using_torchelastic()` instead.",
     )
     return TorchElasticEnvironment.is_using_torchelastic()
Beispiel #8
0
 def is_function_implemented(
         self,
         f_name: str,
         model: Optional[LightningModule] = None) -> bool:
     rank_zero_deprecation(
         "Internal: TrainerModelHooksMixin.is_function_implemented is deprecated in v1.4"
         " and will be removed in v1.6.")
     # note: currently unused - kept as it is public
     if model is None:
         model = self.lightning_module
     f_op = getattr(model, f_name, None)
     return callable(f_op)
    def has_teardown_validate(self) -> bool:
        """Return bool letting you know if ``datamodule.teardown(stage='validate')`` has been called or not.

        Returns:
            bool: True if ``datamodule.teardown(stage='validate')`` has been called. False by default.

        .. deprecated:: v1.4
            Will be removed in v1.6.0.
        """
        rank_zero_deprecation(
            'DataModule property `has_teardown_validate` was deprecated in v1.4 and will be removed in v1.6.'
        )
        return self._has_teardown_validate
    def has_prepared_data(self) -> bool:
        """Return bool letting you know if ``datamodule.prepare_data()`` has been called or not.

        Returns:
            bool: True if ``datamodule.prepare_data()`` has been called. False by default.

        .. deprecated:: v1.4
            Will be removed in v1.6.0.
        """
        rank_zero_deprecation(
            'DataModule property `has_prepared_data` was deprecated in v1.4 and will be removed in v1.6.'
        )
        return self._has_prepared_data
Beispiel #11
0
    def __deprecation_check(
        self,
        profiled_functions: Optional[List[str]],
        record_functions: Optional[Set[str]],
    ) -> Set[str]:
        if record_functions is None:
            record_functions = set()

        if profiled_functions is not None:
            rank_zero_deprecation(
                "`PyTorchProfiler.profiled_functions` has been renamed to"
                " `record_functions` in v1.3 and will be removed in v1.5", )
            if not record_functions:
                record_functions |= set(profiled_functions)
            else:
                raise MisconfigurationException(
                    "You set `PytorchProfiler.profiled_functions` and `PyTorchProfiler.record_functions`."
                    "  Please use only the later.")

        return record_functions
    def on_trainer_init(
        self,
        gradient_clip_val: float,
        gradient_clip_algorithm: str,
        track_grad_norm: Union[int, float, str],
        accumulate_grad_batches: Union[int, Dict[int, int], List[list]],
        truncated_bptt_steps: Optional[int],
        terminate_on_nan: bool,
    ):

        self.trainer.terminate_on_nan = terminate_on_nan

        # gradient clipping
        if gradient_clip_algorithm not in list(GradClipAlgorithmType):
            raise MisconfigurationException(
                f"gradient_clip_algorithm should be in {list(GradClipAlgorithmType)}"
            )
        self.trainer.gradient_clip_val = gradient_clip_val
        self.trainer.gradient_clip_algorithm = GradClipAlgorithmType(
            gradient_clip_algorithm)

        # gradient norm tracking
        if not isinstance(track_grad_norm,
                          (int, float)) and track_grad_norm != 'inf':
            raise MisconfigurationException(
                "track_grad_norm can be an int, a float or 'inf' (infinity norm)."
            )
        self.trainer.track_grad_norm = float(track_grad_norm)

        # accumulated grads
        self.trainer.accumulate_grad_batches = accumulate_grad_batches
        self.configure_accumulated_gradients(accumulate_grad_batches)

        if truncated_bptt_steps is not None and truncated_bptt_steps > 0:
            rank_zero_deprecation(
                "Trainer.truncated_bptt_steps is deprecated in v1.3 and will be removed in v1.5."
                " Set truncated_bptt_steps directly on the LightningModule instead."
            )
        self.trainer.truncated_bptt_steps = truncated_bptt_steps
def test_v1_6_0_rank_zero_warnings_moved():
    with pytest.deprecated_call(match='in v1.3.7 and will be removed in v1.6'):
        rank_zero_warn('test')
    with pytest.deprecated_call(match='in v1.3.7 and will be removed in v1.6'):
        rank_zero_deprecation('test')
Beispiel #14
0
from pytorch_lightning.utilities.distributed import rank_zero_deprecation

rank_zero_deprecation(
    "Using ``import pytorch_lightning.profiler.profilers`` is depreceated in v1.4, and will be removed in v1.6. "
    "HINT: Use ``import pytorch_lightning.profiler`` directly.")

from pytorch_lightning.profiler.advanced import AdvancedProfiler  # noqa E402
from pytorch_lightning.profiler.base import AbstractProfiler, BaseProfiler, PassThroughProfiler  # noqa E402
from pytorch_lightning.profiler.pytorch import PyTorchProfiler  # noqa E402
from pytorch_lightning.profiler.simple import SimpleProfiler  # noqa E402

__all__ = [
    'AbstractProfiler',
    'BaseProfiler',
    'AdvancedProfiler',
    'PassThroughProfiler',
    'PyTorchProfiler',
    'SimpleProfiler',
]
 def deprecation(self, m, *args, **kwargs):
     if m not in self:
         self.add(m)
         rank_zero_deprecation(m, *args, **kwargs)