def on_run_start(self) -> None: # type: ignore[override] """Calls the ``on_train_start`` hook.""" # reset train dataloader and val dataloader self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module) ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled if not ft_enabled and self.restarting and self.trainer.num_training_batches not in ( 0, float("inf")): self.trainer.accumulate_grad_batches = self.trainer.accumulation_scheduler.get_accumulate_grad_batches( self.trainer.current_epoch) expected_steps = math.ceil(self.trainer.num_training_batches / self.trainer.accumulate_grad_batches) # global_step is incremented during checkpointing (#11555) if (self.trainer.global_step - 1) % expected_steps != 0: rank_zero_warn( "You're resuming from a checkpoint that ended mid-epoch." " Training will start from the beginning of the next epoch." " This can cause unreliable results if further training is done," " consider using an end of epoch checkpoint or use fault-tolerant training" " to restart as if training did not stop.") self._is_fresh_start_epoch = True self._results.to(device=self.trainer.lightning_module.device) self.trainer._call_callback_hooks("on_train_start") self.trainer._call_lightning_module_hook("on_train_start") self.trainer._call_strategy_hook("on_train_start")
def _resolve_overfit_batches( dataloader: Collection[DataLoader]) -> Collection[DataLoader]: all_have_sequential_sampler = True def resolve_has_no_sequential_sampler(dataloader: DataLoader): nonlocal all_have_sequential_sampler all_have_sequential_sampler = all_have_sequential_sampler & isinstance( dataloader.sampler, SequentialSampler) apply_to_collection(dataloader, DataLoader, resolve_has_no_sequential_sampler) if not all_have_sequential_sampler: rank_zero_warn( "You requested to overfit but enabled training dataloader shuffling." " We are turning off the training dataloader shuffling for you." ) def replace_sampler(dataloader: DataLoader) -> DataLoader: return _update_dataloader(dataloader, SequentialSampler( dataloader.dataset), mode=RunningStage.TRAINING) dataloader = apply_to_collection(dataloader, DataLoader, replace_sampler) return dataloader
def experiment(self) -> Run: r""" Actual wandb object. To use wandb features in your :class:`~pytorch_lightning.core.lightning.LightningModule` do the following. Example:: self.logger.experiment.some_wandb_function() """ if self._experiment is None: if self._offline: os.environ["WANDB_MODE"] = "dryrun" if wandb.run is None: self._experiment = wandb.init(**self._wandb_init) else: rank_zero_warn( "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse" " this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`." ) self._experiment = wandb.run # define default x-axis (for latest wandb versions) if getattr(self._experiment, "define_metric", None): self._experiment.define_metric("trainer/global_step") self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True) return self._experiment
def init_deepspeed(self): # deepspeed handles gradient clipping internally if is_overridden("configure_gradient_clipping", self.lightning_module, pl.LightningModule): rank_zero_warn( "Since DeepSpeed handles gradient clipping internally, the default" " `LightningModule.configure_gradient_clipping` implementation will not actually clip gradients." " The hook will still be called. Consider setting" " `Trainer(gradient_clip_val=..., gradient_clip_algorithm='norm')`" " which will use the internal mechanism." ) if self.lightning_module.trainer.gradient_clip_algorithm == GradClipAlgorithmType.VALUE: raise MisconfigurationException("DeepSpeed does not support clipping gradients by value.") if not isinstance(self.accelerator, CUDAAccelerator): raise MisconfigurationException( f"DeepSpeed strategy is only supported on GPU but `{self.accelerator.__class__.__name__}` is used." ) accumulation_scheduler = self.lightning_module.trainer.accumulation_scheduler if accumulation_scheduler.epochs != [0]: raise MisconfigurationException( "DeepSpeed currently does not support different `accumulate_grad_batches` at different epochs." ) model = LightningDeepSpeedModule(pl_module=self.model, precision=self.precision_plugin.precision) if self.lightning_module.trainer and self.lightning_module.trainer.training: self._initialize_deepspeed_train(model) else: self._initialize_deepspeed_inference(model)
def _check_eval_shuffling(dataloader, mode): if _is_dataloader_shuffled(dataloader): rank_zero_warn( f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`," " it is strongly recommended that you turn this off for val/test/predict dataloaders.", category=PossibleUserWarning, )
def __verify_eval_loop_configuration(self, model: "pl.LightningModule", stage: str) -> None: loader_name = f"{stage}_dataloader" step_name = "validation_step" if stage == "val" else "test_step" has_loader = is_overridden(loader_name, model) has_step = is_overridden(step_name, model) if has_loader and not has_step: rank_zero_warn( f"you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop" ) if has_step and not has_loader: rank_zero_warn( f"you defined a {step_name} but have no {loader_name}. Skipping {stage} loop" ) # ---------------------------------------------- # verify model does not have # - on_val_dataloader # - on_test_dataloader # ---------------------------------------------- has_on_val_dataloader = is_overridden("on_val_dataloader", model) if has_on_val_dataloader: rank_zero_deprecation( "Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0." " Please use `val_dataloader()` directly.") has_on_test_dataloader = is_overridden("on_test_dataloader", model) if has_on_test_dataloader: rank_zero_deprecation( "Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0." " Please use `test_dataloader()` directly.")
def lightning_restore_optimizer_and_schedulers(self) -> bool: # managed by DeepSpeed if self.load_full_weights and self.zero_stage_3 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: rank_zero_warn( "A single checkpoint file has been given. This means optimizer states and " "scheduler states can not be restored. If you'd like to restore these states, you must " "provide a path to the originally saved DeepSpeed checkpoint.") return False
def set_distributed_mode(self, distributed_backend): self.use_dp = False self.use_ddp = False self.use_ddp2 = False self.single_gpu = False if distributed_backend is None: if self.num_gpus == 0: if self.num_nodes > 1 or self.num_processes > 1: self.use_ddp = True # ddp_cpu elif self.num_gpus == 1: self.single_gpu = True elif self.num_gpus > 1: rank_zero_warn( 'You requested multiple GPUs but did not specify a backend, e.g.' ' Trainer(distributed_backend=dp) (or ddp, ddp2).' ' Setting distributed_backend=dp for you.') self.use_dp = True elif distributed_backend == "dp": # do nothing if num_gpus == 0 if self.num_gpus == 1: self.single_gpu = True self.use_dp = True elif self.num_gpus > 1: self.use_dp = True elif distributed_backend == "ddp": if self.num_gpus == 0: if self.num_nodes > 1 or self.num_processes > 1: self.use_ddp = True # ddp_cpu elif self.num_gpus == 1: self.single_gpu = True self.use_ddp = True elif self.num_gpus > 1: self.use_ddp = True self.num_processes = self.num_gpus elif distributed_backend == "ddp2": # do nothing if num_gpus == 0 if self.num_gpus >= 1: self.use_ddp2 = True elif distributed_backend == "ddp_cpu": if self.num_gpus > 0: rank_zero_warn( 'You requested one or more GPUs, but set the backend to `ddp_cpu`.' ' Training will not use GPUs.') self.use_ddp = True self.data_parallel_device_ids = None self.on_gpu = False # throw error to force user ddp or ddp2 choice if self.num_nodes > 1 and not (self.use_ddp2 or self.use_ddp): raise MisconfigurationException( 'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. ' 'To silence this warning set distributed_backend=ddp or distributed_backend=ddp2' ) log.info( f'GPU available: {torch.cuda.is_available()}, used: {self.on_gpu}')
def lightning_restore_optimizer(self) -> bool: # managed by DeepSpeed if self.load_full_weights and self.zero_stage_3 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: rank_zero_warn( "A single checkpoint file has been given. This means optimizer states cannot be restored." " If you'd like to restore these states, you must provide a path to the originally saved DeepSpeed" " checkpoint. When using ZeRO 3, the original path should be a directory." ) return False
def __init__( self, name: Optional[str] = None, save_dir: Optional[str] = None, offline: Optional[bool] = False, id: Optional[str] = None, anonymous: Optional[bool] = None, version: Optional[str] = None, project: Optional[str] = None, log_model: Optional[bool] = False, experiment=None, prefix: Optional[str] = "", **kwargs, ): if wandb is None: raise ImportError( "You want to use `wandb` logger which is not installed yet," " install it with `pip install wandb`." # pragma: no-cover ) if offline and log_model: raise MisconfigurationException( f"Providing log_model={log_model} and offline={offline} is an invalid configuration" " since model checkpoints cannot be uploaded in offline mode.\n" "Hint: Set `offline=False` to log your model." ) if log_model and not _WANDB_GREATER_EQUAL_0_10_22: rank_zero_warn( f"Providing log_model={log_model} requires wandb version >= 0.10.22" " for logging associated model metadata.\n" "Hint: Upgrade with `pip install --ugrade wandb`." ) super().__init__() self._offline = offline self._log_model = log_model self._prefix = prefix self._experiment = experiment self._logged_model_time = {} self._checkpoint_callback = None # set wandb init arguments anonymous_lut = {True: "allow", False: None} self._wandb_init = dict( name=name, project=project, id=version or id, dir=save_dir, resume="allow", anonymous=anonymous_lut.get(anonymous, anonymous), ) self._wandb_init.update(**kwargs) # extract parameters self._save_dir = self._wandb_init.get("dir") self._name = self._wandb_init.get("name") self._id = self._wandb_init.get("id")
def get_lr(self): if not self._get_lr_called_within_step: rank_zero_warn( "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.") return [ base_lr * self.lr_lambda(self.last_epoch) for base_lr in self.base_lrs ]
def _check_eval_shuffling(dataloader, mode): if ( hasattr(dataloader, "sampler") and not isinstance(dataloader.sampler, SequentialSampler) and not isinstance(dataloader.dataset, IterableDataset) ): rank_zero_warn( f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`," " it is strongly recommended that you turn this off for val/test/predict dataloaders.", category=PossibleUserWarning, )
def clean_namespace(hparams: Union[Dict[str, Any], Namespace]) -> None: """Removes all unpicklable entries from hparams""" hparams_dict = hparams if isinstance(hparams, Namespace): hparams_dict = hparams.__dict__ del_attrs = [k for k, v in hparams_dict.items() if not is_picklable(v)] for k in del_attrs: rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled") del hparams_dict[k]
def _worker_check(self, dataloader: DataLoader, name: str) -> None: if not isinstance(dataloader, DataLoader): return using_spawn = self.trainer._accelerator_connector._strategy_type == _StrategyType.DDP_SPAWN num_cpus = multiprocessing.cpu_count() # ddp_spawn + num_workers > 0 don't mix! tell the user if dataloader.num_workers > 0 and using_spawn: # checks for the attr persistent_workers available in pytorch >= 1.7 if hasattr(dataloader, "persistent_workers"): if not dataloader.persistent_workers: rank_zero_warn( "num_workers>0, persistent_workers=False, and strategy=ddp_spawn" " may result in data loading bottlenecks." " Consider setting persistent_workers=True" " (this is a limitation of Python .spawn() and PyTorch)" ) else: rank_zero_warn( "num_workers>0 and strategy=ddp_spawn do not mix well" " and may result in data loading bottlenecks." " Consider setting strategy=ddp to use num_workers>0" " (this is a limitation of Python .spawn() and PyTorch)" ) elif dataloader.num_workers == 0 and using_spawn: # checks for the attr persistent_workers available in pytorch >= 1.7 if hasattr(dataloader, "persistent_workers"): if not dataloader.persistent_workers: rank_zero_warn( "strategy=ddp_spawn and num_workers=0 may result in data loading bottlenecks." " Consider setting num_workers>0 and persistent_workers=True" ) else: rank_zero_warn( "strategy=ddp_spawn and num_workers=0 may result in data loading bottlenecks." " Consider setting strategy=ddp and set num_workers>0" ) elif dataloader.num_workers <= 2 < num_cpus and not using_spawn: # if changed, update the `filterwarnings` snippet in 'speed.html#num-workers' rank_zero_warn( f"The dataloader, {name}, does not have many workers which may be a bottleneck." " Consider increasing the value of the `num_workers` argument`" f" (try {num_cpus} which is the number of cpus on this machine)" " in the `DataLoader` init to improve performance.", category=PossibleUserWarning, )
def rank_zero_warn(*args, stacklevel: int = 5, **kwargs): from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn rank_zero_deprecation( '`pytorch_lightning.utilities.distributed.rank_zero_warn` has been moved to' ' `pytorch_lightning.utilities.rank_zero_warn` in v1.3.7 and will be removed in v1.6' ) return rank_zero_warn(*args, stacklevel=stacklevel, **kwargs)
def _select_data_fetcher(self) -> AbstractDataFetcher: if not self.trainer.training: return DataFetcher() training_step_fx = getattr(self.trainer.lightning_module, "training_step") if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True): rank_zero_warn( "Found `dataloader_iter` argument in the `training_step`. Note that the support for " "this signature is experimental and the behavior is subject to change." ) return DataLoaderIterDataFetcher() elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1": if not isinstance(self.trainer.accelerator, GPUAccelerator): raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.") return InterBatchParallelDataFetcher() return DataFetcher()
def set_distributed_mode(self, distributed_backend, num_gpu_nodes): # skip for CPU if self.num_gpus == 0: return # single GPU case # in single gpu case we allow ddp so we can train on multiple # nodes, 1 gpu per node if self.num_gpus == 1: self.single_gpu = True if distributed_backend is not None: self.use_dp = distributed_backend == 'dp' self.use_ddp = distributed_backend == 'ddp' self.use_ddp2 = distributed_backend == 'ddp2' # disable single gpu when using ddp2 if self.use_ddp2: self.single_gpu = False # multiple GPU case elif self.num_gpus > 1: if distributed_backend is not None: # DP, DDP case self.use_dp = distributed_backend == 'dp' self.use_ddp = distributed_backend == 'ddp' self.use_ddp2 = distributed_backend == 'ddp2' elif distributed_backend is None: rank_zero_warn( 'You requested multiple GPUs but did not specify a backend, e.g.' ' Trainer(distributed_backend=dp) (or ddp, ddp2).' ' Setting distributed_backend=dp for you.') self.use_dp = True self.use_ddp = False self.use_ddp2 = False # throw error to force user ddp or ddp2 choice if num_gpu_nodes > 1 and not (self.use_ddp2 or self.use_ddp): raise MisconfigurationException( 'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. ' 'To silence this warning set distributed_backend=ddp or distributed_backend=ddp2' ) log.info( f'GPU available: {torch.cuda.is_available()}, used: {self.on_gpu}')
def _format_batch_size_and_grad_accum_config(self): if "gradient_accumulation_steps" in self.config: raise MisconfigurationException( "Within the DeepSpeed config, do not set gradient_accumulation_steps" " as this will be set via accumulate_grad_batches=x argument passed via the Lightning Trainer." ) if "train_micro_batch_size_per_gpu" not in self.config: rank_zero_warn( "Inferring the batch size for internal deepspeed logging from the `train_dataloader()`. " "If you require skipping this, please pass " "`Trainer(plugins=DeepSpeedPlugin(logging_batch_size_per_gpu=batch_size))`" ) batch_size = self._auto_select_batch_size() self.config["train_micro_batch_size_per_gpu"] = batch_size self.config[ "gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches if "gradient_clipping" not in self.config: self.config[ "gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val
def init_deepspeed(self): # deepspeed handles gradient clipping internally if is_overridden("configure_gradient_clipping", self.lightning_module, pl.LightningModule): rank_zero_warn( "Since DeepSpeed handles gradient clipping internally, the default" " `LightningModule.configure_gradient_clipping` implementation will not actually clip gradients." " The hook will still be called. Consider setting" " `Trainer(gradient_clip_val=..., gradient_clip_algorithm='norm')`" " which will use the internal mechanism.") if self.lightning_module.trainer.gradient_clip_algorithm == GradClipAlgorithmType.VALUE: raise MisconfigurationException( "DeepSpeed does not support clipping gradients by value.") accumulation_scheduler = self.lightning_module.trainer.accumulation_scheduler if accumulation_scheduler.epochs != [0]: raise MisconfigurationException( "DeepSpeed currently does not support different `accumulate_grad_batches` at different epochs." ) precision = self.lightning_module.trainer.accelerator.precision model = LightningDeepSpeedModule(pl_module=self.model, precision=precision) if self.zero_stage_3 and self.partition_module: # Ensure the entire model has been moved to the appropriate device dtype = torch.float16 if self.precision in ( 16, "mixed") else torch.float32 deepspeed.zero.Init(module=model, remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype) if self.lightning_module.trainer and self.lightning_module.trainer.training: self._initialize_deepspeed_train(model) else: self._initialize_deepspeed_inference(model)
def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None) -> Sampler: if self._requires_distributed_sampler(dataloader): if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): raise MisconfigurationException( "You seem to have configured a sampler in your DataLoader. This will be replaced" " by `DistributedSampler` since `replace_sampler_ddp` is True and you are using" " distributed training. Either remove the sampler from your DataLoader or set" " `replace_sampler_ddp=False` if you want to use your custom sampler." ) sampler = self._get_distributed_sampler( dataloader, shuffle, mode=mode, overfit_batches=self.trainer.overfit_batches, **self.trainer.distributed_sampler_kwargs, ) # update docs too once this is resolved trainer_fn = self.trainer.state.fn if isinstance(sampler, DistributedSampler) and trainer_fn in ( TrainerFn.VALIDATING, TrainerFn.TESTING): rank_zero_warn( f"Using `DistributedSampler` with the dataloaders. During `trainer.{trainer_fn.value}()`," " it is recommended to use `Trainer(devices=1)` to ensure each sample/batch gets evaluated" " exactly once. Otherwise, multi-device settings use `DistributedSampler` that replicates" " some samples to make sure all devices have same batch size in case of uneven inputs.", category=PossibleUserWarning, ) return sampler return dataloader.sampler
import os from contextlib import redirect_stderr from io import StringIO from pytorch_lightning.utilities.warnings import _warn, rank_zero_deprecation, rank_zero_warn, WarningCache running_special = os.getenv("PL_RUNNING_SPECIAL_TESTS", "0") == "1" if running_special: stderr = StringIO() # recording with redirect_stderr(stderr): _warn("test1") _warn("test2", DeprecationWarning) rank_zero_warn("test3") rank_zero_warn("test4", DeprecationWarning) rank_zero_deprecation("test5") cache = WarningCache() cache.warn("test6") cache.deprecation("test7") output = stderr.getvalue() assert "test_warnings.py:30: UserWarning: test1" in output assert "test_warnings.py:31: DeprecationWarning: test2" in output assert "test_warnings.py:33: UserWarning: test3" in output assert "test_warnings.py:34: DeprecationWarning: test4" in output
def register_ddp_comm_hook( model: DistributedDataParallel, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[Callable] = None, ddp_comm_wrapper: Optional[Callable] = None, ) -> None: """Function to register communication hook for DDP model https://pytorch.org/docs/master/ddp_comm_hooks.html. Args: model: DDP model ddp_comm_state: state is passed to the hook and can be used to maintain and update any state information that users would like to maintain as part of the training process. Examples: error feedback in gradient compression, peers to communicate with next in GossipGrad etc. ddp_comm_hook: hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future This callable function is called once the bucket is ready. The hook can perform whatever processing is needed and return a Future indicating completion of any async work (ex: allreduce). If the hook doesn't perform any communication, it can also just return a completed Future. The Future should hold the new value of grad bucket's tensors. Once a bucket is ready, c10d reducer would call this hook and use the tensors returned by the Future and copy grads to individual parameters. ddp_comm_wrapper: communication hook wraper to support a communication hook such as FP16 compression as wrapper, which could be combined with ddp_comm_hook .. warning :: DDP communication hook needs pytorch version at least 1.8.0 .. warning :: DDP communication wrapper needs pytorch version at least 1.9.0 Post-localSGD hook needs pytorch version at least 1.9.0 Example: from torch.distributed.algorithms.ddp_comm_hooks import ( default_hooks as default, powerSGD_hook as powerSGD, post_localSGD_hook as post_localSGD, ) # fp16_compress_hook for compress gradients register_ddp_comm_hook( model=ddp_model, ddp_comm_hook=default.fp16_compress_hook, ) # powerSGD_hook register_ddp_comm_hook( model=ddp_model, ddp_comm_state=powerSGD.PowerSGDState( process_group=None, matrix_approximation_rank=1, start_powerSGD_iter=5000, ), ddp_comm_hook=powerSGD.powerSGD_hook, ) # post_localSGD_hook subgroup, _ = torch.distributed.new_subgroups() register_comm_hook( model=ddp_model, state=post_localSGD.PostLocalSGDState( process_group=None, subgroup=subgroup, start_localSGD_iter=1_000, ), ddp_comm_hook=post_localSGD.post_localSGD_hook, ) # fp16_compress_wrapper combined with other communication hook register_ddp_comm_hook( model=ddp_model, ddp_comm_state=powerSGD.PowerSGDState( process_group=None, matrix_approximation_rank=1, start_powerSGD_iter=5000, ), ddp_comm_hook=powerSGD.powerSGD_hook, ddp_comm_wrapper=default.fp16_compress_wrapper, ) """ from pytorch_lightning.utilities import rank_zero_warn if not _TORCH_GREATER_EQUAL_1_8: rank_zero_warn("Not registering DDP comm hook. To use communication hooks, please use pytorch>=1.8.0.") return if ddp_comm_hook is None: return # inform mypy that ddp_comm_hook is callable ddp_comm_hook: Callable = ddp_comm_hook if ddp_comm_wrapper is not None: if not _TORCH_GREATER_EQUAL_1_9: rank_zero_warn("Not applying DDP comm wrapper. To use communication wrapper, please use pytorch>=1.9.0.") else: rank_zero_info( f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})." ) ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook) rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.") model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook)
import os from contextlib import redirect_stderr from io import StringIO from pytorch_lightning.utilities.warnings import _warn, rank_zero_deprecation, rank_zero_warn, WarningCache standalone = os.getenv("PL_RUN_STANDALONE_TESTS", "0") == "1" if standalone: stderr = StringIO() # recording with redirect_stderr(stderr): _warn("test1") _warn("test2", category=DeprecationWarning) rank_zero_warn("test3") rank_zero_warn("test4", category=DeprecationWarning) rank_zero_deprecation("test5") cache = WarningCache() cache.warn("test6") cache.deprecation("test7") output = stderr.getvalue() assert "test_warnings.py:30: UserWarning: test1" in output assert "test_warnings.py:31: DeprecationWarning: test2" in output assert "test_warnings.py:33: UserWarning: test3" in output assert "test_warnings.py:34: DeprecationWarning: test4" in output
def _reset_eval_dataloader( self, mode: RunningStage, model: Optional["pl.LightningModule"] = None ) -> Tuple[List[Union[int, float]], List[DataLoader]]: """Generic method to reset a dataloader for evaluation. Args: mode: The running stage of the ``Trainer`` model: The ``LightningModule`` if calling this outside of the trainer scope. Returns: Tuple (num_batches, dataloaders) """ assert mode.evaluating or mode == RunningStage.PREDICTING # always get the loaders first so we can count how many there are dataloaders = self._request_dataloader(mode, model=model) if not isinstance(dataloaders, list): dataloaders = [dataloaders] if any(dl is None for dl in dataloaders): rank_zero_warn("One of given dataloaders is None and it will be skipped.") for loader in dataloaders: apply_to_collection( loader.loaders if isinstance(loader, CombinedLoader) else loader, DataLoader, self._check_eval_shuffling, mode=mode, ) # add samplers dataloaders = [self._prepare_dataloader(dl, False, mode=mode) for dl in dataloaders if dl is not None] # add worker_init_fn for correct seeding in worker processes apply_to_collection( dataloaders, dtype=DataLoader, function=_auto_add_worker_init_fn, rank=self.trainer.global_rank ) loader_num_batches = [] # determine number of batches # datasets could be none, 1 or 2+ module = model or self.trainer.lightning_module or self.datamodule if len(dataloaders) != 0: for i, dataloader in enumerate(dataloaders): orig_num_batches = num_batches = ( len(dataloader) if has_len_all_ranks(dataloader, self.trainer.strategy, module) else float("inf") ) self._worker_check(dataloader, f"{mode.dataloader_prefix}_dataloader {i}") # percent or num_steps limit_eval_batches = getattr(self.trainer, f"limit_{mode.dataloader_prefix}_batches") # limit num batches either as a percent or num steps if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0: num_batches = min(num_batches, int(limit_eval_batches)) elif num_batches != float("inf"): num_batches = int(num_batches * limit_eval_batches) elif limit_eval_batches != 1.0: raise MisconfigurationException( f"When using an IterableDataset for `limit_{mode}_batches`," f" `Trainer(limit_{mode.dataloader_prefix}_batches)` must be `0.0`, `1.0` or an int. An int k" f" specifies `num_{mode.dataloader_prefix}_batches` to use." ) if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float): min_pct = 1.0 / len(dataloader) raise MisconfigurationException( f"you requested to check {limit_eval_batches} of the `{mode.dataloader_prefix}_dataloader` but" f" {limit_eval_batches} * {orig_num_batches} < 1. Please increase the" f" `limit_{mode.dataloader_prefix}_batches` flag. Try at least" f" `limit_{mode.dataloader_prefix}_batches={min_pct}`" ) loader_num_batches.append(num_batches) return loader_num_batches, dataloaders
def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None: # ----------------------------------- # verify model has a training step # ----------------------------------- has_training_step = is_overridden("training_step", model) if not has_training_step: raise MisconfigurationException( "No `training_step()` method defined. Lightning `Trainer` expects as minimum a" " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined." ) # ----------------------------------- # verify model has a train dataloader # ----------------------------------- has_train_dataloader = trainer._data_connector._train_dataloader_source.is_defined( ) if not has_train_dataloader: raise MisconfigurationException( "No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a" " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined." ) # ----------------------------------- # verify model has optimizer # ----------------------------------- has_optimizers = is_overridden("configure_optimizers", model) if not has_optimizers: raise MisconfigurationException( "No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a" " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined." ) # ---------------------------------------------- # verify model does not have on_train_dataloader # ---------------------------------------------- has_on_train_dataloader = is_overridden("on_train_dataloader", model) if has_on_train_dataloader: rank_zero_deprecation( "Method `on_train_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0." " Please use `train_dataloader()` directly.") trainer.overriden_optimizer_step = is_overridden("optimizer_step", model) trainer.overriden_optimizer_zero_grad = is_overridden( "optimizer_zero_grad", model) automatic_optimization = model.automatic_optimization going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches( ) has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad if has_overriden_optimization_functions and going_to_accumulate_grad_batches and automatic_optimization: rank_zero_warn( "When using `Trainer(accumulate_grad_batches != 1)` and overriding" " `LightningModule.optimizer_{step,zero_grad}`, the hooks will not be called on every batch" " (rather, they are called on every optimization step).") # ----------------------------------- # verify model for val loop # ----------------------------------- has_val_loader = trainer._data_connector._val_dataloader_source.is_defined( ) has_val_step = is_overridden("validation_step", model) if has_val_loader and not has_val_step: rank_zero_warn( "You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop." ) if has_val_step and not has_val_loader: rank_zero_warn( "You defined a `validation_step` but have no `val_dataloader`. Skipping val loop." ) # ---------------------------------------------- # verify model does not have on_val_dataloader # ---------------------------------------------- has_on_val_dataloader = is_overridden("on_val_dataloader", model) if has_on_val_dataloader: rank_zero_deprecation( "Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0." " Please use `val_dataloader()` directly.")