def test_prefetch_iterator(): """ Test the prefetch_iterator with PyTorch IterableDataset. """ class IterDataset(IterableDataset): def __iter__(self): yield 1 yield 2 yield 3 dataset = IterDataset() iterator = prefetch_iterator(dataset) assert [item for item in iterator] == [(1, False), (2, False), (3, True)] class EmptyIterDataset(IterableDataset): def __iter__(self): return iter([]) dataset = EmptyIterDataset() iterator = prefetch_iterator(dataset) assert [item for item in iterator] == []
def get_profiled_train_dataloader(self, train_dataloader): profiled_dl = self.trainer.profiler.profile_iterable( enumerate(prefetch_iterator(train_dataloader)), "get_train_batch") return profiled_dl