def check_fn(v: Tensor) -> Tensor: if v.grad_fn is not None: rank_zero_deprecation( f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically" " but this behaviour will change in v1.6. Please detach it manually:" " `return {'loss': ..., 'something': something.detach()}`") return v
def __verify_eval_loop_configuration(self, model: "pl.LightningModule", stage: str) -> None: loader_name = f"{stage}_dataloader" step_name = "validation_step" if stage == "val" else "test_step" has_loader = is_overridden(loader_name, model) has_step = is_overridden(step_name, model) if has_loader and not has_step: rank_zero_warn( f"you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop" ) if has_step and not has_loader: rank_zero_warn( f"you defined a {step_name} but have no {loader_name}. Skipping {stage} loop" ) # ---------------------------------------------- # verify model does not have # - on_val_dataloader # - on_test_dataloader # ---------------------------------------------- has_on_val_dataloader = is_overridden("on_val_dataloader", model) if has_on_val_dataloader: rank_zero_deprecation( "Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0." " Please use `val_dataloader()` directly.") has_on_test_dataloader = is_overridden("on_test_dataloader", model) if has_on_test_dataloader: rank_zero_deprecation( "Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0." " Please use `test_dataloader()` directly.")
def _check_on_configure_sharded_model(trainer: "pl.Trainer") -> None: for callback in trainer.callbacks: if is_overridden(method_name="on_configure_sharded_model", instance=callback): rank_zero_deprecation( "The `on_configure_sharded_model` callback hook was deprecated in" " v1.6 and will be removed in v1.8. Use `setup()` instead.")
def rank_zero_deprecation(*args, stacklevel: int = 5, **kwargs): from pytorch_lightning.utilities.warnings import rank_zero_deprecation rank_zero_deprecation( '`pytorch_lightning.utilities.distributed.rank_zero_deprecation` has been moved to' ' `pytorch_lightning.utilities.rank_zero_deprecation` in v1.3.7 and will be removed in v1.6' ) return rank_zero_deprecation(*args, stacklevel=stacklevel, **kwargs)
def test_v1_8_0_rank_zero_imports(): import warnings from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_info from pytorch_lightning.utilities.warnings import LightningDeprecationWarning, rank_zero_deprecation, rank_zero_warn with pytest.deprecated_call( match="pytorch_lightning.utilities.distributed.rank_zero_debug has been deprecated in v1.6" " and will be removed in v1.8." ): rank_zero_debug("foo") with pytest.deprecated_call( match="pytorch_lightning.utilities.distributed.rank_zero_info has been deprecated in v1.6" " and will be removed in v1.8." ): rank_zero_info("foo") with pytest.deprecated_call( match="pytorch_lightning.utilities.warnings.rank_zero_warn has been deprecated in v1.6" " and will be removed in v1.8." ): rank_zero_warn("foo") with pytest.deprecated_call( match="pytorch_lightning.utilities.warnings.rank_zero_deprecation has been deprecated in v1.6" " and will be removed in v1.8." ): rank_zero_deprecation("foo") with pytest.deprecated_call( match="pytorch_lightning.utilities.warnings.LightningDeprecationWarning has been deprecated in v1.6" " and will be removed in v1.8." ): warnings.warn("foo", LightningDeprecationWarning, stacklevel=5)
def prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str) -> Dict[str, float]: rank_zero_deprecation( "`pytorch_lightning.callbacks.device_stats_monitor.prefix_metrics`" " is deprecated in v1.6 and will be removed in v1.8.") sep = "" return _prefix_metric_keys(metrics_dict, prefix, sep)
def __init__(self, min_steps: Optional[int] = 0, max_steps: int = -1) -> None: super().__init__() if max_steps is None: rank_zero_deprecation( "Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7." " Use `max_steps = -1` instead." ) max_steps = -1 elif max_steps < -1: raise MisconfigurationException( f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {max_steps}." ) self.min_steps = min_steps self.max_steps = max_steps self.global_step: int = 0 self.batch_progress = BatchProgress() self.scheduler_progress = SchedulerProgress() self.batch_loop: Optional[TrainingBatchLoop] = None self.val_loop: Optional["loops.EvaluationLoop"] = None self._results = ResultCollection(training=True) self._outputs: _OUTPUTS_TYPE = [] self._warning_cache = WarningCache() self._dataloader_iter: Optional[Iterator] = None # caches the loaded dataloader state until dataloader objects are available self._dataloader_state_dict: Dict[str, Any] = {}
def batch_to(data: Any) -> Any: # try to move torchtext data first if _TORCHTEXT_LEGACY and isinstance(data, Batch): # TODO: also remove the torchtext dependency with Lightning 1.8 rank_zero_deprecation( "The `torchtext.legacy.Batch` object is deprecated and Lightning will remove support for it in v1.8." " We recommend you to migrate away from Batch by following the TorchText README:" " https://github.com/pytorch/text#bc-breaking-legacy") # Shallow copy because each Batch has a reference to Dataset which contains all examples device_data = copy(data) for field, field_value in data.dataset.fields.items(): if field_value is None: continue device_field = move_data_to_device(getattr(data, field), device) setattr(device_data, field, device_field) return device_data kwargs = {} # Don't issue non-blocking transfers to CPU if isinstance(data, Tensor) and device not in _CPU_DEVICES: kwargs["non_blocking"] = True data_output = data.to(device, **kwargs) if data_output is not None: return data_output # user wrongly implemented the `TransferableDataType` and forgot to return `self`. return data
def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] rank_zero_deprecation( "The `pytorch_lightning.loggers.base.DummyExperiment` is deprecated in v1.7" " and will be removed in v1.9. Please use `pytorch_lightning.loggers.logger.DummyExperiment` instead." ) super().__init__(*args, **kwargs)
def rank_zero_warn(*args: Any, stacklevel: int = 5, **kwargs: Any) -> None: from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn rank_zero_deprecation( "`pytorch_lightning.utilities.distributed.rank_zero_warn` has been moved to" " `pytorch_lightning.utilities.rank_zero_warn` in v1.3.7 and will be removed in v1.6" ) return rank_zero_warn(*args, stacklevel=stacklevel, **kwargs)
def _check_on_keyboard_interrupt(trainer: "pl.Trainer") -> None: """Checks if on_keyboard_interrupt is overriden and sends a deprecation warning.""" for callback in trainer.callbacks: if is_overridden(method_name="on_keyboard_interrupt", instance=callback): rank_zero_deprecation( "The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7." " Please use the `on_exception` callback hook instead.")
def torch_distributed_backend(self) -> str: """Deprecated property.""" rank_zero_deprecation( "ParallelStrategy.torch_distributed_backend was deprecated in v1.6 and will be removed in v1.8." ) pg_backend = _get_process_group_backend_from_env() if pg_backend: return pg_backend return get_default_process_group_backend_for_device(self.root_device)
def on_predict_dataloader(self) -> None: """Called before requesting the predict dataloader. .. deprecated:: v1.5 :meth:`on_predict_dataloader` is deprecated and will be removed in v1.7.0. Please use :meth:`predict_dataloader()` directly. """ rank_zero_deprecation( "Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0." " Please use `predict_dataloader()` directly." )
def _check_on_init_start_end(trainer: "pl.Trainer") -> None: """Checks if on_init_start/end are overridden and sends a deprecation warning.""" for callback in trainer.callbacks: if is_overridden(method_name="on_init_start", instance=callback): rank_zero_deprecation( "The `on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8." ) if is_overridden(method_name="on_init_end", instance=callback): rank_zero_deprecation( "The `on_init_end` callback hook was deprecated in v1.6 and will be removed in v1.8." )
def _check_progress_bar(model: "pl.LightningModule") -> None: r""" Checks if get_progress_bar_dict is overriden and sends a deprecation warning. Args: model: The model to check the get_progress_bar_dict method. """ if is_overridden("get_progress_bar_dict", model): rank_zero_deprecation( "The `LightningModule.get_progress_bar_dict` method was deprecated in v1.5 and will be removed in v1.7." " Please use the `ProgressBarBase.get_metrics` instead.")
def close(self) -> None: """ .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.7. Please use `LoggerCollection.finalize` instead. """ rank_zero_deprecation( "`LoggerCollection.close` method is deprecated in v1.5 and will be removed in v1.7." " Please use `LoggerCollection.finalize` instead.") for logger in self._logger_iterable: logger.close()
def _check_on_hpc_hooks(model: "pl.LightningModule") -> None: if is_overridden("on_hpc_save", model): rank_zero_deprecation( "Method `LightningModule.on_hpc_save` is deprecated in v1.6 and" " will be removed in v1.8. Please use `LightningModule.on_save_checkpoint` instead." ) if is_overridden("on_hpc_load", model): rank_zero_deprecation( "Method `LightningModule.on_hpc_load` is deprecated in v1.6 and" " will be removed in v1.8. Please use `LightningModule.on_load_checkpoint` instead." )
def merge_dicts( dicts: Sequence[Mapping], agg_key_funcs: Optional[Mapping] = None, default_func: Callable[[Sequence[float]], float] = np.mean, ) -> Dict: rank_zero_deprecation( "The `pytorch_lightning.loggers.base.merge_dicts` is deprecated in v1.7" " and will be removed in v1.9. Please use `pytorch_lightning.loggers.logger.merge_dicts` instead." ) return logger.merge_dicts(dicts=dicts, agg_key_funcs=agg_key_funcs, default_func=default_func)
def _check_on_batch_start_end(trainer: "pl.Trainer", model: "pl.LightningModule") -> None: hooks = (["on_batch_start", "on_train_batch_start"], ["on_batch_end", "on_train_batch_end"]) for hook, alternative_hook in hooks: for callback in trainer.callbacks: if is_overridden(method_name=hook, instance=callback): rank_zero_deprecation( f"The `Callback.{hook}` hook was deprecated in v1.6 and" f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead." )
def _check_on_post_move_to_device(model: "pl.LightningModule") -> None: r""" Checks if `on_post_move_to_device` method is overriden and sends a deprecation warning. Args: model: The model to check the `on_post_move_to_device` method. """ if is_overridden("on_post_move_to_device", model): rank_zero_deprecation( "Method `on_post_move_to_device` has been deprecated in v1.5 and will be removed in v1.7. " "We perform automatic parameters tying without the need of implementing `on_post_move_to_device`." )
def close(self) -> None: """Do any cleanup that is necessary to close an experiment. See deprecation warning below. .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.7. Please use `LightningLoggerBase.finalize` instead. """ rank_zero_deprecation( "`LightningLoggerBase.close` method is deprecated in v1.5 and will be removed in v1.7." " Please use `LightningLoggerBase.finalize` instead.") self.save()
def _check_add_get_queue(model: "pl.LightningModule") -> None: r""" Checks if add_to_queue or get_from_queue is overriden and sends a deprecation warning. Args: model: The lightning module """ if is_overridden("add_to_queue", model): rank_zero_deprecation( "The `LightningModule.add_to_queue` method was deprecated in v1.5 and will be removed in v1.7 in " "favor of `DDPSpawnStrategy.add_to_queue`") if is_overridden("get_from_queue", model): rank_zero_deprecation( "The `LightningModule.get_from_queue` method was deprecated in v1.5 and will be removed in v1.7 in " "favor of `DDPSpawnStrategy.get_from_queue`")
def _configure_model_summary_callback( self, enable_model_summary: bool, weights_summary: Optional[str] = None) -> None: if weights_summary is None: rank_zero_deprecation( "Setting `Trainer(weights_summary=None)` is deprecated in v1.5 and will be removed" " in v1.7. Please set `Trainer(enable_model_summary=False)` instead." ) return if not enable_model_summary: return model_summary_cbs = [ type(cb) for cb in self.trainer.callbacks if isinstance(cb, ModelSummary) ] if model_summary_cbs: rank_zero_info( f"Trainer already configured with model summary callbacks: {model_summary_cbs}." " Skipping setting a default `ModelSummary` callback.") return if weights_summary == "top": # special case the default value for weights_summary to preserve backward compatibility max_depth = 1 else: rank_zero_deprecation( f"Setting `Trainer(weights_summary={weights_summary})` is deprecated in v1.5 and will be removed" " in v1.7. Please pass `pytorch_lightning.callbacks.model_summary.ModelSummary` with" " `max_depth` directly to the Trainer's `callbacks` argument instead." ) if weights_summary not in ModelSummaryMode.supported_types(): raise MisconfigurationException( f"`weights_summary` can be None, {', '.join(ModelSummaryMode.supported_types())}", f" but got {weights_summary}", ) max_depth = ModelSummaryMode.get_max_depth(weights_summary) is_progress_bar_rich = isinstance(self.trainer._progress_bar_callback, RichProgressBar) if self.trainer._progress_bar_callback is not None and is_progress_bar_rich: model_summary = RichModelSummary(max_depth=max_depth) else: model_summary = ModelSummary(max_depth=max_depth) self.trainer.callbacks.append(model_summary) self.trainer._weights_summary = weights_summary
def __verify_predict_loop_configuration( self, model: "pl.LightningModule") -> None: has_predict_dataloader = is_overridden("predict_dataloader", model) if not has_predict_dataloader: raise MisconfigurationException( "Dataloader not found for `Trainer.predict`") # ---------------------------------------------- # verify model does not have # - on_predict_dataloader # ---------------------------------------------- has_on_predict_dataloader = is_overridden("on_predict_dataloader", model) if has_on_predict_dataloader: rank_zero_deprecation( "Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0." " Please use `predict_dataloader()` directly.")
def _check_dl_idx_in_on_train_batch_hooks(trainer: "pl.Trainer", model: "pl.LightningModule") -> None: for hook in ("on_train_batch_start", "on_train_batch_end"): if is_param_in_hook_signature(getattr(model, hook), "dataloader_idx", explicit=True): rank_zero_deprecation( f"Base `LightningModule.{hook}` hook signature has changed in v1.5." " The `dataloader_idx` argument will be removed in v1.7.") for cb in trainer.callbacks: if is_param_in_hook_signature(getattr(cb, hook), "dataloader_idx", explicit=True): rank_zero_deprecation( f"Base `Callback.{hook}` hook signature has changed in v1.5." " The `dataloader_idx` argument will be removed in v1.7.")
def on_trainer_init( self, callbacks: Optional[Union[List[Callback], Callback]], checkpoint_callback: bool, progress_bar_refresh_rate: Optional[int], process_position: int, default_root_dir: Optional[str], weights_save_path: Optional[str], stochastic_weight_avg: bool, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, ): # init folder paths for checkpoint + weights save callbacks self.trainer._default_root_dir = default_root_dir or os.getcwd() self.trainer._weights_save_path = weights_save_path or self.trainer._default_root_dir self.trainer._stochastic_weight_avg = stochastic_weight_avg # init callbacks if isinstance(callbacks, Callback): callbacks = [callbacks] self.trainer.callbacks = callbacks or [] # configure checkpoint callback # pass through the required args to figure out defaults self._configure_checkpoint_callbacks(checkpoint_callback) # configure swa callback self._configure_swa_callbacks() # configure the timer callback. # responsible to stop the training when max_time is reached. self._configure_timer_callback(max_time) # init progress bar if process_position != 0: rank_zero_deprecation( f"Setting `Trainer(process_position={process_position})` is deprecated in v1.5 and will be removed" " in v1.7. Please pass `pytorch_lightning.callbacks.progress.ProgressBar` with" " `process_position` directly to the Trainer's `callbacks` argument instead." ) self.trainer._progress_bar_callback = self.configure_progress_bar( progress_bar_refresh_rate, process_position) # push all checkpoint callbacks to the end # it is important that these are the last callbacks to run self.trainer.callbacks = self._reorder_callbacks( self.trainer.callbacks)
def __verify_eval_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule", stage: str) -> None: loader_name = f"{stage}_dataloader" step_name = "validation_step" if stage == "val" else f"{stage}_step" trainer_method = "validate" if stage == "val" else stage on_eval_hook = f"on_{loader_name}" has_loader = getattr(trainer._data_connector, f"_{stage}_dataloader_source").is_defined() has_step = is_overridden(step_name, model) has_on_eval_dataloader = is_overridden(on_eval_hook, model) # ---------------------------------------------- # verify model does not have on_eval_dataloader # ---------------------------------------------- if has_on_eval_dataloader: rank_zero_deprecation( f"Method `{on_eval_hook}` is deprecated in v1.5.0 and will" f" be removed in v1.7.0. Please use `{loader_name}()` directly.") # ----------------------------------- # verify model has an eval_dataloader # ----------------------------------- if not has_loader: raise MisconfigurationException( f"No `{loader_name}()` method defined to run `Trainer.{trainer_method}`." ) # predict_step is not required to be overridden if stage == "predict": if model.predict_step is None: raise MisconfigurationException( "`predict_step` cannot be None to run `Trainer.predict`") elif not has_step and not is_overridden("forward", model): raise MisconfigurationException( "`Trainer.predict` requires `forward` method to run.") else: # ----------------------------------- # verify model has an eval_step # ----------------------------------- if not has_step: raise MisconfigurationException( f"No `{step_name}()` method defined to run `Trainer.{trainer_method}`." )
def on_trainer_init( self, logger: Union[bool, LightningLoggerBase, Iterable[LightningLoggerBase]], flush_logs_every_n_steps: Optional[int], log_every_n_steps: int, move_metrics_to_cpu: bool, ) -> None: self.configure_logger(logger) if flush_logs_every_n_steps is not None: rank_zero_deprecation( f"Setting `Trainer(flush_logs_every_n_steps={flush_logs_every_n_steps})` is deprecated in v1.5 " "and will be removed in v1.7. Please configure flushing in the logger instead." ) else: flush_logs_every_n_steps = 100 # original default parameter self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps self.trainer.log_every_n_steps = log_every_n_steps self.trainer.move_metrics_to_cpu = move_metrics_to_cpu
def __init__(self, trainer: "pl.Trainer", log_gpu_memory: Optional[str] = None) -> None: self.trainer = trainer if log_gpu_memory is not None: rank_zero_deprecation( "Setting `log_gpu_memory` with the trainer flag is deprecated in v1.5 and will be removed in v1.7. " "Please monitor GPU stats with the `DeviceStatsMonitor` callback directly instead." ) self.log_gpu_memory = log_gpu_memory self.eval_loop_results: List[_OUT_DICT] = [] self._val_log_step: int = 0 self._test_log_step: int = 0 self._progress_bar_metrics: _PBAR_DICT = {} self._logged_metrics: _OUT_DICT = {} self._callback_metrics: _OUT_DICT = {} self._gpus_metrics: Dict[str, float] = {} self._epoch_end_reached = False self._current_fx: Optional[str] = None self._batch_idx: Optional[int] = None self._split_idx: Optional[int] = None
def _configure_checkpoint_callbacks(self, checkpoint_callback: Optional[bool], enable_checkpointing: bool) -> None: if checkpoint_callback is not None: rank_zero_deprecation( f"Setting `Trainer(checkpoint_callback={checkpoint_callback})` is deprecated in v1.5 and will " f"be removed in v1.7. Please consider using `Trainer(enable_checkpointing={checkpoint_callback})`." ) # if both are set then checkpoint only if both are True enable_checkpointing = checkpoint_callback and enable_checkpointing if self._trainer_has_checkpoint_callbacks( ) and enable_checkpointing is False: raise MisconfigurationException( "Trainer was configured with `enable_checkpointing=False`" " but found `ModelCheckpoint` in callbacks list.") if not self._trainer_has_checkpoint_callbacks( ) and enable_checkpointing is True: self.trainer.callbacks.append(ModelCheckpoint())