def test_prefetch_iterator(use_combined_loader, dataset_cls, prefetch_batches):
    fetcher = DataFetcher(prefetch_batches=prefetch_batches)
    assert fetcher.prefetch_batches == prefetch_batches

    if use_combined_loader:
        loader = CombinedLoader(
            [DataLoader(dataset_cls()),
             DataLoader(dataset_cls())])
    else:
        loader = DataLoader(dataset_cls())
    fetcher.setup(loader)

    def generate():
        generated = [(fetcher.fetched, data, fetcher.done) for data in fetcher]
        assert fetcher.fetched == 3
        assert fetcher.done
        return generated

    # we can only know the last batch with sized iterables or when we prefetch
    is_last_batch = [
        False, False, prefetch_batches > 0 or dataset_cls is SizedDataset
    ]
    fetched = list(range(prefetch_batches + 1, 4))
    fetched += [3] * (3 - len(fetched))
    batches = [[1, 1], [2, 2], [3, 3]] if use_combined_loader else [1, 2, 3]
    expected = list(zip(fetched, batches, is_last_batch))
    assert len(expected) == 3

    assert generate() == expected
    # validate reset works properly.
    assert generate() == expected
    assert fetcher.fetched == 3
def test_empty_prefetch_iterator(dataset_cls, prefetch_batches):
    loader = DataLoader(dataset_cls())
    fetcher = DataFetcher(prefetch_batches=prefetch_batches)
    fetcher.setup(loader)

    assert not fetcher.done
    assert not list(fetcher)
    assert fetcher.done
def test_misconfiguration_error():
    fetcher = DataFetcher()
    loader = DataLoader(range(10))
    fetcher.setup(loader)
    assert fetcher.loaders == loader
    with pytest.raises(
            MisconfigurationException,
            match=
            "The `dataloader_iter` isn't available outside the __iter__ context."
    ):
        fetcher.loader_iters
    iter(fetcher)
    assert fetcher.loader_iters
示例#4
0
    def _select_data_fetcher(self) -> AbstractDataFetcher:
        if not self.trainer.training:
            return DataFetcher()

        training_step_fx = getattr(self.trainer.lightning_module, "training_step")
        if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
            rank_zero_warn(
                "Found `dataloader_iter` argument in the `training_step`. Note that the support for "
                "this signature is experimental and the behavior is subject to change."
            )
            return DataLoaderIterDataFetcher()
        elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1":
            if not isinstance(self.trainer.accelerator, GPUAccelerator):
                raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.")
            return InterBatchParallelDataFetcher()
        return DataFetcher()
示例#5
0
def test_prefetch_iterator(use_combined_loader):
    """Test the DataFetcher with PyTorch IterableDataset."""
    class IterDataset(IterableDataset):
        def __iter__(self):
            yield 1
            yield 2
            yield 3

    for prefetch_batches in range(5):
        iterator = DataFetcher(prefetch_batches=prefetch_batches)
        assert iterator.prefetch_batches == prefetch_batches

        if use_combined_loader:
            loader = CombinedLoader(
                [DataLoader(IterDataset()),
                 DataLoader(IterDataset())])
        else:
            loader = DataLoader(IterDataset())
        iterator.setup(loader)

        def generate():
            generated = [
                (iterator.fetched, *data)
                for i, data in enumerate(iterator, prefetch_batches + 1)
            ]
            assert iterator.fetched == 3
            assert iterator.done
            return generated

        is_last_batch = [False, False, prefetch_batches > 0]
        fetched = list(range(prefetch_batches + 1, 4))
        fetched += [3] * (3 - len(fetched))
        if use_combined_loader:
            batches = [[tensor(1), tensor(1)], [tensor(2),
                                                tensor(2)],
                       [tensor(3), tensor(3)]]
        else:
            batches = [1, 2, 3]
        expected = list(zip(fetched, batches, is_last_batch))
        assert len(expected) == 3

        assert generate() == expected
        # validate reset works properly.
        assert generate() == expected
        assert iterator.fetched == 3

    class EmptyIterDataset(IterableDataset):
        def __iter__(self):
            return iter([])

    loader = DataLoader(EmptyIterDataset())
    iterator = DataFetcher()
    iterator.setup(loader)
    assert not list(iterator)
示例#6
0
def test_prefetch_iterator(use_combined_loader):
    """Test the DataFetcher with PyTorch IterableDataset."""
    class IterDataset(IterableDataset):
        def __iter__(self):
            yield 1
            yield 2
            yield 3

    for prefetch_batches in range(0, 4):
        if use_combined_loader:
            loader = CombinedLoader(
                [DataLoader(IterDataset()),
                 DataLoader(IterDataset())])
            expected = [
                ([tensor([1]), tensor([1])], False),
                ([tensor([2]), tensor([2])], False),
                ([tensor([3]), tensor([3])], True),
            ]
        else:
            loader = DataLoader(IterDataset())
            expected = [(1, False), (2, False), (3, True)]
        iterator = DataFetcher(prefetch_batches=prefetch_batches)
        prefetch_batches += 1
        assert iterator.prefetch_batches == prefetch_batches
        iterator.setup(loader)

        def generate():
            generated = []
            for idx, data in enumerate(iterator, 1):
                if iterator.done:
                    assert iterator.fetched == 3
                else:
                    assert iterator.fetched == (idx + prefetch_batches)
                generated.append(data)
            return generated

        assert generate() == expected
        # validate reset works properly.
        assert generate() == expected
        assert iterator.fetched == 3

    class EmptyIterDataset(IterableDataset):
        def __iter__(self):
            return iter([])

    dataloader = DataLoader(EmptyIterDataset())
    iterator = DataFetcher()
    iterator.setup(dataloader)
    assert list(iterator) == []
示例#7
0
def test_dataset_rng_states_restart(dataset_class, num_workers, batch_size):
    """Test that the sequence of batches coming from a random number generator continues with the correct sequence
    after reloading the state.
    """
    def create_dataset_sampler():
        dset = CaptureMapDataset(dataset_class(16, 8))
        random_sampler = RandomSampler(dset, generator=torch.Generator())
        return dset, random_sampler

    def create_dataloader_sampler(dset, sampler):
        sampler = FastForwardSampler(sampler)
        sampler.setup(batch_size)
        dl = DataLoader(dset,
                        num_workers=num_workers,
                        sampler=sampler,
                        batch_size=batch_size)
        _add_capture_metadata_collate(dl)
        return dl, sampler

    def fetch(fetcher, prefetch_iter, num_batches_fetched):
        batch, _ = next(prefetch_iter)

        state: List[MergedIteratorState] = fetcher.state
        assert len(state) == 1
        assert isinstance(state[0], MergedIteratorState)

        assert len(fetcher.dataloader_iter.cache_states) == 1
        if num_workers == 0:
            assert state[0].state[0].num_batches_fetched == num_batches_fetched
        return state

    dataset, random_sampler = create_dataset_sampler()
    dataloader, ff_sampler = create_dataloader_sampler(dataset, random_sampler)

    fetcher = DataFetcher()
    fetcher.setup(dataloader)
    prefetch_iter = iter(fetcher)

    # fetch 4 batches
    fetch(fetcher, prefetch_iter, 1)
    fetch(fetcher, prefetch_iter, 2)
    fetch(fetcher, prefetch_iter, 3)

    # (A) capture the state after fetching 4 batches
    state = fetch(fetcher, prefetch_iter, 4)
    state = deepcopy(state[0])

    # (B) simulate 2 additional batches
    batch05, _ = next(prefetch_iter)
    batch06, _ = next(prefetch_iter)

    # start reloading
    dataset, random_sampler = create_dataset_sampler()
    dataloader, ff_sampler = create_dataloader_sampler(dataset, random_sampler)

    # load the state dict saved at (A)
    ff_sampler.load_state_dict(state.sampler_states)
    dataset.load_state_dict(state.dataset_states,
                            latest_worker_id=state.latest_worker_id,
                            num_workers=num_workers)

    prefetcher = DataFetcher()
    prefetcher.setup(dataloader)
    prefetch_iter = iter(prefetcher)

    # fetch 2 random batches, these should match exactly the batches seen at (B)
    batch05_restart, _ = next(prefetch_iter)
    batch06_restart, _ = next(prefetch_iter)

    assert torch.equal(batch05, batch05_restart)
    assert torch.equal(batch06, batch06_restart)