def update_replica_device_attributes(self, inputs: Any) -> None: """Updates the device information of LightningModule by reading the device from the inputs. In :class:`~torch.nn.data_parallel.DataParallel` changes to the state during the `forward` pass are lost when the replicas get discarded. The only way to know the current device is from the inputs passed into the model. Args: inputs: A collection of inputs (typically a tuple). If the inputs don't contain tensors, a warning is shown that accessing ``self.device`` will not return the correct device. """ replica_device = None def find_tensor_with_device(tensor: Tensor) -> Tensor: nonlocal replica_device if replica_device is None and tensor.device != torch.device("cpu"): replica_device = tensor.device return tensor apply_to_collection(inputs, dtype=Tensor, function=find_tensor_with_device) if replica_device is not None: # by calling .to() we force the update to the self.device property self.module.to(device=replica_device) else: rank_zero_warn( "Could not determine on which device the inputs are." " When using DataParallel (strategy='dp'), be aware that in case you are using self.device" " in your code, it will reference only the root device.")
def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None) -> Sampler: if self._requires_distributed_sampler(dataloader): sampler = self._get_distributed_sampler( dataloader, shuffle, mode=mode, overfit_batches=self.trainer.overfit_batches, **self.trainer.distributed_sampler_kwargs, ) # update docs too once this is resolved trainer_fn = self.trainer.state.fn if isinstance(sampler, DistributedSampler) and trainer_fn in ( TrainerFn.VALIDATING, TrainerFn.TESTING): rank_zero_warn( f"Using `DistributedSampler` with the dataloaders. During `trainer.{trainer_fn.value}()`," " it is recommended to use `Trainer(devices=1)` to ensure each sample/batch gets evaluated" " exactly once. Otherwise, multi-device settings use `DistributedSampler` that replicates" " some samples to make sure all devices have same batch size in case of uneven inputs.", category=PossibleUserWarning, ) return sampler return dataloader.sampler
def _resolve_overfit_batches( dataloader: Collection[DataLoader]) -> Collection[DataLoader]: all_have_sequential_sampler = True def resolve_has_no_sequential_sampler(dataloader: DataLoader): nonlocal all_have_sequential_sampler all_have_sequential_sampler = all_have_sequential_sampler & isinstance( dataloader.sampler, SequentialSampler) apply_to_collection(dataloader, DataLoader, resolve_has_no_sequential_sampler) if not all_have_sequential_sampler: rank_zero_warn( "You requested to overfit but enabled training dataloader shuffling." " We are turning off the training dataloader shuffling for you." ) def replace_sampler(dataloader: DataLoader) -> DataLoader: return _update_dataloader(dataloader, SequentialSampler( dataloader.dataset), mode=RunningStage.TRAINING) dataloader = apply_to_collection(dataloader, DataLoader, replace_sampler) return dataloader
def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: checkpoint: dict containing model and trainer state path: write-target path storage_options: not used in ``TorchCheckpointIO.save_checkpoint`` Raises: TypeError: If ``storage_options`` arg is passed in """ if storage_options is not None: raise TypeError( "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg" f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`" " to define how you'd like to use `storage_options`.") fs = get_filesystem(path) fs.makedirs(os.path.dirname(path), exist_ok=True) try: # write the checkpoint dictionary on the file atomic_save(checkpoint, path) except AttributeError as err: # todo (sean): is this try catch necessary still? # https://github.com/Lightning-AI/lightning/pull/431 key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY checkpoint.pop(key, None) rank_zero_warn( f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}" ) atomic_save(checkpoint, path)
def _check_eval_shuffling(dataloader, mode): if _is_dataloader_shuffled(dataloader): rank_zero_warn( f"Your `{mode.dataloader_prefix}_dataloader`'s sampler has shuffling enabled," " it is strongly recommended that you turn shuffling off for val/test/predict dataloaders.", category=PossibleUserWarning, )
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: r""" .. deprecated:: v1.6 `TrainerCallbackHookMixin.on_load_checkpoint` was deprecated in v1.6 and will be removed in v1.8. Called when loading a model checkpoint. """ # Todo: the `callback_states` are dropped with TPUSpawn as they # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 rank_zero_deprecation( "`TrainerCallbackHookMixin.on_load_checkpoint` was deprecated in v1.6 and will be removed in v1.8." ) callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks") if callback_states is None: return is_legacy_ckpt = Version(checkpoint["pytorch-lightning_version"]) < Version("1.5.0dev") current_callbacks_keys = {cb._legacy_state_key if is_legacy_ckpt else cb.state_key for cb in self.callbacks} difference = callback_states.keys() - current_callbacks_keys if difference: rank_zero_warn( "Be aware that when using `ckpt_path`," " callbacks used to create the checkpoint need to be provided during `Trainer` instantiation." f" Please add the following callbacks: {list(difference)}.", ) for callback in self.callbacks: state = callback_states.get(callback.state_key, callback_states.get(callback._legacy_state_key)) if state: state = deepcopy(state) callback.on_load_checkpoint(self, self.lightning_module, state)
def get_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Dict[str, Union[int, str]]: r""" Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics. Implement this to override the items displayed in the progress bar. Here is an example of how to override the defaults: .. code-block:: python def get_metrics(self, trainer, model): # don't show the version number items = super().get_metrics(trainer, model) items.pop("v_num", None) return items Return: Dictionary with the items to be displayed in the progress bar. """ standard_metrics = get_standard_metrics(trainer, pl_module) pbar_metrics = trainer.progress_bar_metrics duplicates = list(standard_metrics.keys() & pbar_metrics.keys()) if duplicates: rank_zero_warn( f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and" f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. " " If this is undesired, change the name or override `get_metrics()` in the progress bar callback.", ) return {**standard_metrics, **pbar_metrics}
def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None: valid_keys = {"optimizer", "lr_scheduler", "frequency", "monitor"} extra_keys = optim_conf.keys() - valid_keys if extra_keys: rank_zero_warn( f"Found unsupported keys in the optimizer configuration: {set(extra_keys)}", category=RuntimeWarning)
def _get_short_description(component: object) -> Optional[str]: parse = import_docstring_parse("LightningCLI(run=True)") try: docstring = parse(component.__doc__) return docstring.short_description except ValueError: rank_zero_warn(f"Failed parsing docstring for {component}")
def _instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]) -> Trainer: key = "callbacks" if key in config: if config[key] is None: config[key] = [] elif not isinstance(config[key], list): config[key] = [config[key]] config[key].extend(callbacks) if key in self.trainer_defaults: value = self.trainer_defaults[key] config[key] += value if isinstance(value, list) else [value] if self.save_config_callback and not config.get( "fast_dev_run", False): config_callback = self.save_config_callback( self._parser(self.subcommand), self.config.get(str(self.subcommand), self.config), self.save_config_filename, overwrite=self.save_config_overwrite, multifile=self.save_config_multifile, ) config[key].append(config_callback) else: rank_zero_warn( f"The `{self.trainer_class.__qualname__}` class does not expose the `{key}` argument so they will" " not be included.") return self.trainer_class(**config)
def _configure_schedulers_manual_opt( schedulers: list) -> List[LRSchedulerConfig]: """Convert each scheduler into `LRSchedulerConfig` structure with relevant information, when using manual optimization.""" lr_scheduler_configs = [] for scheduler in schedulers: if isinstance(scheduler, dict): invalid_keys = { "interval", "frequency", "reduce_on_plateau", "monitor", "strict" } keys_to_warn = [k for k in scheduler.keys() if k in invalid_keys] if keys_to_warn: rank_zero_warn( f"The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored." " You need to call `lr_scheduler.step()` manually in manual optimization.", category=RuntimeWarning, ) config = LRSchedulerConfig( **{ key: scheduler[key] for key in scheduler if key not in invalid_keys }) else: config = LRSchedulerConfig(scheduler) lr_scheduler_configs.append(config) return lr_scheduler_configs
def load_hparams_from_yaml(config_yaml: _PATH, use_omegaconf: bool = True) -> Dict[str, Any]: """Load hparams from a file. Args: config_yaml: Path to config yaml file use_omegaconf: If omegaconf is available and ``use_omegaconf=True``, the hparams will be converted to ``DictConfig`` if possible. >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') >>> path_yaml = './testing-hparams.yaml' >>> save_hparams_to_yaml(path_yaml, hparams) >>> hparams_new = load_hparams_from_yaml(path_yaml) >>> vars(hparams) == hparams_new True >>> os.remove(path_yaml) """ fs = get_filesystem(config_yaml) if not fs.exists(config_yaml): rank_zero_warn(f"Missing Tags: {config_yaml}.", category=RuntimeWarning) return {} with fs.open(config_yaml, "r") as fp: hparams = yaml.full_load(fp) if _OMEGACONF_AVAILABLE: if use_omegaconf: try: return OmegaConf.create(hparams) except (UnsupportedValueType, ValidationError): pass return hparams
def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> List: """This function is used to exclude any parameter which already exists in this optimizer. Args: optimizer: Optimizer used for parameter exclusion params: Iterable of parameters used to check against the provided optimizer Returns: List of parameters not contained in this optimizer param groups """ out_params = [] removed_params = [] for param in params: if not any( torch.equal(p, param) for group in optimizer.param_groups for p in group["params"]): out_params.append(param) else: removed_params.append(param) if removed_params: rank_zero_warn( "The provided params to be frozen already exist within another group of this optimizer." " Those parameters will be skipped.\n" "HINT: Did you init your optimizer in `configure_optimizer` as such:\n" f" {type(optimizer)}(filter(lambda p: p.requires_grad, self.parameters()), ...) ", ) return out_params
def reset(self) -> None: """Resets the internal state of the loop for a new run.""" if self.restarting: self.batch_progress.reset_on_restart() self.scheduler_progress.reset_on_restart() self.batch_loop.optimizer_loop.optim_progress.reset_on_restart() trainer = self.trainer if not trainer.state._fault_tolerant_mode.is_enabled and trainer.num_training_batches != float( "inf"): expected_steps = math.ceil(trainer.num_training_batches / trainer.accumulate_grad_batches) if self.global_step % expected_steps != 0: rank_zero_warn( "You're resuming from a checkpoint that ended before the epoch ended. This can cause unreliable" " results if further training is done. Consider using an end-of-epoch checkpoint or enabling" " fault-tolerant training:" " https://pytorch-lightning.readthedocs.io/en/stable/advanced/fault_tolerant_training.html" ) else: self.batch_progress.reset_on_run() self.scheduler_progress.reset_on_run() self.batch_loop.optimizer_loop.optim_progress.reset_on_run() # when the epoch starts, the total val batch progress should be reset as it's supposed to count the batches # seen per epoch, this is useful for tracking when validation is run multiple times per epoch self.val_loop.epoch_loop.batch_progress.total.reset() self._outputs = []
def test_v1_8_0_rank_zero_imports(): import warnings from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_info from pytorch_lightning.utilities.warnings import LightningDeprecationWarning, rank_zero_deprecation, rank_zero_warn with pytest.deprecated_call( match="pytorch_lightning.utilities.distributed.rank_zero_debug has been deprecated in v1.6" " and will be removed in v1.8." ): rank_zero_debug("foo") with pytest.deprecated_call( match="pytorch_lightning.utilities.distributed.rank_zero_info has been deprecated in v1.6" " and will be removed in v1.8." ): rank_zero_info("foo") with pytest.deprecated_call( match="pytorch_lightning.utilities.warnings.rank_zero_warn has been deprecated in v1.6" " and will be removed in v1.8." ): rank_zero_warn("foo") with pytest.deprecated_call( match="pytorch_lightning.utilities.warnings.rank_zero_deprecation has been deprecated in v1.6" " and will be removed in v1.8." ): rank_zero_deprecation("foo") with pytest.deprecated_call( match="pytorch_lightning.utilities.warnings.LightningDeprecationWarning has been deprecated in v1.6" " and will be removed in v1.8." ): warnings.warn("foo", LightningDeprecationWarning, stacklevel=5)
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: """Gets parallel devices for the Accelerator.""" if isinstance(devices, int): return [torch.device("cpu")] * devices rank_zero_warn( f"The flag `devices` must be an int with `accelerator='cpu'`, got `devices={devices!r}` instead." ) return []
def main_address(self) -> str: if "MASTER_ADDR" not in os.environ: rank_zero_warn( "MASTER_ADDR environment variable is not defined. Set as localhost" ) os.environ["MASTER_ADDR"] = "127.0.0.1" log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") return os.environ["MASTER_ADDR"]
def init_meta_context() -> Generator: rank_zero_warn( "Be aware this feature is highly experimental and there are a number of weird edge cases " "where it can internal assert and/or crash. A more stable version is to be expected from PyTorch 1.11." ) _set_meta_device() yield _unset_meta_device()
def _get_short_description(component: object) -> Optional[str]: if component.__doc__ is None: return None try: docstring = docstring_parser.parse(component.__doc__) return docstring.short_description except (ValueError, docstring_parser.ParseError) as ex: rank_zero_warn(f"Failed parsing docstring for {component}: {ex}")
def __init__( self, name: Optional[str] = None, save_dir: Optional[str] = None, offline: Optional[bool] = False, id: Optional[str] = None, anonymous: Optional[bool] = None, version: Optional[str] = None, project: Optional[str] = None, log_model: Union[str, bool] = False, experiment=None, prefix: Optional[str] = "", **kwargs, ): if wandb is None: raise ModuleNotFoundError( "You want to use `wandb` logger which is not installed yet," " install it with `pip install wandb`." # pragma: no-cover ) if offline and log_model: raise MisconfigurationException( f"Providing log_model={log_model} and offline={offline} is an invalid configuration" " since model checkpoints cannot be uploaded in offline mode.\n" "Hint: Set `offline=False` to log your model.") if log_model and not _WANDB_GREATER_EQUAL_0_10_22: rank_zero_warn( f"Providing log_model={log_model} requires wandb version >= 0.10.22" " for logging associated model metadata.\n" "Hint: Upgrade with `pip install --upgrade wandb`.") super().__init__() self._offline = offline self._log_model = log_model self._prefix = prefix self._experiment = experiment self._logged_model_time = {} self._checkpoint_callback = None # set wandb init arguments anonymous_lut = {True: "allow", False: None} self._wandb_init = dict( name=name, project=project, id=version or id, dir=save_dir, resume="allow", anonymous=anonymous_lut.get(anonymous, anonymous), ) self._wandb_init.update(**kwargs) # extract parameters self._save_dir = self._wandb_init.get("dir") self._name = self._wandb_init.get("name") self._id = self._wandb_init.get("id") # start wandb run (to create an attach_id for distributed modes) if _WANDB_GREATER_EQUAL_0_12_10: wandb.require("service") _ = self.experiment
def main_port(self) -> int: if "MASTER_PORT" not in os.environ: rank_zero_warn( "MASTER_PORT environment variable is not defined. Set as 12910" ) os.environ["MASTER_PORT"] = "12910" log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") return int(os.environ["MASTER_PORT"])
def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: """Called before training, determines unique names for all lr schedulers in the case of multiple of the same type or in the case of multiple parameter groups. Raises: MisconfigurationException: If ``Trainer`` has no ``logger``. """ if not trainer.loggers: raise MisconfigurationException( "Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger." ) if self.log_momentum: def _check_no_key(key: str) -> bool: if trainer.lr_scheduler_configs: return any(key not in config.scheduler.optimizer.defaults for config in trainer.lr_scheduler_configs) return any(key not in optimizer.defaults for optimizer in trainer.optimizers) if _check_no_key("momentum") and _check_no_key("betas"): rank_zero_warn( "You have set log_momentum=True, but some optimizers do not" " have momentum. This will log a value 0 for the momentum.", category=RuntimeWarning, ) # Find names for schedulers names: List[List[str]] = [] ( sched_hparam_keys, optimizers_with_scheduler, optimizers_with_scheduler_types, ) = self._find_names_from_schedulers(trainer.lr_scheduler_configs) names.extend(sched_hparam_keys) # Find names for leftover optimizers optimizer_hparam_keys, _ = self._find_names_from_optimizers( trainer.optimizers, seen_optimizers=optimizers_with_scheduler, seen_optimizer_types=optimizers_with_scheduler_types, ) names.extend(optimizer_hparam_keys) # Initialize for storing values names_flatten = list(itertools.chain.from_iterable(names)) self.lrs = {name: [] for name in names_flatten} self.last_momentum_values = { name + "-momentum": None for name in names_flatten }
def wrapped_func(*args: Any, **kwargs: Any) -> Optional[Any]: if not self._update_called: rank_zero_warn( f"The ``compute`` method of metric {self.__class__.__name__}" " was called before the ``update`` method which may lead to errors," " as metric states have not yet been updated.", ) # return cached value if self._computed is not None: return self._computed self._computed = compute(*args, **kwargs) return self._computed
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params = _convert_params(params) params = _flatten_dict(params) for k, v in params.items(): if len(str(v)) > 250: rank_zero_warn( f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}", category=RuntimeWarning) continue self.experiment.log_param(self.run_id, k, v)
def log_graph(self, model: "pl.LightningModule", input_array=None): if self._log_graph: if input_array is None: input_array = model.example_input_array if input_array is not None: self.experiment.add_graph( model, model._apply_batch_transfer_handler(input_array)) else: rank_zero_warn( "Could not log computational graph since neither the" " `model.example_input_array` attribute is set nor" " `input_array` was given", )
def __init__(self, log_dir: str) -> None: self.hparams = {} self.metrics = [] self.log_dir = log_dir if os.path.exists(self.log_dir) and os.listdir(self.log_dir): rank_zero_warn( f"Experiment logs directory {self.log_dir} exists and is not empty." " Previous log files in this directory will be deleted when the new ones are saved!" ) os.makedirs(self.log_dir, exist_ok=True) self.metrics_file_path = os.path.join(self.log_dir, self.NAME_METRICS_FILE)
def _select_data_fetcher(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]: training_step_fx = getattr(trainer.lightning_module, "training_step") if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True): rank_zero_warn( "Found `dataloader_iter` argument in the `training_step`. Note that the support for " "this signature is experimental and the behavior is subject to change." ) return DataLoaderIterDataFetcher elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1": if not isinstance(trainer.accelerator, CUDAAccelerator): raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.") return InterBatchParallelDataFetcher return DataFetcher
def clean_namespace(hparams: Union[Dict[str, Any], Namespace]) -> None: """Removes all unpicklable entries from hparams.""" hparams_dict = hparams if isinstance(hparams, Namespace): hparams_dict = hparams.__dict__ del_attrs = [k for k, v in hparams_dict.items() if not is_picklable(v)] for k in del_attrs: rank_zero_warn( f"attribute '{k}' removed from hparams because it cannot be pickled" ) del hparams_dict[k]
def log_graph(self, model: "pl.LightningModule", input_array=None): if self._log_graph: if input_array is None: input_array = model.example_input_array if input_array is not None: input_array = model._apply_batch_transfer_handler(input_array) model._running_torchscript = True self.experiment.add_graph(model, input_array) model._running_torchscript = False else: rank_zero_warn( "Could not log computational graph since the" " `model.example_input_array` attribute is not set" " or `input_array` was not given", )
def inner_fn(self, *args, **kwargs): pre_layer_count = len(list(self.model.parameters())) module = fn(self, *args, **kwargs) self.model.on_post_move_to_device() post_layer_count = len(list(self.model.parameters())) if not pre_layer_count == post_layer_count: rank_zero_warn( "The model layers do not match after moving to the target device." " If your model employs weight sharing on TPU," " please tie your weights using the `on_post_move_to_device` model hook.\n" f"Layer count: [Before: {pre_layer_count} After: {post_layer_count}]" ) return module