def test_get_len(): assert get_len(DataLoader(RandomDataset(1, 1))) == 1 value = get_len(DataLoader(RandomIterableDataset(1, 1))) assert isinstance(value, float) assert value == float("inf")
def test_combined_data_loader_with_max_size_cycle_and_ddp(accelerator, replace_sampler_ddp): """This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader with ddp and `max_size_cycle` mode.""" trainer = Trainer(strategy="ddp", accelerator=accelerator, devices=2, replace_sampler_ddp=replace_sampler_ddp) dataloader = CombinedLoader( {"a": DataLoader(RandomDataset(32, 8), batch_size=1), "b": DataLoader(RandomDataset(32, 8), batch_size=1)}, ) dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False) assert len(dataloader) == 4 if replace_sampler_ddp else 8 for a_length in [6, 8, 10]: dataloader = CombinedLoader( { "a": DataLoader(range(a_length), batch_size=1), "b": DataLoader(range(8), batch_size=1), }, mode="max_size_cycle", ) length = max(a_length, 8) assert len(dataloader) == length dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False) assert len(dataloader) == length // 2 if replace_sampler_ddp else length if replace_sampler_ddp: last_batch = list(dataloader)[-1] if a_length == 6: assert last_batch == {"a": torch.tensor([0]), "b": torch.tensor([6])} elif a_length == 8: assert last_batch == {"a": torch.tensor([6]), "b": torch.tensor([6])} elif a_length == 10: assert last_batch == {"a": torch.tensor([8]), "b": torch.tensor([0])} class InfiniteDataset(IterableDataset): def __iter__(self): while True: yield 1 dataloader = CombinedLoader( { "a": DataLoader(InfiniteDataset(), batch_size=1), "b": DataLoader(range(8), batch_size=1), }, mode="max_size_cycle", ) assert get_len(dataloader) == float("inf") assert len(dataloader.loaders["b"].loader) == 8 dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False) assert len(dataloader.loaders["b"].loader) == 4 if replace_sampler_ddp else 8 assert get_len(dataloader) == float("inf")
def _apply_cycle_iterator_length(self) -> None: """When the model is `max_size_cycle`, compute the length across all ``CycleIterator`` and re-assign it to all dataloaders.""" if self.mode != "max_size_cycle": return def set_len(cycle_iterator: CycleIterator, length: int) -> None: cycle_iterator.length = length all_lengths = apply_to_collection(self.loaders, CycleIterator, lambda c: get_len(c.loader)) max_length = _nested_calc_num_data(all_lengths, max) apply_to_collection(self.loaders, CycleIterator, set_len, length=max_length)
def _get_dataloader_init_kwargs( dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None ) -> Dict[str, Any]: if not isinstance(dataloader, DataLoader): raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`") # get the dataloader instance attributes attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")} # not part of `vars` attrs["multiprocessing_context"] = dataloader.multiprocessing_context # get the dataloader instance `__init__` parameters params = dict(inspect.signature(dataloader.__init__).parameters) has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values()) if has_variadic_kwargs: # if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)` params.update(inspect.signature(DataLoader.__init__).parameters) del params["self"] # keep only the params whose default is different to the current attr value non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]} # add `dataset` as it might have been replaced with `*args` non_defaults.add("dataset") # kwargs to re-construct the dataloader dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults} dl_kwargs.update( TrainerDataLoadingMixin._dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode=mode) ) required_args = { p.name for p in params.values() if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) and p.default is p.empty and p.name not in dl_kwargs } # the dataloader has required args which we could not extract from the existing attributes if required_args: required_args = sorted(required_args) dataloader_cls_name = dataloader.__class__.__name__ raise MisconfigurationException( f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. " "This would fail as some of the `__init__` arguments are not available as instance attributes. " f"The missing attributes are {required_args}. " f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or " "manually add the `DistributedSampler` as: " f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`." ) if not has_variadic_kwargs: # the dataloader signature does not allow keyword arguments that need to be passed missing_kwargs = dl_kwargs.keys() - params.keys() if missing_kwargs: missing_kwargs = sorted(missing_kwargs) dataloader_cls_name = dataloader.__class__.__name__ raise MisconfigurationException( f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. " "This would fail as it doesn't expose all its attributes in the `__init__` signature. " f"The missing arguments are {missing_kwargs}. " f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or " "manually add the `DistributedSampler` as: " f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`." ) if isinstance(dl_kwargs["dataset"], IterableDataset): dl_kwargs["batch_sampler"] = None dl_kwargs["sampler"] = None if _fault_tolerant_training(): dataset = dl_kwargs["dataset"] if isinstance(dataset, IterableDataset): # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states. dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"]) elif get_len(dataset) != float("inf"): dl_kwargs["dataset"] = CaptureMapDataset(dataset=dl_kwargs["dataset"]) else: raise MisconfigurationException( "This shouldn't happen, please open an issue on Lightning Github repository." ) return dl_kwargs