예제 #1
0
def test_prefetch_iterator():
    """ Test the prefetch_iterator with PyTorch IterableDataset. """
    class IterDataset(IterableDataset):
        def __iter__(self):
            yield 1
            yield 2
            yield 3

    dataset = IterDataset()
    iterator = prefetch_iterator(dataset)
    assert [item for item in iterator] == [(1, False), (2, False), (3, True)]

    class EmptyIterDataset(IterableDataset):
        def __iter__(self):
            return iter([])

    dataset = EmptyIterDataset()
    iterator = prefetch_iterator(dataset)
    assert [item for item in iterator] == []
예제 #2
0
 def get_profiled_train_dataloader(self, train_dataloader):
     profiled_dl = self.trainer.profiler.profile_iterable(
         enumerate(prefetch_iterator(train_dataloader)), "get_train_batch")
     return profiled_dl