def __init__( self, parallel_devices: Optional[List[torch.device]] = None, num_nodes: Optional[int] = None, cluster_environment: ClusterEnvironment = None, sync_batchnorm: Optional[bool] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, ddp_comm_wrapper: Optional[callable] = None, **kwargs: Any, ): super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) if num_nodes is not None: rank_zero_deprecation( "Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6. " "Notice that it will be overriden by the trainer setting." ) self._num_nodes = num_nodes or 1 if sync_batchnorm is not None: rank_zero_deprecation( "Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4, and will be removed in v1.6. " "Notice that it will be overriden by the trainer setting." ) self._sync_batchnorm = sync_batchnorm or False self._ddp_kwargs = kwargs self.dist = LightningDistributed() self.num_processes = len(parallel_devices) if parallel_devices is not None else 0 self.mp_queue = None self._ddp_comm_state = ddp_comm_state self._ddp_comm_hook = ddp_comm_hook self._ddp_comm_wrapper = ddp_comm_wrapper self._local_rank = 0 self.set_world_ranks()
def metrics_to_scalars(self, metrics: dict) -> dict: rank_zero_deprecation( "Internal: TrainerLoggingMixin.metrics_to_scalars is deprecated in v1.3" " and will be removed in v1.5." " Use `pytorch_lightning.utilities.metrics.metrics_to_scalars` instead." ) return new_metrics_to_scalars(metrics)
def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any: name = fn.__name__ has_run = False # If calling setup, we check the stage and assign stage-specific bool args if name in ("setup", "teardown"): # Get stage either by grabbing from args or checking kwargs. # If not provided, set call status of 'fit', 'validate', and 'test' to True. # We do this so __attach_datamodule in trainer.py doesn't mistakenly call # setup('test') on trainer.test() stage = args[0] if len(args) else kwargs.get("stage", None) if stage is None: has_run = True for s in ("fit", "validate", "test"): attr = f"_has_{name}_{s}" has_run &= getattr(obj, attr) setattr(obj, attr, True) else: attr = f"_has_{name}_{stage}" has_run = getattr(obj, attr) setattr(obj, attr, True) elif name == "prepare_data": has_run = obj._has_prepared_data obj._has_prepared_data = True if has_run: rank_zero_deprecation( f"DataModule.{name} has already been called, so it will not be called again. " f"In v1.6 this behavior will change to always call DataModule.{name}." ) else: fn(*args, **kwargs)
def grad_norm(self, norm_type: Union[float, int, str]) -> Dict[str, float]: """Compute each parameter's gradient's norm and their overall norm. .. deprecated:: v1.3 Will be removed in v1.5.0. Use :func:`pytorch_lightning.utilities.grads.grad_norm` instead. """ rank_zero_deprecation( "LightningModule.grad_norm is deprecated in v1.3 and will be removed in v1.5." " Use grad_norm from pytorch_lightning.utilities.grads instead.") return new_grad_norm(self, norm_type)
def has_setup_fit(self) -> bool: """Return bool letting you know if ``datamodule.setup(stage='fit')`` has been called or not. Returns: bool: True ``if datamodule.setup(stage='fit')`` has been called. False by default. .. deprecated:: v1.4 Will be removed in v1.6.0. """ rank_zero_deprecation('DataModule property `has_setup_fit` was deprecated in v1.4 and will be removed in v1.6.') return self._has_setup_fit
def has_arg(self, f_name: str, arg_name: str) -> bool: rank_zero_deprecation( "Internal: TrainerModelHooksMixin.is_function_implemented is deprecated in v1.4" " and will be removed in v1.6." " Use `pytorch_lightning.utilities.signature_utils.is_param_in_hook_signature` instead." ) model = self.lightning_module f_op = getattr(model, f_name, None) if not f_op: return False return is_param_in_hook_signature(f_op, arg_name)
def is_using_torchelastic(self) -> bool: """ .. deprecated:: v1.3 Will be removed in v1.5.0. Returns: ``True`` if the current process was launched using the torchelastic command. """ rank_zero_deprecation( "The property `AcceleratorConnector.is_using_torchelastic` was deprecated in v1.3" " and will be removed in 1.5. Use `TorchElasticEnvironment.is_using_torchelastic()` instead.", ) return TorchElasticEnvironment.is_using_torchelastic()
def is_function_implemented( self, f_name: str, model: Optional[LightningModule] = None) -> bool: rank_zero_deprecation( "Internal: TrainerModelHooksMixin.is_function_implemented is deprecated in v1.4" " and will be removed in v1.6.") # note: currently unused - kept as it is public if model is None: model = self.lightning_module f_op = getattr(model, f_name, None) return callable(f_op)
def has_teardown_validate(self) -> bool: """Return bool letting you know if ``datamodule.teardown(stage='validate')`` has been called or not. Returns: bool: True if ``datamodule.teardown(stage='validate')`` has been called. False by default. .. deprecated:: v1.4 Will be removed in v1.6.0. """ rank_zero_deprecation( 'DataModule property `has_teardown_validate` was deprecated in v1.4 and will be removed in v1.6.' ) return self._has_teardown_validate
def has_prepared_data(self) -> bool: """Return bool letting you know if ``datamodule.prepare_data()`` has been called or not. Returns: bool: True if ``datamodule.prepare_data()`` has been called. False by default. .. deprecated:: v1.4 Will be removed in v1.6.0. """ rank_zero_deprecation( 'DataModule property `has_prepared_data` was deprecated in v1.4 and will be removed in v1.6.' ) return self._has_prepared_data
def __deprecation_check( self, profiled_functions: Optional[List[str]], record_functions: Optional[Set[str]], ) -> Set[str]: if record_functions is None: record_functions = set() if profiled_functions is not None: rank_zero_deprecation( "`PyTorchProfiler.profiled_functions` has been renamed to" " `record_functions` in v1.3 and will be removed in v1.5", ) if not record_functions: record_functions |= set(profiled_functions) else: raise MisconfigurationException( "You set `PytorchProfiler.profiled_functions` and `PyTorchProfiler.record_functions`." " Please use only the later.") return record_functions
def on_trainer_init( self, gradient_clip_val: float, gradient_clip_algorithm: str, track_grad_norm: Union[int, float, str], accumulate_grad_batches: Union[int, Dict[int, int], List[list]], truncated_bptt_steps: Optional[int], terminate_on_nan: bool, ): self.trainer.terminate_on_nan = terminate_on_nan # gradient clipping if gradient_clip_algorithm not in list(GradClipAlgorithmType): raise MisconfigurationException( f"gradient_clip_algorithm should be in {list(GradClipAlgorithmType)}" ) self.trainer.gradient_clip_val = gradient_clip_val self.trainer.gradient_clip_algorithm = GradClipAlgorithmType( gradient_clip_algorithm) # gradient norm tracking if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf': raise MisconfigurationException( "track_grad_norm can be an int, a float or 'inf' (infinity norm)." ) self.trainer.track_grad_norm = float(track_grad_norm) # accumulated grads self.trainer.accumulate_grad_batches = accumulate_grad_batches self.configure_accumulated_gradients(accumulate_grad_batches) if truncated_bptt_steps is not None and truncated_bptt_steps > 0: rank_zero_deprecation( "Trainer.truncated_bptt_steps is deprecated in v1.3 and will be removed in v1.5." " Set truncated_bptt_steps directly on the LightningModule instead." ) self.trainer.truncated_bptt_steps = truncated_bptt_steps
def test_v1_6_0_rank_zero_warnings_moved(): with pytest.deprecated_call(match='in v1.3.7 and will be removed in v1.6'): rank_zero_warn('test') with pytest.deprecated_call(match='in v1.3.7 and will be removed in v1.6'): rank_zero_deprecation('test')
from pytorch_lightning.utilities.distributed import rank_zero_deprecation rank_zero_deprecation( "Using ``import pytorch_lightning.profiler.profilers`` is depreceated in v1.4, and will be removed in v1.6. " "HINT: Use ``import pytorch_lightning.profiler`` directly.") from pytorch_lightning.profiler.advanced import AdvancedProfiler # noqa E402 from pytorch_lightning.profiler.base import AbstractProfiler, BaseProfiler, PassThroughProfiler # noqa E402 from pytorch_lightning.profiler.pytorch import PyTorchProfiler # noqa E402 from pytorch_lightning.profiler.simple import SimpleProfiler # noqa E402 __all__ = [ 'AbstractProfiler', 'BaseProfiler', 'AdvancedProfiler', 'PassThroughProfiler', 'PyTorchProfiler', 'SimpleProfiler', ]
def deprecation(self, m, *args, **kwargs): if m not in self: self.add(m) rank_zero_deprecation(m, *args, **kwargs)