def test_prefetch_iterator(use_combined_loader, dataset_cls, prefetch_batches): fetcher = DataFetcher(prefetch_batches=prefetch_batches) assert fetcher.prefetch_batches == prefetch_batches if use_combined_loader: loader = CombinedLoader( [DataLoader(dataset_cls()), DataLoader(dataset_cls())]) else: loader = DataLoader(dataset_cls()) fetcher.setup(loader) def generate(): generated = [(fetcher.fetched, data, fetcher.done) for data in fetcher] assert fetcher.fetched == 3 assert fetcher.done return generated # we can only know the last batch with sized iterables or when we prefetch is_last_batch = [ False, False, prefetch_batches > 0 or dataset_cls is SizedDataset ] fetched = list(range(prefetch_batches + 1, 4)) fetched += [3] * (3 - len(fetched)) batches = [[1, 1], [2, 2], [3, 3]] if use_combined_loader else [1, 2, 3] expected = list(zip(fetched, batches, is_last_batch)) assert len(expected) == 3 assert generate() == expected # validate reset works properly. assert generate() == expected assert fetcher.fetched == 3
def test_empty_prefetch_iterator(dataset_cls, prefetch_batches): loader = DataLoader(dataset_cls()) fetcher = DataFetcher(prefetch_batches=prefetch_batches) fetcher.setup(loader) assert not fetcher.done assert not list(fetcher) assert fetcher.done
def test_misconfiguration_error(): fetcher = DataFetcher() loader = DataLoader(range(10)) fetcher.setup(loader) assert fetcher.loaders == loader with pytest.raises( MisconfigurationException, match= "The `dataloader_iter` isn't available outside the __iter__ context." ): fetcher.loader_iters iter(fetcher) assert fetcher.loader_iters
def _select_data_fetcher(self) -> AbstractDataFetcher: if not self.trainer.training: return DataFetcher() training_step_fx = getattr(self.trainer.lightning_module, "training_step") if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True): rank_zero_warn( "Found `dataloader_iter` argument in the `training_step`. Note that the support for " "this signature is experimental and the behavior is subject to change." ) return DataLoaderIterDataFetcher() elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1": if not isinstance(self.trainer.accelerator, GPUAccelerator): raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.") return InterBatchParallelDataFetcher() return DataFetcher()
def test_prefetch_iterator(use_combined_loader): """Test the DataFetcher with PyTorch IterableDataset.""" class IterDataset(IterableDataset): def __iter__(self): yield 1 yield 2 yield 3 for prefetch_batches in range(5): iterator = DataFetcher(prefetch_batches=prefetch_batches) assert iterator.prefetch_batches == prefetch_batches if use_combined_loader: loader = CombinedLoader( [DataLoader(IterDataset()), DataLoader(IterDataset())]) else: loader = DataLoader(IterDataset()) iterator.setup(loader) def generate(): generated = [ (iterator.fetched, *data) for i, data in enumerate(iterator, prefetch_batches + 1) ] assert iterator.fetched == 3 assert iterator.done return generated is_last_batch = [False, False, prefetch_batches > 0] fetched = list(range(prefetch_batches + 1, 4)) fetched += [3] * (3 - len(fetched)) if use_combined_loader: batches = [[tensor(1), tensor(1)], [tensor(2), tensor(2)], [tensor(3), tensor(3)]] else: batches = [1, 2, 3] expected = list(zip(fetched, batches, is_last_batch)) assert len(expected) == 3 assert generate() == expected # validate reset works properly. assert generate() == expected assert iterator.fetched == 3 class EmptyIterDataset(IterableDataset): def __iter__(self): return iter([]) loader = DataLoader(EmptyIterDataset()) iterator = DataFetcher() iterator.setup(loader) assert not list(iterator)
def test_prefetch_iterator(use_combined_loader): """Test the DataFetcher with PyTorch IterableDataset.""" class IterDataset(IterableDataset): def __iter__(self): yield 1 yield 2 yield 3 for prefetch_batches in range(0, 4): if use_combined_loader: loader = CombinedLoader( [DataLoader(IterDataset()), DataLoader(IterDataset())]) expected = [ ([tensor([1]), tensor([1])], False), ([tensor([2]), tensor([2])], False), ([tensor([3]), tensor([3])], True), ] else: loader = DataLoader(IterDataset()) expected = [(1, False), (2, False), (3, True)] iterator = DataFetcher(prefetch_batches=prefetch_batches) prefetch_batches += 1 assert iterator.prefetch_batches == prefetch_batches iterator.setup(loader) def generate(): generated = [] for idx, data in enumerate(iterator, 1): if iterator.done: assert iterator.fetched == 3 else: assert iterator.fetched == (idx + prefetch_batches) generated.append(data) return generated assert generate() == expected # validate reset works properly. assert generate() == expected assert iterator.fetched == 3 class EmptyIterDataset(IterableDataset): def __iter__(self): return iter([]) dataloader = DataLoader(EmptyIterDataset()) iterator = DataFetcher() iterator.setup(dataloader) assert list(iterator) == []
def test_dataset_rng_states_restart(dataset_class, num_workers, batch_size): """Test that the sequence of batches coming from a random number generator continues with the correct sequence after reloading the state. """ def create_dataset_sampler(): dset = CaptureMapDataset(dataset_class(16, 8)) random_sampler = RandomSampler(dset, generator=torch.Generator()) return dset, random_sampler def create_dataloader_sampler(dset, sampler): sampler = FastForwardSampler(sampler) sampler.setup(batch_size) dl = DataLoader(dset, num_workers=num_workers, sampler=sampler, batch_size=batch_size) _add_capture_metadata_collate(dl) return dl, sampler def fetch(fetcher, prefetch_iter, num_batches_fetched): batch, _ = next(prefetch_iter) state: List[MergedIteratorState] = fetcher.state assert len(state) == 1 assert isinstance(state[0], MergedIteratorState) assert len(fetcher.dataloader_iter.cache_states) == 1 if num_workers == 0: assert state[0].state[0].num_batches_fetched == num_batches_fetched return state dataset, random_sampler = create_dataset_sampler() dataloader, ff_sampler = create_dataloader_sampler(dataset, random_sampler) fetcher = DataFetcher() fetcher.setup(dataloader) prefetch_iter = iter(fetcher) # fetch 4 batches fetch(fetcher, prefetch_iter, 1) fetch(fetcher, prefetch_iter, 2) fetch(fetcher, prefetch_iter, 3) # (A) capture the state after fetching 4 batches state = fetch(fetcher, prefetch_iter, 4) state = deepcopy(state[0]) # (B) simulate 2 additional batches batch05, _ = next(prefetch_iter) batch06, _ = next(prefetch_iter) # start reloading dataset, random_sampler = create_dataset_sampler() dataloader, ff_sampler = create_dataloader_sampler(dataset, random_sampler) # load the state dict saved at (A) ff_sampler.load_state_dict(state.sampler_states) dataset.load_state_dict(state.dataset_states, latest_worker_id=state.latest_worker_id, num_workers=num_workers) prefetcher = DataFetcher() prefetcher.setup(dataloader) prefetch_iter = iter(prefetcher) # fetch 2 random batches, these should match exactly the batches seen at (B) batch05_restart, _ = next(prefetch_iter) batch06_restart, _ = next(prefetch_iter) assert torch.equal(batch05, batch05_restart) assert torch.equal(batch06, batch06_restart)