Exemple #1
0
    def on_run_start(self) -> None:  # type: ignore[override]
        """Calls the ``on_train_start`` hook."""
        # reset train dataloader and val dataloader
        self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module)

        ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled
        if not ft_enabled and self.restarting and self.trainer.num_training_batches not in (
                0, float("inf")):
            self.trainer.accumulate_grad_batches = self.trainer.accumulation_scheduler.get_accumulate_grad_batches(
                self.trainer.current_epoch)
            expected_steps = math.ceil(self.trainer.num_training_batches /
                                       self.trainer.accumulate_grad_batches)

            # global_step is incremented during checkpointing (#11555)
            if (self.trainer.global_step - 1) % expected_steps != 0:
                rank_zero_warn(
                    "You're resuming from a checkpoint that ended mid-epoch."
                    " Training will start from the beginning of the next epoch."
                    " This can cause unreliable results if further training is done,"
                    " consider using an end of epoch checkpoint or use fault-tolerant training"
                    " to restart as if training did not stop.")

        self._is_fresh_start_epoch = True
        self._results.to(device=self.trainer.lightning_module.device)
        self.trainer._call_callback_hooks("on_train_start")
        self.trainer._call_lightning_module_hook("on_train_start")
        self.trainer._call_strategy_hook("on_train_start")
    def _resolve_overfit_batches(
            dataloader: Collection[DataLoader]) -> Collection[DataLoader]:
        all_have_sequential_sampler = True

        def resolve_has_no_sequential_sampler(dataloader: DataLoader):
            nonlocal all_have_sequential_sampler
            all_have_sequential_sampler = all_have_sequential_sampler & isinstance(
                dataloader.sampler, SequentialSampler)

        apply_to_collection(dataloader, DataLoader,
                            resolve_has_no_sequential_sampler)

        if not all_have_sequential_sampler:
            rank_zero_warn(
                "You requested to overfit but enabled training dataloader shuffling."
                " We are turning off the training dataloader shuffling for you."
            )

            def replace_sampler(dataloader: DataLoader) -> DataLoader:
                return _update_dataloader(dataloader,
                                          SequentialSampler(
                                              dataloader.dataset),
                                          mode=RunningStage.TRAINING)

            dataloader = apply_to_collection(dataloader, DataLoader,
                                             replace_sampler)

        return dataloader
Exemple #3
0
    def experiment(self) -> Run:
        r"""

        Actual wandb object. To use wandb features in your
        :class:`~pytorch_lightning.core.lightning.LightningModule` do the following.

        Example::

            self.logger.experiment.some_wandb_function()

        """
        if self._experiment is None:
            if self._offline:
                os.environ["WANDB_MODE"] = "dryrun"
            if wandb.run is None:
                self._experiment = wandb.init(**self._wandb_init)
            else:
                rank_zero_warn(
                    "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse"
                    " this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`."
                )
                self._experiment = wandb.run

        # define default x-axis (for latest wandb versions)
        if getattr(self._experiment, "define_metric", None):
            self._experiment.define_metric("trainer/global_step")
            self._experiment.define_metric("*",
                                           step_metric="trainer/global_step",
                                           step_sync=True)

        return self._experiment
    def init_deepspeed(self):
        # deepspeed handles gradient clipping internally
        if is_overridden("configure_gradient_clipping", self.lightning_module, pl.LightningModule):
            rank_zero_warn(
                "Since DeepSpeed handles gradient clipping internally, the default"
                " `LightningModule.configure_gradient_clipping` implementation will not actually clip gradients."
                " The hook will still be called. Consider setting"
                " `Trainer(gradient_clip_val=..., gradient_clip_algorithm='norm')`"
                " which will use the internal mechanism."
            )

        if self.lightning_module.trainer.gradient_clip_algorithm == GradClipAlgorithmType.VALUE:
            raise MisconfigurationException("DeepSpeed does not support clipping gradients by value.")

        if not isinstance(self.accelerator, CUDAAccelerator):
            raise MisconfigurationException(
                f"DeepSpeed strategy is only supported on GPU but `{self.accelerator.__class__.__name__}` is used."
            )

        accumulation_scheduler = self.lightning_module.trainer.accumulation_scheduler

        if accumulation_scheduler.epochs != [0]:
            raise MisconfigurationException(
                "DeepSpeed currently does not support different `accumulate_grad_batches` at different epochs."
            )

        model = LightningDeepSpeedModule(pl_module=self.model, precision=self.precision_plugin.precision)

        if self.lightning_module.trainer and self.lightning_module.trainer.training:
            self._initialize_deepspeed_train(model)
        else:
            self._initialize_deepspeed_inference(model)
 def _check_eval_shuffling(dataloader, mode):
     if _is_dataloader_shuffled(dataloader):
         rank_zero_warn(
             f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`,"
             " it is strongly recommended that you turn this off for val/test/predict dataloaders.",
             category=PossibleUserWarning,
         )
    def __verify_eval_loop_configuration(self, model: "pl.LightningModule",
                                         stage: str) -> None:
        loader_name = f"{stage}_dataloader"
        step_name = "validation_step" if stage == "val" else "test_step"

        has_loader = is_overridden(loader_name, model)
        has_step = is_overridden(step_name, model)

        if has_loader and not has_step:
            rank_zero_warn(
                f"you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop"
            )
        if has_step and not has_loader:
            rank_zero_warn(
                f"you defined a {step_name} but have no {loader_name}. Skipping {stage} loop"
            )

        # ----------------------------------------------
        # verify model does not have
        # - on_val_dataloader
        # - on_test_dataloader
        # ----------------------------------------------
        has_on_val_dataloader = is_overridden("on_val_dataloader", model)
        if has_on_val_dataloader:
            rank_zero_deprecation(
                "Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
                " Please use `val_dataloader()` directly.")

        has_on_test_dataloader = is_overridden("on_test_dataloader", model)
        if has_on_test_dataloader:
            rank_zero_deprecation(
                "Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
                " Please use `test_dataloader()` directly.")
 def lightning_restore_optimizer_and_schedulers(self) -> bool:
     # managed by DeepSpeed
     if self.load_full_weights and self.zero_stage_3 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
         rank_zero_warn(
             "A single checkpoint file has been given. This means optimizer states and "
             "scheduler states can not be restored. If you'd like to restore these states, you must "
             "provide a path to the originally saved DeepSpeed checkpoint.")
     return False
    def set_distributed_mode(self, distributed_backend):
        self.use_dp = False
        self.use_ddp = False
        self.use_ddp2 = False
        self.single_gpu = False

        if distributed_backend is None:
            if self.num_gpus == 0:
                if self.num_nodes > 1 or self.num_processes > 1:
                    self.use_ddp = True  # ddp_cpu
            elif self.num_gpus == 1:
                self.single_gpu = True
            elif self.num_gpus > 1:
                rank_zero_warn(
                    'You requested multiple GPUs but did not specify a backend, e.g.'
                    ' Trainer(distributed_backend=dp) (or ddp, ddp2).'
                    ' Setting distributed_backend=dp for you.')
                self.use_dp = True
        elif distributed_backend == "dp":
            # do nothing if num_gpus == 0
            if self.num_gpus == 1:
                self.single_gpu = True
                self.use_dp = True
            elif self.num_gpus > 1:
                self.use_dp = True
        elif distributed_backend == "ddp":
            if self.num_gpus == 0:
                if self.num_nodes > 1 or self.num_processes > 1:
                    self.use_ddp = True  # ddp_cpu
            elif self.num_gpus == 1:
                self.single_gpu = True
                self.use_ddp = True
            elif self.num_gpus > 1:
                self.use_ddp = True
                self.num_processes = self.num_gpus
        elif distributed_backend == "ddp2":
            # do nothing if num_gpus == 0
            if self.num_gpus >= 1:
                self.use_ddp2 = True
        elif distributed_backend == "ddp_cpu":
            if self.num_gpus > 0:
                rank_zero_warn(
                    'You requested one or more GPUs, but set the backend to `ddp_cpu`.'
                    ' Training will not use GPUs.')
            self.use_ddp = True
            self.data_parallel_device_ids = None
            self.on_gpu = False

        # throw error to force user ddp or ddp2 choice
        if self.num_nodes > 1 and not (self.use_ddp2 or self.use_ddp):
            raise MisconfigurationException(
                'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. '
                'To silence this warning set distributed_backend=ddp or distributed_backend=ddp2'
            )

        log.info(
            f'GPU available: {torch.cuda.is_available()}, used: {self.on_gpu}')
 def lightning_restore_optimizer(self) -> bool:
     # managed by DeepSpeed
     if self.load_full_weights and self.zero_stage_3 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
         rank_zero_warn(
             "A single checkpoint file has been given. This means optimizer states cannot be restored."
             " If you'd like to restore these states, you must provide a path to the originally saved DeepSpeed"
             " checkpoint. When using ZeRO 3, the original path should be a directory."
         )
     return False
Exemple #10
0
    def __init__(
        self,
        name: Optional[str] = None,
        save_dir: Optional[str] = None,
        offline: Optional[bool] = False,
        id: Optional[str] = None,
        anonymous: Optional[bool] = None,
        version: Optional[str] = None,
        project: Optional[str] = None,
        log_model: Optional[bool] = False,
        experiment=None,
        prefix: Optional[str] = "",
        **kwargs,
    ):
        if wandb is None:
            raise ImportError(
                "You want to use `wandb` logger which is not installed yet,"
                " install it with `pip install wandb`."  # pragma: no-cover
            )

        if offline and log_model:
            raise MisconfigurationException(
                f"Providing log_model={log_model} and offline={offline} is an invalid configuration"
                " since model checkpoints cannot be uploaded in offline mode.\n"
                "Hint: Set `offline=False` to log your model."
            )

        if log_model and not _WANDB_GREATER_EQUAL_0_10_22:
            rank_zero_warn(
                f"Providing log_model={log_model} requires wandb version >= 0.10.22"
                " for logging associated model metadata.\n"
                "Hint: Upgrade with `pip install --ugrade wandb`."
            )

        super().__init__()
        self._offline = offline
        self._log_model = log_model
        self._prefix = prefix
        self._experiment = experiment
        self._logged_model_time = {}
        self._checkpoint_callback = None
        # set wandb init arguments
        anonymous_lut = {True: "allow", False: None}
        self._wandb_init = dict(
            name=name,
            project=project,
            id=version or id,
            dir=save_dir,
            resume="allow",
            anonymous=anonymous_lut.get(anonymous, anonymous),
        )
        self._wandb_init.update(**kwargs)
        # extract parameters
        self._save_dir = self._wandb_init.get("dir")
        self._name = self._wandb_init.get("name")
        self._id = self._wandb_init.get("id")
    def get_lr(self):
        if not self._get_lr_called_within_step:
            rank_zero_warn(
                "To get the last learning rate computed by the scheduler, "
                "please use `get_last_lr()`.")

        return [
            base_lr * self.lr_lambda(self.last_epoch)
            for base_lr in self.base_lrs
        ]
 def _check_eval_shuffling(dataloader, mode):
     if (
         hasattr(dataloader, "sampler")
         and not isinstance(dataloader.sampler, SequentialSampler)
         and not isinstance(dataloader.dataset, IterableDataset)
     ):
         rank_zero_warn(
             f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`,"
             " it is strongly recommended that you turn this off for val/test/predict dataloaders.",
             category=PossibleUserWarning,
         )
Exemple #13
0
def clean_namespace(hparams: Union[Dict[str, Any], Namespace]) -> None:
    """Removes all unpicklable entries from hparams"""

    hparams_dict = hparams
    if isinstance(hparams, Namespace):
        hparams_dict = hparams.__dict__

    del_attrs = [k for k, v in hparams_dict.items() if not is_picklable(v)]

    for k in del_attrs:
        rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")
        del hparams_dict[k]
    def _worker_check(self, dataloader: DataLoader, name: str) -> None:
        if not isinstance(dataloader, DataLoader):
            return

        using_spawn = self.trainer._accelerator_connector._strategy_type == _StrategyType.DDP_SPAWN
        num_cpus = multiprocessing.cpu_count()

        # ddp_spawn + num_workers > 0 don't mix! tell the user
        if dataloader.num_workers > 0 and using_spawn:
            # checks for the attr persistent_workers available in pytorch >= 1.7
            if hasattr(dataloader, "persistent_workers"):
                if not dataloader.persistent_workers:
                    rank_zero_warn(
                        "num_workers>0, persistent_workers=False, and strategy=ddp_spawn"
                        " may result in data loading bottlenecks."
                        " Consider setting persistent_workers=True"
                        " (this is a limitation of Python .spawn() and PyTorch)"
                    )
            else:
                rank_zero_warn(
                    "num_workers>0 and strategy=ddp_spawn do not mix well"
                    " and may result in data loading bottlenecks."
                    " Consider setting strategy=ddp to use num_workers>0"
                    " (this is a limitation of Python .spawn() and PyTorch)"
                )

        elif dataloader.num_workers == 0 and using_spawn:
            # checks for the attr persistent_workers available in pytorch >= 1.7
            if hasattr(dataloader, "persistent_workers"):
                if not dataloader.persistent_workers:
                    rank_zero_warn(
                        "strategy=ddp_spawn and num_workers=0 may result in data loading bottlenecks."
                        " Consider setting num_workers>0 and persistent_workers=True"
                    )
            else:
                rank_zero_warn(
                    "strategy=ddp_spawn and num_workers=0 may result in data loading bottlenecks."
                    " Consider setting strategy=ddp and set num_workers>0"
                )

        elif dataloader.num_workers <= 2 < num_cpus and not using_spawn:
            # if changed, update the `filterwarnings` snippet in 'speed.html#num-workers'
            rank_zero_warn(
                f"The dataloader, {name}, does not have many workers which may be a bottleneck."
                " Consider increasing the value of the `num_workers` argument`"
                f" (try {num_cpus} which is the number of cpus on this machine)"
                " in the `DataLoader` init to improve performance.",
                category=PossibleUserWarning,
            )
def rank_zero_warn(*args, stacklevel: int = 5, **kwargs):
    from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn
    rank_zero_deprecation(
        '`pytorch_lightning.utilities.distributed.rank_zero_warn` has been moved to'
        ' `pytorch_lightning.utilities.rank_zero_warn` in v1.3.7 and will be removed in v1.6'
    )
    return rank_zero_warn(*args, stacklevel=stacklevel, **kwargs)
    def _select_data_fetcher(self) -> AbstractDataFetcher:
        if not self.trainer.training:
            return DataFetcher()

        training_step_fx = getattr(self.trainer.lightning_module, "training_step")
        if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
            rank_zero_warn(
                "Found `dataloader_iter` argument in the `training_step`. Note that the support for "
                "this signature is experimental and the behavior is subject to change."
            )
            return DataLoaderIterDataFetcher()
        elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1":
            if not isinstance(self.trainer.accelerator, GPUAccelerator):
                raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.")
            return InterBatchParallelDataFetcher()
        return DataFetcher()
Exemple #17
0
    def set_distributed_mode(self, distributed_backend, num_gpu_nodes):
        # skip for CPU
        if self.num_gpus == 0:
            return

        # single GPU case
        # in single gpu case we allow ddp so we can train on multiple
        # nodes, 1 gpu per node
        if self.num_gpus == 1:
            self.single_gpu = True

            if distributed_backend is not None:
                self.use_dp = distributed_backend == 'dp'
                self.use_ddp = distributed_backend == 'ddp'
                self.use_ddp2 = distributed_backend == 'ddp2'

                # disable single gpu when using ddp2
                if self.use_ddp2:
                    self.single_gpu = False

        # multiple GPU case
        elif self.num_gpus > 1:
            if distributed_backend is not None:
                # DP, DDP case
                self.use_dp = distributed_backend == 'dp'
                self.use_ddp = distributed_backend == 'ddp'
                self.use_ddp2 = distributed_backend == 'ddp2'

            elif distributed_backend is None:
                rank_zero_warn(
                    'You requested multiple GPUs but did not specify a backend, e.g.'
                    ' Trainer(distributed_backend=dp) (or ddp, ddp2).'
                    ' Setting distributed_backend=dp for you.')
                self.use_dp = True
                self.use_ddp = False
                self.use_ddp2 = False

        # throw error to force user ddp or ddp2 choice
        if num_gpu_nodes > 1 and not (self.use_ddp2 or self.use_ddp):
            raise MisconfigurationException(
                'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. '
                'To silence this warning set distributed_backend=ddp or distributed_backend=ddp2'
            )

        log.info(
            f'GPU available: {torch.cuda.is_available()}, used: {self.on_gpu}')
Exemple #18
0
 def _format_batch_size_and_grad_accum_config(self):
     if "gradient_accumulation_steps" in self.config:
         raise MisconfigurationException(
             "Within the DeepSpeed config, do not set gradient_accumulation_steps"
             " as this will be set via accumulate_grad_batches=x argument passed via the Lightning Trainer."
         )
     if "train_micro_batch_size_per_gpu" not in self.config:
         rank_zero_warn(
             "Inferring the batch size for internal deepspeed logging from the `train_dataloader()`. "
             "If you require skipping this, please pass "
             "`Trainer(plugins=DeepSpeedPlugin(logging_batch_size_per_gpu=batch_size))`"
         )
         batch_size = self._auto_select_batch_size()
         self.config["train_micro_batch_size_per_gpu"] = batch_size
     self.config[
         "gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches
     if "gradient_clipping" not in self.config:
         self.config[
             "gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val
Exemple #19
0
    def init_deepspeed(self):
        # deepspeed handles gradient clipping internally
        if is_overridden("configure_gradient_clipping", self.lightning_module,
                         pl.LightningModule):
            rank_zero_warn(
                "Since DeepSpeed handles gradient clipping internally, the default"
                " `LightningModule.configure_gradient_clipping` implementation will not actually clip gradients."
                " The hook will still be called. Consider setting"
                " `Trainer(gradient_clip_val=..., gradient_clip_algorithm='norm')`"
                " which will use the internal mechanism.")

        if self.lightning_module.trainer.gradient_clip_algorithm == GradClipAlgorithmType.VALUE:
            raise MisconfigurationException(
                "DeepSpeed does not support clipping gradients by value.")

        accumulation_scheduler = self.lightning_module.trainer.accumulation_scheduler

        if accumulation_scheduler.epochs != [0]:
            raise MisconfigurationException(
                "DeepSpeed currently does not support different `accumulate_grad_batches` at different epochs."
            )

        precision = self.lightning_module.trainer.accelerator.precision
        model = LightningDeepSpeedModule(pl_module=self.model,
                                         precision=precision)

        if self.zero_stage_3 and self.partition_module:
            # Ensure the entire model has been moved to the appropriate device
            dtype = torch.float16 if self.precision in (
                16, "mixed") else torch.float32
            deepspeed.zero.Init(module=model,
                                remote_device=self.remote_device,
                                pin_memory=True,
                                config=self.config,
                                dtype=dtype)

        if self.lightning_module.trainer and self.lightning_module.trainer.training:
            self._initialize_deepspeed_train(model)
        else:
            self._initialize_deepspeed_inference(model)
    def _resolve_sampler(self,
                         dataloader: DataLoader,
                         shuffle: bool,
                         mode: Optional[RunningStage] = None) -> Sampler:
        if self._requires_distributed_sampler(dataloader):
            if not isinstance(dataloader.sampler,
                              (SequentialSampler, RandomSampler)):
                raise MisconfigurationException(
                    "You seem to have configured a sampler in your DataLoader. This will be replaced"
                    " by `DistributedSampler` since `replace_sampler_ddp` is True and you are using"
                    " distributed training. Either remove the sampler from your DataLoader or set"
                    " `replace_sampler_ddp=False` if you want to use your custom sampler."
                )
            sampler = self._get_distributed_sampler(
                dataloader,
                shuffle,
                mode=mode,
                overfit_batches=self.trainer.overfit_batches,
                **self.trainer.distributed_sampler_kwargs,
            )

            # update docs too once this is resolved
            trainer_fn = self.trainer.state.fn
            if isinstance(sampler, DistributedSampler) and trainer_fn in (
                    TrainerFn.VALIDATING, TrainerFn.TESTING):
                rank_zero_warn(
                    f"Using `DistributedSampler` with the dataloaders. During `trainer.{trainer_fn.value}()`,"
                    " it is recommended to use `Trainer(devices=1)` to ensure each sample/batch gets evaluated"
                    " exactly once. Otherwise, multi-device settings use `DistributedSampler` that replicates"
                    " some samples to make sure all devices have same batch size in case of uneven inputs.",
                    category=PossibleUserWarning,
                )

            return sampler

        return dataloader.sampler
import os
from contextlib import redirect_stderr
from io import StringIO

from pytorch_lightning.utilities.warnings import _warn, rank_zero_deprecation, rank_zero_warn, WarningCache

running_special = os.getenv("PL_RUNNING_SPECIAL_TESTS", "0") == "1"
if running_special:

    stderr = StringIO()
    # recording
    with redirect_stderr(stderr):
        _warn("test1")
        _warn("test2", DeprecationWarning)

        rank_zero_warn("test3")
        rank_zero_warn("test4", DeprecationWarning)

        rank_zero_deprecation("test5")

        cache = WarningCache()
        cache.warn("test6")
        cache.deprecation("test7")

    output = stderr.getvalue()
    assert "test_warnings.py:30: UserWarning: test1" in output
    assert "test_warnings.py:31: DeprecationWarning: test2" in output

    assert "test_warnings.py:33: UserWarning: test3" in output
    assert "test_warnings.py:34: DeprecationWarning: test4" in output
Exemple #22
0
def register_ddp_comm_hook(
    model: DistributedDataParallel,
    ddp_comm_state: Optional[object] = None,
    ddp_comm_hook: Optional[Callable] = None,
    ddp_comm_wrapper: Optional[Callable] = None,
) -> None:
    """Function to register communication hook for DDP model https://pytorch.org/docs/master/ddp_comm_hooks.html.

    Args:
        model:
            DDP model
        ddp_comm_state:
            state is passed to the hook and can be used to maintain
            and update any state information that users would like to
            maintain as part of the training process. Examples: error
            feedback in gradient compression, peers to communicate with
            next in GossipGrad etc.
        ddp_comm_hook:
            hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future

            This callable function is called once the bucket is ready. The
            hook can perform whatever processing is needed and return
            a Future indicating completion of any async work (ex: allreduce).
            If the hook doesn't perform any communication, it can also
            just return a completed Future. The Future should hold the
            new value of grad bucket's tensors. Once a bucket is ready,
            c10d reducer would call this hook and use the tensors returned
            by the Future and copy grads to individual parameters.
        ddp_comm_wrapper:
            communication hook wraper to support a communication hook such
            as FP16 compression as wrapper, which could be combined with
            ddp_comm_hook

    .. warning ::
        DDP communication hook needs pytorch version at least 1.8.0

    .. warning ::
        DDP communication wrapper needs pytorch version at least 1.9.0
        Post-localSGD hook needs pytorch version at least 1.9.0

    Example:

        from torch.distributed.algorithms.ddp_comm_hooks import (
            default_hooks as default,
            powerSGD_hook as powerSGD,
            post_localSGD_hook as post_localSGD,
        )

        # fp16_compress_hook for compress gradients
        register_ddp_comm_hook(
            model=ddp_model,
            ddp_comm_hook=default.fp16_compress_hook,
        )

        # powerSGD_hook
        register_ddp_comm_hook(
            model=ddp_model,
            ddp_comm_state=powerSGD.PowerSGDState(
                process_group=None,
                matrix_approximation_rank=1,
                start_powerSGD_iter=5000,
            ),
            ddp_comm_hook=powerSGD.powerSGD_hook,
        )

        # post_localSGD_hook
        subgroup, _ = torch.distributed.new_subgroups()
        register_comm_hook(
            model=ddp_model,
            state=post_localSGD.PostLocalSGDState(
                process_group=None,
                subgroup=subgroup,
                start_localSGD_iter=1_000,
            ),
            ddp_comm_hook=post_localSGD.post_localSGD_hook,
        )

        # fp16_compress_wrapper combined with other communication hook
        register_ddp_comm_hook(
            model=ddp_model,
            ddp_comm_state=powerSGD.PowerSGDState(
                process_group=None,
                matrix_approximation_rank=1,
                start_powerSGD_iter=5000,
            ),
            ddp_comm_hook=powerSGD.powerSGD_hook,
            ddp_comm_wrapper=default.fp16_compress_wrapper,
        )
    """
    from pytorch_lightning.utilities import rank_zero_warn

    if not _TORCH_GREATER_EQUAL_1_8:
        rank_zero_warn("Not registering DDP comm hook. To use communication hooks, please use pytorch>=1.8.0.")
        return
    if ddp_comm_hook is None:
        return
    # inform mypy that ddp_comm_hook is callable
    ddp_comm_hook: Callable = ddp_comm_hook

    if ddp_comm_wrapper is not None:
        if not _TORCH_GREATER_EQUAL_1_9:
            rank_zero_warn("Not applying DDP comm wrapper. To use communication wrapper, please use pytorch>=1.9.0.")
        else:
            rank_zero_info(
                f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})."
            )
            ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook)

    rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.")
    model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook)
import os
from contextlib import redirect_stderr
from io import StringIO

from pytorch_lightning.utilities.warnings import _warn, rank_zero_deprecation, rank_zero_warn, WarningCache

standalone = os.getenv("PL_RUN_STANDALONE_TESTS", "0") == "1"
if standalone:

    stderr = StringIO()
    # recording
    with redirect_stderr(stderr):
        _warn("test1")
        _warn("test2", category=DeprecationWarning)

        rank_zero_warn("test3")
        rank_zero_warn("test4", category=DeprecationWarning)

        rank_zero_deprecation("test5")

        cache = WarningCache()
        cache.warn("test6")
        cache.deprecation("test7")

    output = stderr.getvalue()
    assert "test_warnings.py:30: UserWarning: test1" in output
    assert "test_warnings.py:31: DeprecationWarning: test2" in output

    assert "test_warnings.py:33: UserWarning: test3" in output
    assert "test_warnings.py:34: DeprecationWarning: test4" in output
    def _reset_eval_dataloader(
        self, mode: RunningStage, model: Optional["pl.LightningModule"] = None
    ) -> Tuple[List[Union[int, float]], List[DataLoader]]:
        """Generic method to reset a dataloader for evaluation.

        Args:
            mode: The running stage of the ``Trainer``
            model: The ``LightningModule`` if calling this outside of the trainer scope.

        Returns:
            Tuple (num_batches, dataloaders)
        """
        assert mode.evaluating or mode == RunningStage.PREDICTING

        # always get the loaders first so we can count how many there are
        dataloaders = self._request_dataloader(mode, model=model)

        if not isinstance(dataloaders, list):
            dataloaders = [dataloaders]

        if any(dl is None for dl in dataloaders):
            rank_zero_warn("One of given dataloaders is None and it will be skipped.")

        for loader in dataloaders:
            apply_to_collection(
                loader.loaders if isinstance(loader, CombinedLoader) else loader,
                DataLoader,
                self._check_eval_shuffling,
                mode=mode,
            )

        # add samplers
        dataloaders = [self._prepare_dataloader(dl, False, mode=mode) for dl in dataloaders if dl is not None]

        # add worker_init_fn for correct seeding in worker processes
        apply_to_collection(
            dataloaders, dtype=DataLoader, function=_auto_add_worker_init_fn, rank=self.trainer.global_rank
        )

        loader_num_batches = []

        # determine number of batches
        # datasets could be none, 1 or 2+
        module = model or self.trainer.lightning_module or self.datamodule
        if len(dataloaders) != 0:
            for i, dataloader in enumerate(dataloaders):
                orig_num_batches = num_batches = (
                    len(dataloader) if has_len_all_ranks(dataloader, self.trainer.strategy, module) else float("inf")
                )
                self._worker_check(dataloader, f"{mode.dataloader_prefix}_dataloader {i}")

                # percent or num_steps
                limit_eval_batches = getattr(self.trainer, f"limit_{mode.dataloader_prefix}_batches")

                # limit num batches either as a percent or num steps
                if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0:
                    num_batches = min(num_batches, int(limit_eval_batches))
                elif num_batches != float("inf"):
                    num_batches = int(num_batches * limit_eval_batches)
                elif limit_eval_batches != 1.0:
                    raise MisconfigurationException(
                        f"When using an IterableDataset for `limit_{mode}_batches`,"
                        f" `Trainer(limit_{mode.dataloader_prefix}_batches)` must be `0.0`, `1.0` or an int. An int k"
                        f" specifies `num_{mode.dataloader_prefix}_batches` to use."
                    )

                if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
                    min_pct = 1.0 / len(dataloader)
                    raise MisconfigurationException(
                        f"you requested to check {limit_eval_batches} of the `{mode.dataloader_prefix}_dataloader` but"
                        f" {limit_eval_batches} * {orig_num_batches} < 1. Please increase the"
                        f" `limit_{mode.dataloader_prefix}_batches` flag. Try at least"
                        f" `limit_{mode.dataloader_prefix}_batches={min_pct}`"
                    )

                loader_num_batches.append(num_batches)

        return loader_num_batches, dataloaders
def __verify_train_val_loop_configuration(trainer: "pl.Trainer",
                                          model: "pl.LightningModule") -> None:
    # -----------------------------------
    # verify model has a training step
    # -----------------------------------
    has_training_step = is_overridden("training_step", model)
    if not has_training_step:
        raise MisconfigurationException(
            "No `training_step()` method defined. Lightning `Trainer` expects as minimum a"
            " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined."
        )

    # -----------------------------------
    # verify model has a train dataloader
    # -----------------------------------
    has_train_dataloader = trainer._data_connector._train_dataloader_source.is_defined(
    )
    if not has_train_dataloader:
        raise MisconfigurationException(
            "No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a"
            " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined."
        )

    # -----------------------------------
    # verify model has optimizer
    # -----------------------------------
    has_optimizers = is_overridden("configure_optimizers", model)
    if not has_optimizers:
        raise MisconfigurationException(
            "No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a"
            " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined."
        )

    # ----------------------------------------------
    # verify model does not have on_train_dataloader
    # ----------------------------------------------
    has_on_train_dataloader = is_overridden("on_train_dataloader", model)
    if has_on_train_dataloader:
        rank_zero_deprecation(
            "Method `on_train_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
            " Please use `train_dataloader()` directly.")

    trainer.overriden_optimizer_step = is_overridden("optimizer_step", model)
    trainer.overriden_optimizer_zero_grad = is_overridden(
        "optimizer_zero_grad", model)
    automatic_optimization = model.automatic_optimization
    going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches(
    )

    has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad
    if has_overriden_optimization_functions and going_to_accumulate_grad_batches and automatic_optimization:
        rank_zero_warn(
            "When using `Trainer(accumulate_grad_batches != 1)` and overriding"
            " `LightningModule.optimizer_{step,zero_grad}`, the hooks will not be called on every batch"
            " (rather, they are called on every optimization step).")

    # -----------------------------------
    # verify model for val loop
    # -----------------------------------

    has_val_loader = trainer._data_connector._val_dataloader_source.is_defined(
    )
    has_val_step = is_overridden("validation_step", model)

    if has_val_loader and not has_val_step:
        rank_zero_warn(
            "You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop."
        )
    if has_val_step and not has_val_loader:
        rank_zero_warn(
            "You defined a `validation_step` but have no `val_dataloader`. Skipping val loop."
        )

    # ----------------------------------------------
    # verify model does not have on_val_dataloader
    # ----------------------------------------------
    has_on_val_dataloader = is_overridden("on_val_dataloader", model)
    if has_on_val_dataloader:
        rank_zero_deprecation(
            "Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
            " Please use `val_dataloader()` directly.")