def _patch_dataloader_get_iterators() -> None: """This function is used to replace the DataLoader iterator by their stateful version.""" if not _FaultTolerantMode.detect_current_mode().is_manual: return if not hasattr(DataLoader, "_ori_get_iterator"): DataLoader._ori_get_iterator = DataLoader._get_iterator DataLoader._get_iterator = _get_iterator
def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Dict: """The state dict is determined by the state and progress of this loop and all its children. Args: destination: An existing dictionary to update with this loop's state. By default a new dictionary is returned. prefix: A prefix for each key in the state dictionary """ if destination is None: destination = {} destination[prefix + "state_dict"] = self.on_save_checkpoint() # do not get the mode from `self.trainer` because it might not have been attached yet ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled for k, v in self.__dict__.items(): key = prefix + k if ft_enabled and isinstance(v, BaseProgress): destination[key] = v.state_dict() elif isinstance(v, Loop): v.state_dict(destination, key + ".") elif isinstance(v, _ResultCollection): # sync / unsync metrics v.sync() destination[key] = v.state_dict() v.unsync() return destination
def _validate_fault_tolerant_automatic( dataloader: Iterable, stage: "pl.trainer.states.RunningStage") -> None: """This function is used to validate that Fault-tolerance is possible with the user data.""" if not _FaultTolerantMode.detect_current_mode().is_automatic: return from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator if isinstance(dataloader, CombinedLoader): dataloaders = dataloader.loaders else: dataloaders = dataloader dl_loaders = [] def flatten_dataloader( dataloader: Union[DataLoader, CycleIterator, Iterable]) -> None: nonlocal dl_loaders if isinstance(dataloader, CycleIterator): dataloader = dataloader.loader dl_loaders.append(dataloader) apply_to_collection(dataloaders, (DataLoader, CycleIterator), flatten_dataloader) if len(dl_loaders ) > 1 and stage == pl.trainer.states.RunningStage.TRAINING: raise ValueError("Fault-tolerance supports only a single dataloader.") for dataloader in dl_loaders: validator_fn = (_validate_iterable_dataset if isinstance( dataloader.dataset, IterableDataset) else _validate_map_dataset) validator_fn(dataloader)
def patch_dataloader_iterator( dataloader: DataLoader, iterator: Iterator, data_fetcher: "pl.utilities.fetching.DataFetcher", num_batches_fetched: int = 0, ) -> None: """Patches the iterator of a PyTorch dataloader by injecting logic for fault-tolerant training when it is necessary to remove the sampler state dict from provided data batch. The custom data has this format: .. code-block:: python { "batch": ..., # data returned by DataLoader "__pl_restart_meta": { "sampler0": { 0: {"current_iteration": ...}, 1: {"current_iteration": ...}, }, "sampler1": ..., }, } Each sampler in the worker process tracks the current iteration. We return all of them to the main process as part of the sample and then a special collate function :func:`_capture_metadata_collate` will extract the current iteration as part of the metadata returned by a custom batch. """ if not _FaultTolerantMode.detect_current_mode().is_automatic: return assert isinstance(dataloader.dataset, (CaptureMapDataset, CaptureIterableDataset)) iterator._next_data = _next_data_wrapper(iterator._next_data, iterator, dataloader, num_batches_fetched, data_fetcher)
def on_run_start(self) -> None: # type: ignore[override] """Calls the ``on_train_start`` hook.""" # reset train dataloader and val dataloader self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module) ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled if not ft_enabled and self.restarting and self.trainer.num_training_batches not in ( 0, float("inf")): self.trainer.accumulate_grad_batches = self.trainer.accumulation_scheduler.get_accumulate_grad_batches( self.trainer.current_epoch) expected_steps = math.ceil(self.trainer.num_training_batches / self.trainer.accumulate_grad_batches) # global_step is incremented during checkpointing (#11555) if (self.trainer.global_step - 1) % expected_steps != 0: rank_zero_warn( "You're resuming from a checkpoint that ended mid-epoch." " Training will start from the beginning of the next epoch." " This can cause unreliable results if further training is done," " consider using an end of epoch checkpoint or use fault-tolerant training" " to restart as if training did not stop.") self._is_fresh_start_epoch = True self._results.to(device=self.trainer.lightning_module.device) self.trainer._call_callback_hooks("on_train_start") self.trainer._call_lightning_module_hook("on_train_start") self.trainer._call_strategy_hook("on_train_start")
def _add_capture_metadata_collate(dataloader: DataLoader) -> None: """Wrap default collate function to retrive captured dataset state dict when fault tolerant is enabled.""" fault_tolerant_mode = _FaultTolerantMode.detect_current_mode() collate_fn = dataloader.collate_fn if not fault_tolerant_mode.is_enabled or ( isinstance(collate_fn, partial) and collate_fn.func is _capture_metadata_collate): return dataloader.collate_fn = partial( _capture_metadata_collate, dataset=dataloader.dataset, collate_fn=collate_fn, fault_tolerant_mode=fault_tolerant_mode, )
def _reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None: """Utility to reload state_dict within dataloader for fault tolerance.""" fault_tolerant_mode = _FaultTolerantMode.detect_current_mode() if not fault_tolerant_mode.is_enabled: return if fault_tolerant_mode.is_automatic: _reload_dataloader_state_dict_automatic(dataloader, state_dict) elif fault_tolerant_mode.is_manual: _reload_dataloader_state_dict_manual(dataloader, state_dict) else: raise MisconfigurationException("This shouldn't be happening. Please, open an issue.")
def _dataloader_init_kwargs_resolve_sampler( dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None) -> Dict[str, Any]: """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re-instantiation. If the dataloader is being used for prediction, the sampler will be wrapped into an `IndexBatchSamplerWrapper`, so Lightning can keep track of its indices. If fault tolerant training is enabled, the sampler will be wrapped into a `FastForwardSampler`. """ fault_tolerant_mode = _FaultTolerantMode.detect_current_mode() batch_sampler = getattr(dataloader, "batch_sampler") is_predicting = mode == RunningStage.PREDICTING # checking the batch sampler type is different than PyTorch default. if batch_sampler is not None and (type(batch_sampler) is not BatchSampler or is_predicting): batch_sampler = type(batch_sampler)( sampler, batch_size=batch_sampler.batch_size, drop_last=(False if is_predicting else batch_sampler.drop_last), ) if is_predicting: batch_sampler = IndexBatchSamplerWrapper(batch_sampler) if fault_tolerant_mode.is_automatic: fast_forward_sampler = batch_sampler = FastForwardSampler( batch_sampler) fast_forward_sampler.setup(dataloader_batch_size=1) return { "sampler": None, "shuffle": False, "batch_sampler": batch_sampler, "batch_size": 1, "drop_last": False, } if fault_tolerant_mode.is_automatic: fast_forward_sampler = sampler = FastForwardSampler(sampler) fast_forward_sampler.setup(dataloader_batch_size=dataloader.batch_size) return {"sampler": sampler, "shuffle": False, "batch_sampler": None}
def _get_dataloader_init_kwargs( dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None) -> Dict[str, Any]: if not isinstance(dataloader, DataLoader): raise ValueError( f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`" ) # get the dataloader instance attributes attrs = { k: v for k, v in vars(dataloader).items() if not k.startswith("_") } # not part of `vars` attrs["multiprocessing_context"] = dataloader.multiprocessing_context # get the dataloader instance `__init__` parameters params = dict(inspect.signature(dataloader.__init__).parameters) has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values()) if has_variadic_kwargs: # if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)` params.update(inspect.signature(DataLoader.__init__).parameters) del params["self"] # keep only the params whose default is different to the current attr value non_defaults = { name for name, p in params.items() if name in attrs and p.default != attrs[name] } # add `dataset` as it might have been replaced with `*args` non_defaults.add("dataset") # kwargs to re-construct the dataloader dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults} if isinstance(dl_kwargs["dataset"], IterableDataset): dl_kwargs["batch_sampler"] = None dl_kwargs["sampler"] = None else: dl_kwargs.update( _dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode=mode)) required_args = { p.name for p in params.values() if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) and p.default is p.empty and p.name not in dl_kwargs } # the dataloader has required args which we could not extract from the existing attributes if required_args: required_args = sorted(required_args) dataloader_cls_name = dataloader.__class__.__name__ raise MisconfigurationException( f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. " "This would fail as some of the `__init__` arguments are not available as instance attributes. " f"The missing attributes are {required_args}. " f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or " "manually add the `DistributedSampler` as: " f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`." ) if not has_variadic_kwargs: # the dataloader signature does not allow keyword arguments that need to be passed missing_kwargs = dl_kwargs.keys() - params.keys() if missing_kwargs: missing_kwargs = sorted(missing_kwargs) dataloader_cls_name = dataloader.__class__.__name__ raise MisconfigurationException( f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. " "This would fail as it doesn't expose all its attributes in the `__init__` signature. " f"The missing arguments are {missing_kwargs}. " f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or " "manually add the `DistributedSampler` as: " f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`." ) if _FaultTolerantMode.detect_current_mode().is_automatic: dl_kwargs = _apply_fault_tolerant_automatic_capture_dataset_wrapper( dl_kwargs) return dl_kwargs
def _fault_tolerant_training() -> bool: from pytorch_lightning.utilities.enums import _FaultTolerantMode return _FaultTolerantMode.detect_current_mode().is_enabled
def _get_dataloader_init_args_and_kwargs( dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None ) -> Tuple[Tuple[Any], Dict[str, Any]]: if not isinstance(dataloader, DataLoader): raise ValueError( f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`" ) was_wrapped = hasattr(dataloader, "__pl_dl_args") if was_wrapped: dl_args = dataloader.__pl_dl_args dl_kwargs = dataloader.__pl_dl_kwargs arg_names = dataloader.__pl_dl_arg_names original_dataset = dataloader.__dataset # we have this saved from _wrap_init else: # get the dataloader instance attributes attrs = { k: v for k, v in vars(dataloader).items() if not k.startswith("_") } # We cannot be 100% sure the class sets dataset argument. Let's set it to None to be safe # and hope we can get it from the instance attributes original_dataset = None # not part of `vars` attrs["multiprocessing_context"] = dataloader.multiprocessing_context arg_names = () # get the dataloader instance `__init__` parameters params = dict(inspect.signature(dataloader.__init__).parameters) has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values()) if has_variadic_kwargs: # if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)` if was_wrapped: # if the dataloader was wrapped in a hook, only take arguments with default values # and assume user passes their kwargs correctly params.update({ k: v for k, v in inspect.signature( DataLoader.__init__).parameters.items() if v.default is not v.empty }) else: params.update(inspect.signature(DataLoader.__init__).parameters) params.pop("self", None) if not was_wrapped: # keep only the params whose default is different to the current attr value non_defaults = { name for name, p in params.items() if name in attrs and p.default != attrs[name] } # add `dataset` as it might have been replaced with `*args` non_defaults.add("dataset") # kwargs to re-construct the dataloader dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults} dl_args = () dataset = dl_kwargs.get("dataset", original_dataset) if isinstance(dataset, IterableDataset): dl_kwargs["batch_sampler"] = None dl_kwargs["sampler"] = None else: dl_kwargs.update( _dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode=mode)) required_args = { p.name for p in params.values() if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) and p.default is p.empty and p.name not in dl_kwargs and p.name not in arg_names } # the dataloader has required args which we could not extract from the existing attributes if required_args: required_args = sorted(required_args) dataloader_cls_name = dataloader.__class__.__name__ missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in required_args) raise MisconfigurationException( f"Trying to inject custom `Sampler` into the `{dataloader_cls_name}` instance. " "This would fail as some of the `__init__` arguments are not available as instance attributes. " f"The missing attributes are {required_args}. If you instantiate your `{dataloader_cls_name}` inside a " "`*_dataloader` hook of your module, we will do this for you." f" Otherwise, define {missing_args_message} inside your `__init__`." ) if not has_variadic_kwargs: # the dataloader signature does not allow keyword arguments that need to be passed missing_kwargs = (set(dl_kwargs) | set(arg_names)) - params.keys() if missing_kwargs: missing_kwargs = sorted(missing_kwargs) dataloader_cls_name = dataloader.__class__.__name__ raise MisconfigurationException( f"Trying to inject parameters into the `{dataloader_cls_name}` instance. " "This would fail as it doesn't expose all its attributes in the `__init__` signature. " f"The missing arguments are {missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` class, " "add the `__init__` arguments or allow passing `**kwargs`") if _FaultTolerantMode.detect_current_mode().is_automatic: dl_args, dl_kwargs = _apply_fault_tolerant_automatic_capture_dataset_wrapper( was_wrapped, arg_names, dl_args, dl_kwargs) return dl_args, dl_kwargs