def on_load_checkpoint(self, state_dict: Dict) -> None: # cache the dataloader state dict until the dataloader objects are available # dataset states are collected across all ranks dataloader_state_dict = state_dict.get("dataloader_state_dict", None) if not _fault_tolerant_training() or not dataloader_state_dict: return self._dataloader_state_dict = dataloader_state_dict[self.trainer.global_rank]
def register_signal_handlers(self) -> None: self._original_handlers = self._get_current_signal_handlers() sigusr1_handlers: List[_HANDLER] = [] sigterm_handlers: List[_HANDLER] = [] if _fault_tolerant_training(): sigterm_handlers.append(self.fault_tolerant_sigterm_handler_fn) environment = self.trainer._accelerator_connector.cluster_environment if isinstance(environment, SLURMEnvironment) and environment.auto_requeue: log.info("SLURM auto-requeueing enabled. Setting signal handlers.") sigusr1_handlers.append(self.slurm_sigusr1_handler_fn) sigterm_handlers.append(self.sigterm_handler_fn) # signal.SIGUSR1 doesn't seem available on windows if not self._is_on_windows(): if sigusr1_handlers and not self._has_already_handler( signal.SIGUSR1): self._register_signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers)) if sigterm_handlers and not self._has_already_handler( signal.SIGTERM): self._register_signal(signal.SIGTERM, HandlersCompose(sigterm_handlers))
def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Dict: """The state dict is determined by the state and progress of this loop and all its children. Args: destination: An existing dictionary to update with this loop's state. By default a new dictionary is returned. prefix: A prefix for each key in the state dictionary """ if destination is None: destination = {} destination[prefix + "state_dict"] = self.on_save_checkpoint() # do not get the mode from `self.trainer` because it might not have been attached yet ft_enabled = _fault_tolerant_training() for k, v in self.__dict__.items(): key = prefix + k if isinstance(v, BaseProgress): destination[key] = v.state_dict() elif isinstance(v, Loop): v.state_dict(destination, key + ".") elif ft_enabled and isinstance(v, _ResultCollection): # sync / unsync metrics v.sync() destination[key] = v.state_dict() v.unsync() return destination
def _apply_patch_fn(loader: DataLoader, iterator: Iterator): if isinstance(loader, CycleIterator): loader = loader.loader # cycle_iterator = iterator iterator = iterator._loader_iter if isinstance(loader, DataLoader) and _fault_tolerant_training(): loader._lightning_fetcher = self patch_dataloader_iterator(loader, iterator, self)
def _prepare_dataloader(self, dataloader: Any, shuffle: Optional[bool] = None, mode: Optional[RunningStage] = None) -> Any: """This function handles to following functionalities: - Injecting a `DistributedDataSampler` into the `DataLoader` if on a distributed environment - Wrapping the datasets and samplers into fault-tolerant components - Wrapping the dataloader based on strategy-specific logic """ if isinstance(dataloader, CombinedLoader): # apply `_prepare_dataloader` on all the collection of loaders dataloader.loaders = apply_to_collection( dataloader.loaders, (DataLoader, CycleIterator), self._prepare_dataloader, shuffle, mode=mode) # the length need to recomputed across all dataloaders in case of special behavior. dataloader._apply_cycle_iterator_length() return dataloader # don't do anything if it's not a dataloader if not isinstance(dataloader, (DataLoader, CycleIterator)): return dataloader cycle_iterator: Optional[CycleIterator] = None if isinstance(dataloader, CycleIterator): cycle_iterator = dataloader dataloader = dataloader.loader if (_fault_tolerant_training() # injects components to track the state or self._requires_distributed_sampler( dataloader) # sets the distributed sampler or mode == RunningStage.PREDICTING # to track indices for the predictions # IPUs use a custom `poptorch.DataLoader` which we might need to convert to or isinstance(self.trainer.accelerator, IPUAccelerator)): if shuffle is None: # for training, set to True always # for evaluation, decide based on existing sampler shuffle = True if mode == RunningStage.TRAINING else _is_dataloader_shuffled( dataloader) sampler = self._resolve_sampler(dataloader, shuffle=shuffle, mode=mode) dataloader = _update_dataloader(dataloader, sampler, mode=mode) dataloader = self.trainer.strategy.process_dataloader(dataloader) if cycle_iterator is not None: cycle_iterator.loader = dataloader return cycle_iterator return dataloader
def _dataloader_init_kwargs_resolve_sampler( dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None ) -> Dict[str, Any]: """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re-instantiation. If the dataloader is being used for prediction, the sampler will be wrapped into an `IndexBatchSamplerWrapper`, so Lightning can keep track of its indices. If fault tolerant training is enabled, the sampler will be wrapped into a `FastForwardSampler`. """ batch_sampler = getattr(dataloader, "batch_sampler") is_predicting = mode == RunningStage.PREDICTING # checking the batch sampler type is different than PyTorch default. if batch_sampler is not None and (type(batch_sampler) is not BatchSampler or is_predicting): batch_sampler = type(batch_sampler)( sampler, batch_size=batch_sampler.batch_size, drop_last=(False if is_predicting else batch_sampler.drop_last), ) if is_predicting: batch_sampler = IndexBatchSamplerWrapper(batch_sampler) if _fault_tolerant_training(): fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler) fast_forward_sampler.setup(dataloader_batch_size=1) return { "sampler": None, "shuffle": False, "batch_sampler": batch_sampler, "batch_size": 1, "drop_last": False, } if _fault_tolerant_training(): fast_forward_sampler = sampler = FastForwardSampler(sampler) fast_forward_sampler.setup(dataloader_batch_size=dataloader.batch_size) return {"sampler": sampler, "shuffle": False, "batch_sampler": None}
def _resolve_batch_sampler( dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None) -> Dict[str, Any]: batch_sampler = getattr(dataloader, "batch_sampler") is_predicting = mode == RunningStage.PREDICTING # checking the batch sampler type is different than PyTorch default. if (batch_sampler is not None and type(batch_sampler) is not BatchSampler) or is_predicting: batch_sampler = type(batch_sampler)( sampler, batch_size=batch_sampler.batch_size, drop_last=(False if is_predicting else batch_sampler.drop_last), ) if is_predicting: batch_sampler = IndexBatchSamplerWrapper(batch_sampler) if _fault_tolerant_training(): fast_forward_sampler = batch_sampler = FastForwardSampler( batch_sampler) fast_forward_sampler.setup(dataloader_batch_size=1) return { "sampler": None, "shuffle": False, "batch_sampler": batch_sampler, "batch_size": 1, "drop_last": False, } if _fault_tolerant_training(): fast_forward_sampler = sampler = FastForwardSampler(sampler) fast_forward_sampler.setup( dataloader_batch_size=dataloader.batch_size) return {"sampler": sampler, "shuffle": False, "batch_sampler": None}
def next_fn(iterator: Iterator): batch = next(iterator) if not _fault_tolerant_training(): return batch # when fault tolerant is enabled, the iterator will return # `FastForwardSampler` state_dict metadata # along side with the user data. # the metadata are extracted and store directly on the iterator # to simplify the collection on `state_dict` call. batch, samplers_state_dict = CaptureIterableDataset.extract_samplers_state_dict_from_batch( batch) # store the `sampler_state_dict` on the iterator CaptureIterableDataset.store_samplers_state_dict( iterator, samplers_state_dict) return batch
def state_dict(self, has_completed: bool = False) -> Dict: """The state dict includes all states from wrapped dataloaders and their samplers through the ``CaptureIterableDataset`` and fast-forward samplers. Args: has_completed: whether the current state of data fetching is considered completed or not. If it is, the current state gets returned, otherwise the previously cached state. """ if not _fault_tolerant_training() or self._iterator is None: return {} return apply_to_collection( self._iterator.loader_iters, Iterator, self._state_dict_fn, has_completed=has_completed, )
def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]: metrics = ([ m for m in self.trainer.lightning_module.modules() if isinstance(m, Metric) ] if _fault_tolerant_training() else []) for metric in metrics: metric.persistent(True) metric.sync() state_dict = self.trainer.strategy.lightning_module_state_dict() for metric in metrics: # sync can be a no-op (e.g. on cpu) so `unsync` would raise a user error exception if we don't check if metric._is_synced: metric.unsync() return state_dict
def register_signal_handlers(self) -> None: sigusr1_handlers: List[Callable] = [] sigterm_handlers: List[Callable] = [] if _fault_tolerant_training(): sigusr1_handlers.append(self.fault_tolerant_sigusr1_handler_fn) if self._is_on_slurm(): log.info("Set SLURM handle signals.") sigusr1_handlers.append(self.slurm_sigusr1_handler_fn) sigterm_handlers.append(self.sigterm_handler_fn) # signal.SIGUSR1 doesn't seem available on windows if not self._is_on_windows(): if not self._has_already_handler(signal.SIGUSR1): signal.signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers)) if not self._has_already_handler(signal.SIGTERM): signal.signal(signal.SIGTERM, HandlersCompose(sigterm_handlers))
def prepare_dataloader(self, dataloader: Any, shuffle: bool, mode: Optional[RunningStage] = None) -> Any: """This function handles to following functionalities: - Injecting a `DistributedDataSampler` into the `DataLoader` if on a distributed environment - Wrapping the datasets and samplers into fault-tolerant components """ if isinstance(dataloader, CombinedLoader): # apply `prepare_dataloader` on all the collection of loaders dataloader.loaders = apply_to_collection( dataloader.loaders, (DataLoader, CycleIterator), self.prepare_dataloader, shuffle, mode=mode ) # the length need to recomputed across all dataloaders in case of special behavior. dataloader._apply_cycle_iterator_length() return dataloader # don't do anything if it's not a dataloader if not isinstance(dataloader, (DataLoader, CycleIterator)): return dataloader cycle_iterator: Optional[CycleIterator] = None if isinstance(dataloader, CycleIterator): cycle_iterator = dataloader dataloader = dataloader.loader if ( _fault_tolerant_training() # injects components to track the state or self._requires_distributed_sampler(dataloader) # sets the distributed sampler or mode == RunningStage.PREDICTING # to track indices for the predictions or self._accelerator_connector.use_ipu # IPUs use a custom `DataLoader` ): sampler = self._resolve_sampler(dataloader, shuffle=shuffle, mode=mode) dataloader = _update_dataloader(dataloader, sampler, mode=mode) if cycle_iterator is not None: cycle_iterator.loader = dataloader return cycle_iterator return dataloader
def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]: cache = None if on_step and result_metric.meta.on_step: cache = result_metric._forward_cache elif not on_step and result_metric.meta.on_epoch: if result_metric._computed is None: should = result_metric.meta.sync.should if not result_metric.meta.sync.should and distributed_available( ): # ensure sync happens for FT since during a failure, the metrics are synced and saved to the # checkpoint, so during restart, metrics on rank 0 are from the accumulated ones from the previous # run, and on other ranks, they are 0. So we need to make sure they are synced in further training # to ensure correct calculation. if _fault_tolerant_training(): result_metric.meta.sync.should = True else: warning_cache.warn( f"It is recommended to use `self.log({result_metric.meta.name!r}, ..., sync_dist=True)`" " when logging on epoch level in distributed setting to accumulate the metric across" " devices.", category=PossibleUserWarning, ) result_metric.compute() result_metric.meta.sync.should = should cache = result_metric._computed if cache is not None: if not isinstance(cache, Tensor): raise ValueError( f"The `.compute()` return of the metric logged as {result_metric.meta.name!r} must be a tensor." f" Found {cache}") if not result_metric.meta.enable_graph: return cache.detach() return cache
def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None: """Utility to reload state_dict within dataloader for fault tolerance.""" if not _fault_tolerant_training(): return dataset = dataloader.dataset if isinstance(dataset, CaptureMapDataset): iterator_state = state_dict["state"][0] if not isinstance(iterator_state, IteratorState): iterator_state = IteratorState.from_state_dict(iterator_state) # reload sampler state ff_sampler = _find_fast_forward_samplers(dataloader) ff_sampler.load_state_dict(iterator_state.sampler_state) # reload dataset state dataset.load_state_dict( iterator_state.dataset_state, latest_worker_id=state_dict["latest_worker_id"], num_workers=iterator_state.num_workers, ) elif isinstance(dataset, CaptureIterableDataset): dataset.load_state_dict({ sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items() }) else: raise MisconfigurationException( "This shouldn't happen. Please, open an issue on PyTorch Lightning Github." )
def dump_checkpoint(self, weights_only: bool = False) -> dict: """Creating a model checkpoint dictionary object from various component states. Args: weights_only: saving model weights only Return: structured dictionary: { 'epoch': training epoch 'global_step': training global step 'pytorch-lightning_version': The version of PyTorch Lightning that produced this checkpoint 'callbacks': "callback specific state"[] # if not weights_only 'optimizer_states': "PT optim's state_dict"[] # if not weights_only 'lr_schedulers': "PT sched's state_dict"[] # if not weights_only 'native_amp_scaling_state': PT amp's state_dict # if not weights_only and use native amp 'amp_scaling_state': Apex's state_dict # if not weights_only and use apex amp 'state_dict': Model's state_dict (e.g. network weights) CHECKPOINT_HYPER_PARAMS_NAME: CHECKPOINT_HYPER_PARAMS_KEY: CHECKPOINT_HYPER_PARAMS_TYPE: something_cool_i_want_to_save: anything you define through model.on_save_checkpoint LightningDataModule.__class__.__name__: pl DataModule's state } """ # dump epoch/global_step/pytorch-lightning_version current_epoch = self.trainer.current_epoch global_step = self.trainer.global_step has_reached_max_steps = self.trainer.max_steps and self.trainer.max_steps <= global_step global_step += 1 if not has_reached_max_steps: current_epoch += 1 model = self.trainer.lightning_module checkpoint = { "epoch": current_epoch, "global_step": global_step, "pytorch-lightning_version": pl.__version__, "state_dict": self._get_lightning_module_state_dict(), } if _fault_tolerant_training(): checkpoint["loops"] = self._get_loops_state_dict() if not weights_only: # dump callbacks checkpoint["callbacks"] = self.trainer.on_save_checkpoint( checkpoint) optimizer_states = [] for i, optimizer in enumerate(self.trainer.optimizers): # Rely on accelerator to dump optimizer state optimizer_state = self.trainer.accelerator.optimizer_state( optimizer) optimizer_states.append(optimizer_state) checkpoint["optimizer_states"] = optimizer_states # dump lr schedulers lr_schedulers = [] for scheduler in self.trainer.lr_schedulers: lr_schedulers.append(scheduler["scheduler"].state_dict()) checkpoint["lr_schedulers"] = lr_schedulers self.trainer.precision_plugin.on_save_checkpoint(checkpoint) # dump hyper-parameters if model.hparams: if hasattr(model, "_hparams_name"): checkpoint[pl.LightningModule. CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name # dump arguments if _OMEGACONF_AVAILABLE and isinstance(model.hparams, Container): checkpoint[pl.LightningModule. CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams checkpoint[ pl.LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type( model.hparams) else: checkpoint[ pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict( model.hparams) # give the model a chance to dump a few things model.on_save_checkpoint(checkpoint) if self.trainer.datamodule is not None: self.trainer.datamodule.on_save_checkpoint(checkpoint) return checkpoint
def result_collection_reload(**kwargs): """This test is going to validate ResultCollection is properly being reload and final accumulation with Fault Tolerant Training is correct.""" if not _fault_tolerant_training(): pytest.skip("Fault tolerant not available") num_processes = kwargs.get("gpus", 1) class CustomException(Exception): pass class ExtendedBoringModel(BoringModel): def __init__(self): super().__init__() self.breaking_batch_idx = 3 self.has_validated_sum = False self.dummy_metric = DummyMeanMetric() @property def results(self): return self.trainer.fit_loop._results def training_step(self, batch, batch_idx): # In the training step, we will accumulate metrics using batch_idx from 0 to 4 # Without failure, we would expect to get `total=10 * world_size` and `num_batches=5 * world_size` # Therefore, compute on `epoch_end` should provide 2 as `10 / 5`. # However, below we will simulate a failure on `batch_idx=3`. if self.trainer.fit_loop.restarting: self.log("tracking", batch_idx, on_step=True, on_epoch=True) self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) self.dummy_metric(batch_idx) self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) value = self.results["training_step.tracking_metric"].value value_2 = self.results["training_step.tracking"].value # On failure, the Metric states are being accumulated on rank 0 and zeroed-out on other ranks. # The shift indicates we failed while the state was `shift=sign(is_global_zero > 0) * [0..3]` shift = 0 if num_processes == 2: shift = 3 if self.trainer.is_global_zero else -3 expected = sum(range(batch_idx + 1)) + shift assert expected == value == value_2 else: if batch_idx == self.breaking_batch_idx: # simulate failure mid epoch raise CustomException self.log("tracking", batch_idx, on_step=True, on_epoch=True) self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) self.dummy_metric(batch_idx) self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) value = self.results["training_step.tracking"].value assert value == sum(range(batch_idx + 1)) value = self.results["training_step.tracking_2"] assert value == sum(range(batch_idx + 1)) return super().training_step(batch, batch_idx) def on_epoch_end(self) -> None: if self.trainer.fit_loop.restarting: total = sum(range(5)) * num_processes metrics = self.results.metrics(on_step=False) assert self.results["training_step.tracking"].value == total assert metrics["callback"][ "tracking"] == self.dummy_metric.compute() == 2 assert self.results["training_step.tracking_2"].value == total assert metrics["callback"][ "tracking_2"] == self.dummy_metric.compute() == 2 self.has_validated_sum = True model = ExtendedBoringModel() trainer_kwargs = { "max_epochs": 1, "limit_train_batches": 5, "limit_val_batches": 0 } trainer_kwargs.update(kwargs) trainer = Trainer(**trainer_kwargs) with suppress(CustomException): trainer.fit(model) assert not model.has_validated_sum tmpdir = (trainer.training_type_plugin.broadcast( trainer_kwargs["default_root_dir"], 0) if num_processes >= 2 else trainer_kwargs["default_root_dir"]) ckpt_path = os.path.join(tmpdir, ".pl_auto_save.ckpt") trainer_kwargs["resume_from_checkpoint"] = ckpt_path trainer = Trainer(**trainer_kwargs) trainer.fit(model) assert model.has_validated_sum
def test_fault_tolerant_not_supported(): assert not _fault_tolerant_training()
def _attach_data_fetcher_fn(loader: DataLoader) -> None: if isinstance(loader, CycleIterator): loader = loader.loader if isinstance(loader, DataLoader) and _fault_tolerant_training(): loader._lightning_fetcher = self
def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None ) -> None: """Resets the train dataloader and initialises required variables (number of batches, when to validate, etc.). Args: model: The `LightningModule` if calling this outside of the trainer scope. """ self.train_dataloader = self.request_dataloader(RunningStage.TRAINING, model=model) if self.overfit_batches > 0: if hasattr(self.train_dataloader, "sampler") and isinstance( self.train_dataloader.sampler, RandomSampler): rank_zero_warn( "You requested to overfit but enabled training dataloader shuffling." " We are turning off the training dataloader shuffling for you." ) self.train_dataloader = self.replace_sampler( self.train_dataloader, SequentialSampler(self.train_dataloader.dataset), mode=RunningStage.TRAINING) # debugging self.dev_debugger.track_load_dataloader_call( "train_dataloader", dataloaders=[self.train_dataloader]) # automatically add samplers self.train_dataloader = apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True, mode=RunningStage.TRAINING) # check the workers recursively apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, "train_dataloader") # add worker_init_fn for correct seeding in worker processes apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn) # add collate_fn to collect metadata for fault tolerant training if _fault_tolerant_training(): apply_to_collection(self.train_dataloader, DataLoader, self._add_sampler_metadata_collate) # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches self.train_dataloader = CombinedLoader( self.train_dataloader, self.data_connector.multiple_trainloader_mode) self.num_training_batches = len(self.train_dataloader) if has_len( self.train_dataloader) else float("inf") if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches)) elif self.num_training_batches != float("inf"): self.num_training_batches = int(self.num_training_batches * self.limit_train_batches) elif self.limit_train_batches != 1.0: raise MisconfigurationException( "When using an IterableDataset for `limit_train_batches`," " `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies" " `num_training_batches` to use.") # determine when to check validation # if int passed in, val checks that often # otherwise, it checks in [0, 1.0] % range of a training epoch if isinstance(self.val_check_interval, int): self.val_check_batch = self.val_check_interval if self.val_check_batch > self.num_training_batches: raise ValueError( f"`val_check_interval` ({self.val_check_interval}) must be less than or equal " f"to the number of the training batches ({self.num_training_batches}). " "If you want to disable validation set `limit_val_batches` to 0.0 instead." ) else: if not has_len(self.train_dataloader): if self.val_check_interval == 1.0: self.val_check_batch = float("inf") else: raise MisconfigurationException( "When using an IterableDataset for `train_dataloader`," " `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies" " checking validation every k training batches.") else: self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch) if self.logger and self.num_training_batches < self.log_every_n_steps: rank_zero_warn( f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval" f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if" " you want to see logs for the training epoch.")
def _get_dataloader_init_kwargs( dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None) -> Dict[str, Any]: if not isinstance(dataloader, DataLoader): raise ValueError( f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`" ) # get the dataloader instance attributes attrs = { k: v for k, v in vars(dataloader).items() if not k.startswith("_") } # not part of `vars` attrs["multiprocessing_context"] = dataloader.multiprocessing_context # get the dataloader instance `__init__` parameters params = dict(inspect.signature(dataloader.__init__).parameters) # keep only the params whose default is different to the current attr value non_defaults = { name for name, p in params.items() if name in attrs and p.default != attrs[name] } # add `dataset` as it might have been replaced with `*args` non_defaults.add("dataset") # kwargs to re-construct the dataloader dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults} dl_kwargs.update( TrainerDataLoadingMixin._resolve_batch_sampler(dataloader, sampler, mode=mode)) required_args = { p.name for p in params.values() if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) and p.default is p.empty and p.name not in dl_kwargs } # the dataloader has required args which we could not extract from the existing attributes if required_args: required_args = sorted(required_args) dataloader_cls_name = dataloader.__class__.__name__ raise MisconfigurationException( f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. " "This would fail as some of the `__init__` arguments are not available as instance attributes. " f"The missing attributes are {required_args}. " f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or " "manually add the `DistributedSampler` as: " f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`." ) has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values()) if not has_variadic_kwargs: # the dataloader signature does not allow keyword arguments that need to be passed missing_kwargs = dl_kwargs.keys() - params.keys() if missing_kwargs: missing_kwargs = sorted(missing_kwargs) dataloader_cls_name = dataloader.__class__.__name__ raise MisconfigurationException( f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. " "This would fail as it doesn't expose all its attributes in the `__init__` signature. " f"The missing arguments are {missing_kwargs}. " f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or " "manually add the `DistributedSampler` as: " f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`." ) if isinstance(dl_kwargs["dataset"], IterableDataset): dl_kwargs["batch_sampler"] = None dl_kwargs["sampler"] = None if _fault_tolerant_training(): if isinstance(dl_kwargs["dataset"], IterableDataset): # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states. dl_kwargs["dataset"] = CaptureIterableDataset( dataset=dl_kwargs["dataset"]) elif len(dl_kwargs["dataset"]): dl_kwargs["dataset"] = CaptureMapDataset( dataset=dl_kwargs["dataset"]) else: raise MisconfigurationException( "This shouldn't happen, please open an issue on Lightning Github repository." ) return dl_kwargs