Ejemplo n.º 1
0
def test_dataloader_reinit_for_subclass():
    class CustomDataLoader(torch.utils.data.DataLoader):
        def __init__(self,
                     dataset,
                     batch_size=1,
                     shuffle=False,
                     sampler=None,
                     batch_sampler=None,
                     num_workers=0,
                     collate_fn=None,
                     pin_memory=False,
                     drop_last=False,
                     timeout=0,
                     worker_init_fn=None,
                     dummy_kwarg=None):
            super().__init__(dataset, batch_size, shuffle, sampler,
                             batch_sampler, num_workers, collate_fn,
                             pin_memory, drop_last, timeout, worker_init_fn)

            self.dummy_kwarg = dummy_kwarg

    trainer = Trainer(gpus=[0, 1], num_nodes=1, distributed_backend='ddp')

    class CustomDummyObj:
        sampler = None

    result = trainer.auto_add_sampler(CustomDummyObj(), train=True)
    assert isinstance(result,
                      CustomDummyObj), "Wrongly reinstantiated data loader"

    result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000))),
                                      train=True)
    assert isinstance(result, torch.utils.data.DataLoader)
    assert isinstance(result, CustomDataLoader)
    assert hasattr(result, 'dummy_kwarg')
Ejemplo n.º 2
0
def test_dataloader_reinit_for_subclass(tmpdir):

    class CustomDataLoader(torch.utils.data.DataLoader):

        def __init__(
            self,
            dataset,
            batch_size=1,
            shuffle=False,
            sampler=None,
            batch_sampler=None,
            num_workers=0,
            collate_fn=None,
            pin_memory=False,
            drop_last=False,
            timeout=0,
            worker_init_fn=None,
            dummy_kwarg=None,
            **kwargs
        ):
            super().__init__(
                dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn, pin_memory, drop_last,
                timeout, worker_init_fn
            )

            self.dummy_kwarg = dummy_kwarg

    trainer = Trainer(
        gpus=[0, 1],
        num_nodes=1,
        accelerator='ddp_spawn',
        default_root_dir=tmpdir,
    )

    class CustomDummyObj:
        sampler = None

    result = trainer.auto_add_sampler(CustomDummyObj(), shuffle=True)
    assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader"

    dataset = list(range(1000))
    result = trainer.auto_add_sampler(CustomDataLoader(dataset), shuffle=True)
    assert isinstance(result, torch.utils.data.DataLoader)
    assert isinstance(result, CustomDataLoader)
    assert hasattr(result, 'dummy_kwarg')

    # Shuffled DataLoader should also work
    result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), shuffle=True)
    assert isinstance(result, torch.utils.data.DataLoader)
    assert isinstance(result, CustomDataLoader)
    assert hasattr(result, 'dummy_kwarg')

    class CustomSampler(torch.utils.data.Sampler):
        pass

    # Should raise an error if existing sampler is being replaced
    with pytest.raises(MisconfigurationException, match='DistributedSampler'):
        trainer.auto_add_sampler(
            CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), shuffle=True
        )
Ejemplo n.º 3
0
def test_combined_data_loader_validation_test(cuda_available_mock, device_count_mock, tmpdir):
    """This test makes sure distributed sampler has been properly injected in dataloaders when using
    CombinedLoader."""

    class CustomDataset(Dataset):
        def __init__(self, data):
            self.data = data

        def __len__(self):
            return len(self.data)

        def __getitem__(self, index):
            return self.data[index]

    dataloader = CombinedLoader(
        {
            "a": DataLoader(CustomDataset(range(10))),
            "b": {"c": DataLoader(CustomDataset(range(10))), "d": DataLoader(CustomDataset(range(10)))},
            "e": [DataLoader(CustomDataset(range(10))), DataLoader(CustomDataset(range(10)))],
        }
    )

    trainer = Trainer(replace_sampler_ddp=True, accelerator="ddp", gpus=2)
    dataloader = trainer.auto_add_sampler(dataloader, shuffle=True)
    _count = 0

    def _assert_distributed_sampler(v):
        nonlocal _count
        _count += 1
        assert isinstance(v, DistributedSampler)

    apply_to_collection(dataloader.sampler, Sampler, _assert_distributed_sampler)
    assert _count == 5
Ejemplo n.º 4
0
def test_dataloader_reinit_for_subclass():
    class CustomDataLoader(DataLoader):
        def __init__(
            self,
            dataset,
            batch_size=1,
            shuffle=False,
            sampler=None,
            batch_sampler=None,
            num_workers=0,
            collate_fn=None,
            pin_memory=False,
            drop_last=False,
            timeout=0,
            worker_init_fn=None,
            dummy_kwarg=None,
        ):
            super().__init__(
                dataset,
                batch_size,
                shuffle,
                sampler,
                batch_sampler,
                num_workers,
                collate_fn,
                pin_memory,
                drop_last,
                timeout,
                worker_init_fn,
            )
            self.dummy_kwarg = dummy_kwarg
            self.something_unrelated = 1

    trainer = Trainer(num_processes=1, accelerator="ddp_cpu")

    class CustomDummyObj:
        sampler = None

    result = trainer.auto_add_sampler(CustomDummyObj(), shuffle=True)
    assert isinstance(result,
                      CustomDummyObj), "Wrongly reinstantiated data loader"

    dataset = list(range(10))
    result = trainer.auto_add_sampler(CustomDataLoader(dataset), shuffle=True)
    assert isinstance(result, DataLoader)
    assert isinstance(result, CustomDataLoader)
    assert result.dummy_kwarg is None

    # Shuffled DataLoader should also work
    result = trainer.auto_add_sampler(CustomDataLoader(dataset, shuffle=True),
                                      shuffle=True)
    assert isinstance(result, DataLoader)
    assert isinstance(result, CustomDataLoader)
    assert result.dummy_kwarg is None

    class CustomSampler(Sampler):
        pass

    # Should raise an error if existing sampler is being replaced
    dataloader = CustomDataLoader(dataset, sampler=CustomSampler(dataset))
    with pytest.raises(MisconfigurationException,
                       match="will be replaced  by `DistributedSampler`"):
        trainer.auto_add_sampler(dataloader, shuffle=True)