Esempio n. 1
0
 def teardown(self) -> None:
     self.reset()
     if isinstance(self.dataloader, CombinedLoader):
         self.dataloader.reset()
     if isinstance(self.dataloader, DataLoader):
         CombinedLoader._shutdown_workers_and_reset_iterator(self.dataloader)
     self.dataloader_iter = None
Esempio n. 2
0
def test_combined_loader_calc_length_mode_error():
    """Test the ValueError when calculating the number of batches."""
    with pytest.raises(
            TypeError,
            match=
            "Expected data to be int, Sequence or Mapping, but got NoneType"):
        CombinedLoader._calc_num_batches(None)
Esempio n. 3
0
def test_combined_data_loader_with_max_size_cycle_and_ddp(accelerator, replace_sampler_ddp):
    """This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader
    with ddp and `max_size_cycle` mode."""
    trainer = Trainer(strategy="ddp", accelerator=accelerator, devices=2, replace_sampler_ddp=replace_sampler_ddp)

    dataloader = CombinedLoader(
        {"a": DataLoader(RandomDataset(32, 8), batch_size=1), "b": DataLoader(RandomDataset(32, 8), batch_size=1)},
    )
    dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False)
    assert len(dataloader) == 4 if replace_sampler_ddp else 8

    for a_length in [6, 8, 10]:
        dataloader = CombinedLoader(
            {
                "a": DataLoader(range(a_length), batch_size=1),
                "b": DataLoader(range(8), batch_size=1),
            },
            mode="max_size_cycle",
        )

        length = max(a_length, 8)
        assert len(dataloader) == length
        dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False)
        assert len(dataloader) == length // 2 if replace_sampler_ddp else length
        if replace_sampler_ddp:
            last_batch = list(dataloader)[-1]
            if a_length == 6:
                assert last_batch == {"a": torch.tensor([0]), "b": torch.tensor([6])}
            elif a_length == 8:
                assert last_batch == {"a": torch.tensor([6]), "b": torch.tensor([6])}
            elif a_length == 10:
                assert last_batch == {"a": torch.tensor([8]), "b": torch.tensor([0])}

    class InfiniteDataset(IterableDataset):
        def __iter__(self):
            while True:
                yield 1

    dataloader = CombinedLoader(
        {
            "a": DataLoader(InfiniteDataset(), batch_size=1),
            "b": DataLoader(range(8), batch_size=1),
        },
        mode="max_size_cycle",
    )
    assert get_len(dataloader) == float("inf")
    assert len(dataloader.loaders["b"].loader) == 8
    dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False)
    assert len(dataloader.loaders["b"].loader) == 4 if replace_sampler_ddp else 8
    assert get_len(dataloader) == float("inf")
Esempio n. 4
0
def test_combined_dataloader_for_training_with_ddp(
    replace_sampler_ddp: bool, is_min_size_mode: bool, use_combined_loader: bool
):
    """When providing a CombinedLoader as the training data, it should be correctly receive the distributed
    samplers."""
    mode = "min_size" if is_min_size_mode else "max_size_cycle"
    dim = 3
    n1 = 8
    n2 = 6
    dataloader = {
        "a": DataLoader(RandomDataset(dim, n1), batch_size=1),
        "b": DataLoader(RandomDataset(dim, n2), batch_size=1),
    }
    if use_combined_loader:
        dataloader = CombinedLoader(dataloader, mode=mode)
    expected_length_before_ddp = min(n1, n2) if is_min_size_mode else max(n1, n2)
    expected_length_after_ddp = expected_length_before_ddp // 2 if replace_sampler_ddp else expected_length_before_ddp
    model = BoringModel()
    trainer = Trainer(
        strategy="ddp",
        accelerator="auto",
        devices=2,
        replace_sampler_ddp=replace_sampler_ddp,
        multiple_trainloader_mode="max_size_cycle" if use_combined_loader else mode,
    )
    trainer._data_connector.attach_data(
        model=model, train_dataloaders=dataloader, val_dataloaders=None, datamodule=None
    )
    trainer.reset_train_dataloader(model=model)
    assert trainer.train_dataloader is not None
    assert isinstance(trainer.train_dataloader, CombinedLoader)
    assert trainer.train_dataloader.mode == mode
    assert trainer.num_training_batches == expected_length_after_ddp
Esempio n. 5
0
def test_combined_loader_sequence_with_map_and_iterable(lengths):
    class MyIterableDataset(IterableDataset):
        def __init__(self, size: int = 10):
            self.size = size

        def __iter__(self):
            self.sampler = SequentialSampler(range(self.size))
            self.iter_sampler = iter(self.sampler)
            return self

        def __next__(self):
            return next(self.iter_sampler)

    class MyMapDataset(Dataset):
        def __init__(self, size: int = 10):
            self.size = size

        def __getitem__(self, index):
            return index

        def __len__(self):
            return self.size

    x, y = lengths
    loaders = [DataLoader(MyIterableDataset(x)), DataLoader(MyMapDataset(y))]
    dataloader = CombinedLoader(loaders, mode="max_size_cycle")
    counter = 0
    for _ in dataloader:
        counter += 1
    assert counter == max(x, y)
Esempio n. 6
0
    def val_dataloader(self):
        val_dataloader_head = DataLoader(
            TestDataset(
                self.val_triples,
                self.train_triples + self.val_triples + self.test_triples,
                len(self.entity2id),
                len(self.relation2id),
                "head-batch",
            ),
            batch_size=self.val_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=TestDataset.collate_fn,
            drop_last=True,
            pin_memory=True,
        )

        val_dataloader_tail = DataLoader(
            TestDataset(
                self.val_triples,
                self.train_triples + self.val_triples + self.test_triples,
                len(self.entity2id),
                len(self.relation2id),
                "tail-batch",
            ),
            batch_size=self.val_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=TestDataset.collate_fn,
            drop_last=True,
            pin_memory=True,
        )

        return CombinedLoader([val_dataloader_head, val_dataloader_tail])
Esempio n. 7
0
def test_combined_loader_sequence_iterable_dataset(mode,
                                                   use_multiple_dataloaders):
    """Test `CombinedLoader` of mode 'min_size' given sequence loaders."""
    if use_multiple_dataloaders:
        loaders = [
            torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2),
            torch.utils.data.DataLoader(TestIterableDataset(20), batch_size=2),
        ]
    else:
        loaders = [
            torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2),
        ]

    combined_loader = CombinedLoader(loaders, mode)

    has_break = False

    for idx, item in enumerate(combined_loader):
        assert isinstance(item, Sequence)
        assert len(item) == 2 if use_multiple_dataloaders else 1
        if not use_multiple_dataloaders and idx == 4:
            has_break = True
            break

    if mode == "max_size_cycle":
        assert combined_loader.loaders[0].state.done == (not has_break)
    expected = (10 if mode == "max_size_cycle" else
                5) if use_multiple_dataloaders else 5
    assert (expected - 1) == idx, (mode, use_multiple_dataloaders)
Esempio n. 8
0
def test_combined_loader_loader_type_error():
    """Test the ValueError when wrapping the loaders."""
    with pytest.raises(
            TypeError,
            match=
            "Expected data to be int, Sequence or Mapping, but got NoneType"):
        CombinedLoader(None, "max_size_cycle")
 def predict_dataloader(self):
     return CombinedLoader({
         "a":
         DataLoader(RandomDataset(32, 8), batch_size=2),
         "b":
         DataLoader(RandomDataset(32, 8), batch_size=4),
     })
def test_prefetch_iterator(use_combined_loader, dataset_cls, prefetch_batches):
    fetcher = DataFetcher(prefetch_batches=prefetch_batches)
    assert fetcher.prefetch_batches == prefetch_batches

    if use_combined_loader:
        loader = CombinedLoader(
            [DataLoader(dataset_cls()),
             DataLoader(dataset_cls())])
    else:
        loader = DataLoader(dataset_cls())
    fetcher.setup(loader)

    def generate():
        generated = [(fetcher.fetched, data, fetcher.done) for data in fetcher]
        assert fetcher.fetched == 3
        assert fetcher.done
        return generated

    # we can only know the last batch with sized iterables or when we prefetch
    is_last_batch = [
        False, False, prefetch_batches > 0 or dataset_cls is SizedDataset
    ]
    fetched = list(range(prefetch_batches + 1, 4))
    fetched += [3] * (3 - len(fetched))
    batches = [[1, 1], [2, 2], [3, 3]] if use_combined_loader else [1, 2, 3]
    expected = list(zip(fetched, batches, is_last_batch))
    assert len(expected) == 3

    assert generate() == expected
    # validate reset works properly.
    assert generate() == expected
    assert fetcher.fetched == 3
Esempio n. 11
0
def test_combined_data_loader_validation_test(cuda_available_mock, device_count_mock, tmpdir):
    """This test makes sure distributed sampler has been properly injected in dataloaders when using
    CombinedLoader."""

    class CustomDataset(Dataset):
        def __init__(self, data):
            self.data = data

        def __len__(self):
            return len(self.data)

        def __getitem__(self, index):
            return self.data[index]

    dataloader = CombinedLoader(
        {
            "a": DataLoader(CustomDataset(range(10))),
            "b": {"c": DataLoader(CustomDataset(range(10))), "d": DataLoader(CustomDataset(range(10)))},
            "e": [DataLoader(CustomDataset(range(10))), DataLoader(CustomDataset(range(10)))],
        }
    )

    trainer = Trainer(replace_sampler_ddp=True, accelerator="ddp", gpus=2)
    dataloader = trainer.auto_add_sampler(dataloader, shuffle=True)
    _count = 0

    def _assert_distributed_sampler(v):
        nonlocal _count
        _count += 1
        assert isinstance(v, DistributedSampler)

    apply_to_collection(dataloader.sampler, Sampler, _assert_distributed_sampler)
    assert _count == 5
Esempio n. 12
0
 def train_dataloader(self):
     loader_l = DataLoader(self.train_l, self.batch_size, shuffle=True)
     loader_u = DataLoader(self.train_u, self.batch_size, shuffle=True)
     loader_real = DataLoader(self.train, self.batch_size, shuffle=True)
     loaders = {"u": loader_u, "l": loader_l, "real": loader_real}
     combined_loaders = CombinedLoader(loaders, "max_size_cycle")
     return combined_loaders
Esempio n. 13
0
 def val_dataloader(self):
     loader_test = DataLoader(self.test_dataset,
                              int(len(self.test_dataset) / 10))
     loader_u = DataLoader(self.train_dataset_u,
                           int(len(self.train_dataset_u) / 10))
     loaders = {"u": loader_u, "test": loader_test}
     combined_loaders = CombinedLoader(loaders, "max_size_cycle")
     return combined_loaders
Esempio n. 14
0
def test_prefetch_iterator(use_combined_loader):
    """Test the DataFetcher with PyTorch IterableDataset."""
    class IterDataset(IterableDataset):
        def __iter__(self):
            yield 1
            yield 2
            yield 3

    for prefetch_batches in range(5):
        iterator = DataFetcher(prefetch_batches=prefetch_batches)
        assert iterator.prefetch_batches == prefetch_batches

        if use_combined_loader:
            loader = CombinedLoader(
                [DataLoader(IterDataset()),
                 DataLoader(IterDataset())])
        else:
            loader = DataLoader(IterDataset())
        iterator.setup(loader)

        def generate():
            generated = [
                (iterator.fetched, *data)
                for i, data in enumerate(iterator, prefetch_batches + 1)
            ]
            assert iterator.fetched == 3
            assert iterator.done
            return generated

        is_last_batch = [False, False, prefetch_batches > 0]
        fetched = list(range(prefetch_batches + 1, 4))
        fetched += [3] * (3 - len(fetched))
        if use_combined_loader:
            batches = [[tensor(1), tensor(1)], [tensor(2),
                                                tensor(2)],
                       [tensor(3), tensor(3)]]
        else:
            batches = [1, 2, 3]
        expected = list(zip(fetched, batches, is_last_batch))
        assert len(expected) == 3

        assert generate() == expected
        # validate reset works properly.
        assert generate() == expected
        assert iterator.fetched == 3

    class EmptyIterDataset(IterableDataset):
        def __iter__(self):
            return iter([])

    loader = DataLoader(EmptyIterDataset())
    iterator = DataFetcher()
    iterator.setup(loader)
    assert not list(iterator)
Esempio n. 15
0
    def val_dataloader(self, *args: Any,
                       **kwargs: Any) -> CombinedLoader:  # type: ignore
        """
        The val dataloader
        """
        dataloaders = {
            SSLDataModuleType.ENCODER: self.encoder_module.val_dataloader(),
            SSLDataModuleType.LINEAR_HEAD:
            self.linear_head_module.val_dataloader()
        }

        return CombinedLoader(dataloaders, mode="max_size_cycle")
 def get_combined_loader(self, encoder_loader: Sized,
                         linear_head_loader: Sized) -> CombinedLoader:
     """
     Creates a CombinedLoader from the data loaders for the encoder and the linear head.
     The cycle mode is chosen such that in all cases the encoder dataset is only cycled through once.
     :param encoder_loader: The dataloader to use for the SSL encoder.
     :param linear_head_loader: The dataloader to use for the linear head.
     """
     mode = self._cycle_mode(len(encoder_loader), len(linear_head_loader))
     dataloaders = {
         SSLDataModuleType.ENCODER: encoder_loader,
         SSLDataModuleType.LINEAR_HEAD: linear_head_loader
     }
     return CombinedLoader(dataloaders, mode=mode)
Esempio n. 17
0
def test_prefetch_iterator(use_combined_loader):
    """Test the DataFetcher with PyTorch IterableDataset."""
    class IterDataset(IterableDataset):
        def __iter__(self):
            yield 1
            yield 2
            yield 3

    for prefetch_batches in range(0, 4):
        if use_combined_loader:
            loader = CombinedLoader(
                [DataLoader(IterDataset()),
                 DataLoader(IterDataset())])
            expected = [
                ([tensor([1]), tensor([1])], False),
                ([tensor([2]), tensor([2])], False),
                ([tensor([3]), tensor([3])], True),
            ]
        else:
            loader = DataLoader(IterDataset())
            expected = [(1, False), (2, False), (3, True)]
        iterator = DataFetcher(prefetch_batches=prefetch_batches)
        prefetch_batches += 1
        assert iterator.prefetch_batches == prefetch_batches
        iterator.setup(loader)

        def generate():
            generated = []
            for idx, data in enumerate(iterator, 1):
                if iterator.done:
                    assert iterator.fetched == 3
                else:
                    assert iterator.fetched == (idx + prefetch_batches)
                generated.append(data)
            return generated

        assert generate() == expected
        # validate reset works properly.
        assert generate() == expected
        assert iterator.fetched == 3

    class EmptyIterDataset(IterableDataset):
        def __iter__(self):
            return iter([])

    dataloader = DataLoader(EmptyIterDataset())
    iterator = DataFetcher()
    iterator.setup(dataloader)
    assert list(iterator) == []
    def on_save_checkpoint(self) -> Dict:
        state_dict = super().on_save_checkpoint()

        if (
            self.trainer is not None
            and self.trainer.state._fault_tolerant_mode.is_enabled
            and self._data_fetcher is not None
            and not self._num_completed_batches_reached()  # did not finish
            and self.batch_progress.current.ready  # did start
        ):
            state = CombinedLoader._state_dict_fn(self._data_fetcher.dataloader_iter, self._has_completed())
            if state:
                state_dict["dataloader_state_dict"] = _collect_states_on_rank_zero_over_collection(state)

        return state_dict
Esempio n. 19
0
def test_combined_loader_sequence_max_size_cycle():
    """Test `CombinedLoader` of mode 'max_size_cycle' given sequence loaders."""
    loaders = [
        torch.utils.data.DataLoader(range(10), batch_size=4),
        torch.utils.data.DataLoader(range(20), batch_size=5),
    ]

    combined_loader = CombinedLoader(loaders, "max_size_cycle")

    assert len(combined_loader) == max(len(v) for v in loaders)

    for idx, item in enumerate(combined_loader):
        assert isinstance(item, Sequence)
        assert len(item) == 2

    assert idx == len(combined_loader) - 1
Esempio n. 20
0
    def val_dataloader(self):
        dataset_root = self.hydra_conf["trainer"]["dataset_root"]
        dataset_paths = self.hydra_conf["trainer"]["valid_dataset"].split("*")

        scene_dataset, img_dataset = self.setup_dataset(dataset_root, dataset_paths)
        self.valid_scene_dataset = torch.utils.data.ConcatDataset(scene_dataset)
        self.valid_img_dataset = torch.utils.data.ConcatDataset(img_dataset)

        valid_scene_sampler = My_ddp_sampler2(self.valid_scene_dataset, self.batch_size,
                                             v_sample_mode="internal", shuffle=False)
        valid_img_sampler = My_ddp_sampler2(self.valid_img_dataset, self.batch_size,
                                           v_sample_mode="internal", shuffle=False)
        if self.involved_imgs:
            combined_dataset = {
                "scene": DataLoader(self.valid_scene_dataset,
                                    batch_size=self.batch_size,
                                    num_workers=self.scene_worker,
                                    shuffle=False,
                                    pin_memory=True,
                                    collate_fn=self.dataset_builder.collate_fn,
                                    sampler=valid_scene_sampler,
                                    persistent_workers=True
                                    ),
                "img": DataLoader(self.valid_img_dataset,
                                  batch_size=self.batch_size,
                                  num_workers=self.img_worker,
                                  shuffle=False,
                                  pin_memory=True,
                                  collate_fn=self.dataset_builder.collate_fn,
                                  sampler=valid_img_sampler,
                                  persistent_workers=True
                                  )}
            assert len(combined_dataset["scene"]) == len(combined_dataset["img"])
        else:
            combined_dataset = {
                "scene": DataLoader(self.valid_scene_dataset,
                                    batch_size=self.batch_size,
                                    num_workers=self.scene_worker,
                                    shuffle=False,
                                    pin_memory=True,
                                    collate_fn=self.dataset_builder.collate_fn,
                                    sampler=valid_scene_sampler,
                                    persistent_workers=True
                                    )
            }

        return CombinedLoader(combined_dataset, mode="min_size")
Esempio n. 21
0
def test_combined_loader_dict_max_size_cycle():
    """Test `CombinedLoader` of mode 'max_size_cycle' given mapping loaders."""
    loaders = {
        "a": torch.utils.data.DataLoader(range(10), batch_size=4),
        "b": torch.utils.data.DataLoader(range(20), batch_size=5),
    }

    combined_loader = CombinedLoader(loaders, "max_size_cycle")

    assert len(combined_loader) == max(len(v) for v in loaders.values())

    for idx, item in enumerate(combined_loader):
        assert isinstance(item, dict)
        assert len(item) == 2
        assert "a" in item and "b" in item

    assert idx == len(combined_loader) - 1
def create_dataloader():
    dataset = range(50)
    num_workers = 2
    batch_size = 8
    sampler = FastForwardSampler(SequentialSampler(dataset))
    sampler.setup(batch_size)

    dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
    dataloader.fast_forward_sampler = sampler

    loader_dict = {
        "a": [DataLoader(create_iterable_dataset(3, num_workers), num_workers=num_workers, batch_size=3), dataloader],
        "b": DataLoader(
            create_iterable_dataset(2, num_workers=1, attr_name="custom_sampler"), num_workers=0, batch_size=2
        ),
    }
    apply_to_collection(loader_dict, DataLoader, Trainer._add_sampler_metadata_collate)
    return CombinedLoader(loader_dict)
 def train_dataloader(self):
     return CombinedLoader(
         {
             "src":
             DataLoader(self.trainset_src,
                        batch_size=self.batch_size,
                        shuffle=True,
                        pin_memory=True,
                        num_workers=self.num_workers,
                        drop_last=True),
             "tgt":
             DataLoader(self.trainset_tgt,
                        batch_size=self.batch_size,
                        shuffle=True,
                        pin_memory=True,
                        num_workers=self.num_workers,
                        drop_last=True),
         }, "max_size_cycle")
Esempio n. 24
0
    def val_dataloader(self):
        loaders = {
            "content":
            DataLoader(
                self.content_val,
                batch_size=1,
                shuffle=False,
                num_workers=self.workers,
                pin_memory=True,
            ),
            "style":
            DataLoader(
                self.style_val,
                batch_size=1,
                shuffle=False,
                num_workers=self.workers,
                pin_memory=True,
            ),
        }

        return CombinedLoader(loaders, "max_size_cycle")
Esempio n. 25
0
    def test_dataloader(self):
        dataset_root = self.hydra_conf["trainer"]["dataset_root"]
        dataset_paths = self.hydra_conf["trainer"]["test_dataset"].split("*")

        scene_dataset, img_dataset = self.setup_dataset(dataset_root, dataset_paths)

        self.test_scene_dataset = torch.utils.data.ConcatDataset(scene_dataset)
        self.test_img_dataset = torch.utils.data.ConcatDataset(img_dataset)
        if self.involved_imgs:
            combined_dataset = {
                "scene": DataLoader(self.test_scene_dataset,
                                    batch_size=self.batch_size,
                                    num_workers=self.scene_worker,
                                    shuffle=False,
                                    pin_memory=True,
                                    collate_fn=self.dataset_builder.collate_fn,
                                    ),
                "img": DataLoader(self.test_img_dataset,
                                  batch_size=self.batch_size,
                                  num_workers=self.img_worker,
                                  shuffle=False,
                                  pin_memory=True,
                                  collate_fn=self.dataset_builder.collate_fn,
                                  )}
            assert len(combined_dataset["scene"]) == len(combined_dataset["img"])
        else:
            combined_dataset = {
                "scene": DataLoader(self.test_scene_dataset,
                                    batch_size=self.batch_size,
                                    num_workers=self.scene_worker,
                                    shuffle=False,
                                    pin_memory=True,
                                    collate_fn=self.dataset_builder.collate_fn,
                                    ),
            }

        return CombinedLoader(combined_dataset, mode="min_size")
Esempio n. 26
0
    def val_dataloader(self) -> List[DataLoader]:
        main_loader = DataLoader(
            self.val_data.batched(
                self.val_dataloader_conf["batch_size"] - sum(self.batch_size_extra),
                partial=False,
            ),
            batch_size=None,
            pin_memory=False,
            num_workers=self.val_dataloader_conf["num_workers"],
        )

        loaders = {"main": main_loader}
        for cnt, (bs, val_data) in enumerate(
            zip(self.batch_size_extra, self.extra_valid_data)
        ):
            loaders[f"extra_{cnt}"] = DataLoader(
                val_data.batched(bs, partial=False),
                batch_size=None,
                pin_memory=True,
                num_workers=bs // 2,
            )

        combined_loaders = CombinedLoader(loaders, "max_size_cycle")
        return combined_loaders
    with pytest.raises(
            MisconfigurationException,
            match=
            rf"{limit_val_batches} \* {dl_size} < 1. Please increase the `limit_val_batches`",
    ):
        trainer._data_connector._reset_eval_dataloader(RunningStage.VALIDATING,
                                                       model)


@pytest.mark.parametrize(
    "val_dl,warns",
    [
        (DataLoader(dataset=RandomDataset(32, 64), shuffle=True), True),
        (DataLoader(dataset=RandomDataset(32, 64), sampler=list(
            range(64))), False),
        (CombinedLoader(DataLoader(dataset=RandomDataset(32, 64),
                                   shuffle=True)), True),
        (
            CombinedLoader([
                DataLoader(dataset=RandomDataset(32, 64)),
                DataLoader(dataset=RandomDataset(32, 64), shuffle=True)
            ]),
            True,
        ),
        (
            CombinedLoader({
                "dl1":
                DataLoader(dataset=RandomDataset(32, 64)),
                "dl2":
                DataLoader(dataset=RandomDataset(32, 64), shuffle=True),
            }),
            True,
Esempio n. 28
0
def test_combined_loader_init_mode_error():
    """Test the ValueError when constructing `CombinedLoader`"""
    with pytest.raises(MisconfigurationException, match="Invalid Mode"):
        CombinedLoader([range(10)], "testtt")
Esempio n. 29
0
    def reset_train_dataloader(self, model: LightningModule) -> None:
        """Resets the train dataloader and initialises required variables
        (number of batches, when to validate, etc.).

        Args:
            model: The current `LightningModule`
        """
        self.train_dataloader = self.request_dataloader(model, "train")

        if self.overfit_batches > 0:
            if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler):
                rank_zero_warn(
                    'You requested to overfit but enabled training dataloader shuffling.'
                    ' We are turning it off for you.'
                )
                self.train_dataloader = self.replace_sampler(
                    self.train_dataloader, SequentialSampler(self.train_dataloader.dataset)
                )

        # debugging
        self.dev_debugger.track_load_dataloader_call('train_dataloader', dataloaders=[self.train_dataloader])

        # automatically add samplers
        self.train_dataloader = apply_to_collection(
            self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True
        )

        # check the workers recursively
        apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, 'train dataloader')

        # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
        self.train_dataloader = CombinedLoader(self.train_dataloader, self._multiple_trainloader_mode)

        self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf')

        if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
            self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
        elif self.num_training_batches != float('inf'):
            self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
        elif self.limit_train_batches != 1.0:
            raise MisconfigurationException(
                'When using an IterableDataset for `limit_train_batches`,'
                ' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
                ' `num_training_batches` to use.'
            )

        # determine when to check validation
        # if int passed in, val checks that often
        # otherwise, it checks in [0, 1.0] % range of a training epoch
        if isinstance(self.val_check_interval, int):
            self.val_check_batch = self.val_check_interval
            if self.val_check_batch > self.num_training_batches:
                raise ValueError(
                    f'`val_check_interval` ({self.val_check_interval}) must be less than or equal '
                    f'to the number of the training batches ({self.num_training_batches}). '
                    'If you want to disable validation set `limit_val_batches` to 0.0 instead.'
                )
        else:
            if not has_len(self.train_dataloader):
                if self.val_check_interval == 1.0:
                    self.val_check_batch = float('inf')
                else:
                    raise MisconfigurationException(
                        'When using an IterableDataset for `train_dataloader`,'
                        ' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies'
                        ' checking validation every k training batches.'
                    )
            else:
                self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
                self.val_check_batch = max(1, self.val_check_batch)
Esempio n. 30
0
def test_combined_data_loader_validation_test(cuda_available_mock,
                                              device_count_mock,
                                              use_fault_tolerant,
                                              replace_sampler_ddp, tmpdir):
    """This test makes sure distributed sampler has been properly injected in dataloaders when using
    CombinedLoader."""
    class CustomDataset(Dataset):
        def __init__(self, data):
            self.data = data

        def __len__(self):
            return len(self.data)

        def __getitem__(self, index):
            return self.data[index]

    class CustomSampler(RandomSampler):
        def __init__(self, data_source, name) -> None:
            super().__init__(data_source)
            self.name = name

    dataset = CustomDataset(range(10))
    dataloader = CombinedLoader({
        "a":
        DataLoader(CustomDataset(range(10))),
        "b":
        DataLoader(dataset, sampler=CustomSampler(dataset, "custom_sampler")),
        "c": {
            "c": DataLoader(CustomDataset(range(10))),
            "d": DataLoader(CustomDataset(range(10)))
        },
        "d": [
            DataLoader(CustomDataset(range(10))),
            DataLoader(CustomDataset(range(10)))
        ],
    })

    with mock.patch.dict(
            os.environ,
        {"PL_FAULT_TOLERANT_TRAINING": str(int(use_fault_tolerant))}):

        trainer = Trainer(replace_sampler_ddp=replace_sampler_ddp,
                          strategy="ddp",
                          gpus=2)
        dataloader = trainer._data_connector._prepare_dataloader(dataloader,
                                                                 shuffle=True)
        _count = 0
        _has_fastforward_sampler = False

    def _assert_distributed_sampler(v):
        nonlocal _count
        nonlocal _has_fastforward_sampler
        _count += 1
        if use_fault_tolerant:
            _has_fastforward_sampler = True
            assert isinstance(v, FastForwardSampler)
            v = v._sampler
        if replace_sampler_ddp:
            assert isinstance(v, DistributedSampler)
        else:
            assert isinstance(v, (SequentialSampler, CustomSampler))

    apply_to_collection(dataloader.sampler, Sampler,
                        _assert_distributed_sampler)
    assert _count == 6
    assert _has_fastforward_sampler == use_fault_tolerant

    def _assert_dataset(loader):
        d = loader.dataset
        if use_fault_tolerant:
            assert isinstance(d, CaptureMapDataset)
        else:
            assert isinstance(d, CustomDataset)

    apply_to_collection(dataloader.loaders, DataLoader, _assert_dataset)