def _wrap_with_capture_dataset(dataset: Dataset) -> Dataset: if isinstance(dataset, IterableDataset): # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states. return CaptureIterableDataset(dataset=dataset) if get_len(dataset) != float("inf"): return CaptureMapDataset(dataset=dataset) raise RuntimeError( "This shouldn't happen, please open an issue on Lightning Github repository." )
def _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs: Dict) -> Dict: dataset = dl_kwargs["dataset"] if isinstance(dataset, IterableDataset): # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states. dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dataset) elif get_len(dataset) != float("inf"): dl_kwargs["dataset"] = CaptureMapDataset(dataset=dataset) else: raise MisconfigurationException("This shouldn't happen, please open an issue on Lightning Github repository.") return dl_kwargs
def _get_dataloader_init_kwargs( dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None) -> Dict[str, Any]: if not isinstance(dataloader, DataLoader): raise ValueError( f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`" ) # get the dataloader instance attributes attrs = { k: v for k, v in vars(dataloader).items() if not k.startswith("_") } # not part of `vars` attrs["multiprocessing_context"] = dataloader.multiprocessing_context # get the dataloader instance `__init__` parameters params = dict(inspect.signature(dataloader.__init__).parameters) # keep only the params whose default is different to the current attr value non_defaults = { name for name, p in params.items() if name in attrs and p.default != attrs[name] } # add `dataset` as it might have been replaced with `*args` non_defaults.add("dataset") # kwargs to re-construct the dataloader dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults} dl_kwargs.update( TrainerDataLoadingMixin._resolve_batch_sampler(dataloader, sampler, mode=mode)) required_args = { p.name for p in params.values() if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) and p.default is p.empty and p.name not in dl_kwargs } # the dataloader has required args which we could not extract from the existing attributes if required_args: required_args = sorted(required_args) dataloader_cls_name = dataloader.__class__.__name__ raise MisconfigurationException( f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. " "This would fail as some of the `__init__` arguments are not available as instance attributes. " f"The missing attributes are {required_args}. " f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or " "manually add the `DistributedSampler` as: " f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`." ) has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values()) if not has_variadic_kwargs: # the dataloader signature does not allow keyword arguments that need to be passed missing_kwargs = dl_kwargs.keys() - params.keys() if missing_kwargs: missing_kwargs = sorted(missing_kwargs) dataloader_cls_name = dataloader.__class__.__name__ raise MisconfigurationException( f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. " "This would fail as it doesn't expose all its attributes in the `__init__` signature. " f"The missing arguments are {missing_kwargs}. " f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or " "manually add the `DistributedSampler` as: " f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`." ) if isinstance(dl_kwargs["dataset"], IterableDataset): dl_kwargs["batch_sampler"] = None dl_kwargs["sampler"] = None if _fault_tolerant_training(): if isinstance(dl_kwargs["dataset"], IterableDataset): # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states. dl_kwargs["dataset"] = CaptureIterableDataset( dataset=dl_kwargs["dataset"]) elif len(dl_kwargs["dataset"]): dl_kwargs["dataset"] = CaptureMapDataset( dataset=dl_kwargs["dataset"]) else: raise MisconfigurationException( "This shouldn't happen, please open an issue on Lightning Github repository." ) return dl_kwargs
def create_dataset_sampler(): dset = CaptureMapDataset(dataset_class(16, 8)) random_sampler = RandomSampler(dset, generator=torch.Generator()) return dset, random_sampler