Ejemplo n.º 1
0
def test_has_len():
    assert has_len(DataLoader(RandomDataset(1, 1)))

    with pytest.warns(UserWarning, match="`DataLoader` returned 0 length."):
        assert has_len(DataLoader(RandomDataset(0, 0)))

    assert not has_len(DataLoader(RandomIterableDataset(1, 1)))
Ejemplo n.º 2
0
def test_has_len():
    assert has_len(DataLoader(RandomDataset(1, 1)))

    with pytest.raises(ValueError, match="`Dataloader` returned 0 length."):
        assert has_len(DataLoader(RandomDataset(0, 0)))

    assert not has_len(DataLoader(RandomIterableDataset(1, 1)))
def test_progress_bar_max_val_check_interval(tmpdir, total_train_samples,
                                             train_batch_size,
                                             total_val_samples, val_batch_size,
                                             val_check_interval):
    world_size = 2
    train_data = DataLoader(RandomDataset(32, total_train_samples),
                            batch_size=train_batch_size)
    val_data = DataLoader(RandomDataset(32, total_val_samples),
                          batch_size=val_batch_size)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=tmpdir,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
        val_check_interval=val_check_interval,
        gpus=world_size,
        strategy="ddp",
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)

    total_train_batches = total_train_samples // (train_batch_size *
                                                  world_size)
    val_check_batch = max(1, int(total_train_batches * val_check_interval))
    assert trainer.val_check_batch == val_check_batch
    val_checks_per_epoch = total_train_batches / val_check_batch
    total_val_batches = total_val_samples // (val_batch_size * world_size)
    assert trainer.progress_bar_callback.total_train_batches == total_train_batches
    assert trainer.progress_bar_callback.total_val_batches == total_val_batches
    total_val_batches = total_val_batches * val_checks_per_epoch
    if trainer.is_global_zero:
        assert trainer.progress_bar_callback.main_progress_bar.total == total_train_batches + total_val_batches
Ejemplo n.º 4
0
def test_combined_dataloader_for_training_with_ddp(
    replace_sampler_ddp: bool, is_min_size_mode: bool, use_combined_loader: bool
):
    """When providing a CombinedLoader as the training data, it should be correctly receive the distributed
    samplers."""
    mode = "min_size" if is_min_size_mode else "max_size_cycle"
    dim = 3
    n1 = 8
    n2 = 6
    dataloader = {
        "a": DataLoader(RandomDataset(dim, n1), batch_size=1),
        "b": DataLoader(RandomDataset(dim, n2), batch_size=1),
    }
    if use_combined_loader:
        dataloader = CombinedLoader(dataloader, mode=mode)
    expected_length_before_ddp = min(n1, n2) if is_min_size_mode else max(n1, n2)
    expected_length_after_ddp = expected_length_before_ddp // 2 if replace_sampler_ddp else expected_length_before_ddp
    model = BoringModel()
    trainer = Trainer(
        strategy="ddp",
        accelerator="auto",
        devices=2,
        replace_sampler_ddp=replace_sampler_ddp,
        multiple_trainloader_mode="max_size_cycle" if use_combined_loader else mode,
    )
    trainer._data_connector.attach_data(
        model=model, train_dataloaders=dataloader, val_dataloaders=None, datamodule=None
    )
    trainer.reset_train_dataloader(model=model)
    assert trainer.train_dataloader is not None
    assert isinstance(trainer.train_dataloader, CombinedLoader)
    assert trainer.train_dataloader.mode == mode
    assert trainer.num_training_batches == expected_length_after_ddp
Ejemplo n.º 5
0
def test_boring_lite_model_ddp(precision, strategy, devices, accelerator,
                               tmpdir):
    LightningLite.seed_everything(42)
    train_dataloader = DataLoader(RandomDataset(32, 4))
    model = BoringModel()
    num_epochs = 1
    state_dict = deepcopy(model.state_dict())

    lite = LiteRunner(precision=precision,
                      strategy=strategy,
                      devices=devices,
                      accelerator=accelerator)
    lite.run(model, train_dataloader, num_epochs=num_epochs, tmpdir=tmpdir)

    lite_model_state_dict = model.state_dict()

    for w_pure, w_lite in zip(state_dict.values(),
                              lite_model_state_dict.values()):
        assert not torch.equal(w_pure.cpu(), w_lite.cpu())

    LightningLite.seed_everything(42)
    train_dataloader = DataLoader(RandomDataset(32, 4))
    model = BoringModel()
    run(lite.global_rank, model, train_dataloader, num_epochs, precision,
        accelerator, tmpdir)
    pure_model_state_dict = model.state_dict()

    for w_pure, w_lite in zip(pure_model_state_dict.values(),
                              lite_model_state_dict.values()):
        assert torch.equal(w_pure.cpu(), w_lite.cpu())
Ejemplo n.º 6
0
def test_has_len_all_rank():
    trainer = Trainer(fast_dev_run=True)
    model = BoringModel()

    with pytest.raises(
            MisconfigurationException,
            match="Total length of `Dataloader` across ranks is zero."):
        assert not has_len_all_ranks(DataLoader(RandomDataset(0, 0)),
                                     trainer.strategy, model)

    assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.strategy,
                             model)
Ejemplo n.º 7
0
def test_has_len_all_rank():
    trainer = Trainer(fast_dev_run=True)
    model = BoringModel()

    with pytest.warns(
            UserWarning,
            match="Total length of `DataLoader` across ranks is zero."):
        assert has_len_all_ranks(DataLoader(RandomDataset(0, 0)),
                                 trainer.strategy, model)

    assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.strategy,
                             model)
Ejemplo n.º 8
0
def test_on_epoch_logging_with_sum_and_on_batch_start(tmpdir):
    class TestModel(BoringModel):
        def on_train_epoch_end(self):
            assert all(v == 3 for v in self.trainer.callback_metrics.values())

        def on_validation_epoch_end(self):
            assert all(v == 3 for v in self.trainer.callback_metrics.values())

        def on_train_batch_start(self, batch, batch_idx):
            self.log("on_train_batch_start",
                     1.0,
                     on_step=False,
                     on_epoch=True,
                     reduce_fx="sum")

        def on_train_batch_end(self, outputs, batch, batch_idx):
            self.log("on_train_batch_end",
                     1.0,
                     on_step=False,
                     on_epoch=True,
                     reduce_fx="sum")

        def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
            self.log("on_validation_batch_start", 1.0, reduce_fx="sum")

        def on_validation_batch_end(self, outputs, batch, batch_idx,
                                    dataloader_idx):
            self.log("on_validation_batch_end", 1.0, reduce_fx="sum")

        def training_epoch_end(self, *_) -> None:
            self.log("training_epoch_end", 3.0, reduce_fx="mean")
            assert self.trainer._results[
                "training_epoch_end.training_epoch_end"].value == 3.0

        def validation_epoch_end(self, *_) -> None:
            self.log("validation_epoch_end", 3.0, reduce_fx="mean")
            assert self.trainer._results[
                "validation_epoch_end.validation_epoch_end"].value == 3.0

    model = TestModel()
    trainer = Trainer(
        enable_progress_bar=False,
        limit_train_batches=3,
        limit_val_batches=3,
        num_sanity_val_steps=3,
        max_epochs=1,
    )
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
Ejemplo n.º 9
0
def test_combined_data_loader_with_max_size_cycle_and_ddp(accelerator, replace_sampler_ddp):
    """This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader
    with ddp and `max_size_cycle` mode."""
    trainer = Trainer(strategy="ddp", accelerator=accelerator, devices=2, replace_sampler_ddp=replace_sampler_ddp)

    dataloader = CombinedLoader(
        {"a": DataLoader(RandomDataset(32, 8), batch_size=1), "b": DataLoader(RandomDataset(32, 8), batch_size=1)},
    )
    dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False)
    assert len(dataloader) == 4 if replace_sampler_ddp else 8

    for a_length in [6, 8, 10]:
        dataloader = CombinedLoader(
            {
                "a": DataLoader(range(a_length), batch_size=1),
                "b": DataLoader(range(8), batch_size=1),
            },
            mode="max_size_cycle",
        )

        length = max(a_length, 8)
        assert len(dataloader) == length
        dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False)
        assert len(dataloader) == length // 2 if replace_sampler_ddp else length
        if replace_sampler_ddp:
            last_batch = list(dataloader)[-1]
            if a_length == 6:
                assert last_batch == {"a": torch.tensor([0]), "b": torch.tensor([6])}
            elif a_length == 8:
                assert last_batch == {"a": torch.tensor([6]), "b": torch.tensor([6])}
            elif a_length == 10:
                assert last_batch == {"a": torch.tensor([8]), "b": torch.tensor([0])}

    class InfiniteDataset(IterableDataset):
        def __iter__(self):
            while True:
                yield 1

    dataloader = CombinedLoader(
        {
            "a": DataLoader(InfiniteDataset(), batch_size=1),
            "b": DataLoader(range(8), batch_size=1),
        },
        mode="max_size_cycle",
    )
    assert get_len(dataloader) == float("inf")
    assert len(dataloader.loaders["b"].loader) == 8
    dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False)
    assert len(dataloader.loaders["b"].loader) == 4 if replace_sampler_ddp else 8
    assert get_len(dataloader) == float("inf")
Ejemplo n.º 10
0
def test_replace_sampler_with_multiprocessing_context(tmpdir):
    """
    This test verifies that replace_sampler conserves multiprocessing context
    """
    train = RandomDataset(32, 64)
    context = 'spawn'
    train = DataLoader(train,
                       batch_size=32,
                       num_workers=2,
                       multiprocessing_context=context,
                       shuffle=True)

    class ExtendedBoringModel(BoringModel):
        def train_dataloader(self):
            return train

    trainer = Trainer(
        max_epochs=1,
        progress_bar_refresh_rate=20,
        overfit_batches=5,
    )

    new_data_loader = trainer.replace_sampler(train,
                                              SequentialSampler(train.dataset))
    assert (new_data_loader.multiprocessing_context ==
            train.multiprocessing_context)
Ejemplo n.º 11
0
def test_get_len():
    assert get_len(DataLoader(RandomDataset(1, 1))) == 1

    value = get_len(DataLoader(RandomIterableDataset(1, 1)))

    assert isinstance(value, float)
    assert value == float("inf")
def test_swa_deepcopy(tmpdir):
    """Test to ensure SWA Callback doesn't deepcopy dataloaders and datamodule potentially leading to OOM"""
    class TestSWA(StochasticWeightAveraging):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.on_before_accelerator_backend_setup_called = False

        def on_before_accelerator_backend_setup(self, trainer: 'Trainer',
                                                pl_module: 'LightningModule'):
            super().on_before_accelerator_backend_setup(trainer, pl_module)
            assert self._average_model.train_dataloader is not pl_module.train_dataloader
            assert self._average_model.train_dataloader.__self__ == self._average_model
            assert isinstance(pl_module.train_dataloader, _PatchDataLoader)
            assert self._average_model.trainer is None
            self.on_before_accelerator_backend_setup_called = True

    model = BoringModel()
    swa = TestSWA()
    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=swa,
        fast_dev_run=True,
    )
    trainer.fit(model, train_dataloader=DataLoader(RandomDataset(32, 2)))
    assert swa.on_before_accelerator_backend_setup_called
Ejemplo n.º 13
0
def test_boring_lite_model_ddp_spawn(precision, strategy, devices, accelerator,
                                     tmpdir):
    LightningLite.seed_everything(42)
    train_dataloader = DataLoader(RandomDataset(32, 8))
    model = BoringModel()
    num_epochs = 1
    state_dict = deepcopy(model.state_dict())

    lite = LiteRunner(precision=precision,
                      strategy=strategy,
                      devices=devices,
                      accelerator=accelerator)
    checkpoint_path = lite.run(model,
                               train_dataloader,
                               num_epochs=num_epochs,
                               tmpdir=tmpdir)
    spawn_model_state_dict = torch.load(checkpoint_path)

    for w_pure, w_lite in zip(state_dict.values(),
                              spawn_model_state_dict.values()):
        assert not torch.equal(w_pure.cpu(), w_lite.cpu())

    model.load_state_dict(state_dict)
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = str(find_free_network_port())
    mp.spawn(run,
             args=(model, train_dataloader, num_epochs, precision, accelerator,
                   tmpdir),
             nprocs=2)
    spawn_pure_model_state_dict = torch.load(
        os.path.join(tmpdir, "model_spawn.pt"))

    for w_pure, w_lite in zip(spawn_pure_model_state_dict.values(),
                              spawn_model_state_dict.values()):
        assert torch.equal(w_pure.cpu(), w_lite.cpu())
Ejemplo n.º 14
0
def test_boring_lite_model_single_device(precision, strategy, devices,
                                         accelerator, tmpdir):
    LightningLite.seed_everything(42)
    train_dataloader = DataLoader(RandomDataset(32, 8))
    model = BoringModel()
    num_epochs = 1
    state_dict = deepcopy(model.state_dict())

    lite = LiteRunner(precision=precision,
                      strategy=strategy,
                      devices=devices,
                      accelerator=accelerator)
    lite.run(model, train_dataloader, num_epochs=num_epochs)
    lite_state_dict = model.state_dict()

    with precision_context(precision, accelerator):
        model.load_state_dict(state_dict)
        pure_state_dict = main(lite.to_device,
                               model,
                               train_dataloader,
                               num_epochs=num_epochs)

    state_dict = apply_to_collection(state_dict, torch.Tensor, lite.to_device)
    for w_pure, w_lite in zip(state_dict.values(), lite_state_dict.values()):
        assert not torch.equal(w_pure, w_lite)

    for w_pure, w_lite in zip(pure_state_dict.values(),
                              lite_state_dict.values()):
        assert torch.equal(w_pure, w_lite)
Ejemplo n.º 15
0
def test_progress_bar_max_val_check_interval_ddp(tmpdir, val_check_interval):
    world_size = 2
    total_train_samples = 16
    train_batch_size = 4
    total_val_samples = 2
    val_batch_size = 1
    train_data = DataLoader(RandomDataset(32, 8), batch_size=train_batch_size)
    val_data = DataLoader(RandomDataset(32, total_val_samples),
                          batch_size=val_batch_size)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=tmpdir,
        num_sanity_val_steps=0,
        max_epochs=1,
        val_check_interval=val_check_interval,
        accelerator="gpu",
        devices=world_size,
        strategy="ddp",
        enable_progress_bar=True,
        enable_model_summary=False,
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)

    total_train_batches = total_train_samples // (train_batch_size *
                                                  world_size)
    val_check_batch = max(1, int(total_train_batches * val_check_interval))
    assert trainer.val_check_batch == val_check_batch
    val_checks_per_epoch = total_train_batches / val_check_batch
    total_val_batches = total_val_samples // (val_batch_size * world_size)
    pbar_callback = trainer.progress_bar_callback

    if trainer.is_global_zero:
        assert pbar_callback.val_progress_bar.n == total_val_batches
        assert pbar_callback.val_progress_bar.total == total_val_batches
        total_val_batches = total_val_batches * val_checks_per_epoch
        assert pbar_callback.main_progress_bar.n == (
            total_train_batches + total_val_batches) // world_size
        assert pbar_callback.main_progress_bar.total == (
            total_train_batches + total_val_batches) // world_size
        assert pbar_callback.is_enabled
Ejemplo n.º 16
0
def test_has_iterable_dataset():
    assert has_iterable_dataset(DataLoader(RandomIterableDataset(1, 1)))

    assert not has_iterable_dataset(DataLoader(RandomDataset(1, 1)))

    class MockDatasetWithoutIterableDataset(RandomDataset):
        def __iter__(self):
            yield 1
            return self

    assert not has_iterable_dataset(
        DataLoader(MockDatasetWithoutIterableDataset(1, 1)))
Ejemplo n.º 17
0
def test_update_dataloader_with_multiprocessing_context():
    """This test verifies that replace_sampler conserves multiprocessing context."""
    train = RandomDataset(32, 64)
    context = "spawn"
    train = DataLoader(train,
                       batch_size=32,
                       num_workers=2,
                       multiprocessing_context=context,
                       shuffle=True)
    new_data_loader = _update_dataloader(train,
                                         SequentialSampler(train.dataset))
    assert new_data_loader.multiprocessing_context == train.multiprocessing_context
Ejemplo n.º 18
0
def test_error_raised_with_insufficient_float_limit_train_dataloader():
    batch_size = 16
    dl = DataLoader(RandomDataset(32, batch_size * 9), batch_size=batch_size)
    trainer = Trainer(limit_train_batches=0.1)
    model = BoringModel()

    trainer._data_connector.attach_data(model=model, train_dataloaders=dl)
    with pytest.raises(
            MisconfigurationException,
            match=
            "Please increase the `limit_train_batches` argument. Try at least",
    ):
        trainer.reset_train_dataloader(model)
def test_quantization_val_test_predict(tmpdir):
    """Test the default quantization aware training not affected by validating, testing and predicting."""
    seed_everything(42)
    num_features = 16
    dm = RegressDataModule(num_features=num_features)
    qmodel = RegressionModel()

    val_test_predict_qmodel = copy.deepcopy(qmodel)
    trainer = Trainer(
        callbacks=[QuantizationAwareTraining(quantize_on_fit_end=False)],
        default_root_dir=tmpdir,
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        limit_predict_batches=1,
        val_check_interval=1,
        num_sanity_val_steps=1,
        max_epochs=4,
    )
    trainer.fit(val_test_predict_qmodel, datamodule=dm)
    trainer.validate(model=val_test_predict_qmodel,
                     datamodule=dm,
                     verbose=False)
    trainer.test(model=val_test_predict_qmodel, datamodule=dm, verbose=False)
    trainer.predict(model=val_test_predict_qmodel,
                    dataloaders=[
                        torch.utils.data.DataLoader(
                            RandomDataset(num_features, 16))
                    ])

    expected_qmodel = copy.deepcopy(qmodel)
    # No validation in ``expected_qmodel`` fitting.
    Trainer(
        callbacks=[QuantizationAwareTraining(quantize_on_fit_end=False)],
        default_root_dir=tmpdir,
        limit_train_batches=1,
        limit_val_batches=0,
        max_epochs=4,
    ).fit(expected_qmodel, datamodule=dm)

    expected_state_dict = expected_qmodel.state_dict()
    for key, value in val_test_predict_qmodel.state_dict().items():
        expected_value = expected_state_dict[key]
        assert torch.allclose(value, expected_value)
Ejemplo n.º 20
0
def test_swa_deepcopy(tmpdir):
    """Test to ensure SWA Callback doesn't deepcopy dataloaders and datamodule potentially leading to OOM."""

    class TestSWA(StochasticWeightAveraging):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.setup_called = False

        def setup(self, trainer, pl_module, stage) -> None:
            super().setup(trainer, pl_module, stage)
            assert self._average_model.train_dataloader is not pl_module.train_dataloader
            assert self._average_model.train_dataloader.__self__ == self._average_model
            assert self._average_model.trainer is None
            self.setup_called = True

    model = BoringModel()
    swa = TestSWA(swa_lrs=1e-2)
    trainer = Trainer(default_root_dir=tmpdir, callbacks=swa, fast_dev_run=True)
    trainer.fit(model, train_dataloaders=DataLoader(RandomDataset(32, 2)))
    assert swa.setup_called
Ejemplo n.º 21
0
def test_loader_detaching():
    """Checks that the loader has been reset after the entrypoint."""

    loader = DataLoader(RandomDataset(32, 10), batch_size=1)

    model = LoaderTestModel()

    assert len(model.train_dataloader()) == 64
    assert len(model.val_dataloader()) == 64
    assert len(model.predict_dataloader()) == 64
    assert len(model.test_dataloader()) == 64

    trainer = Trainer(fast_dev_run=1)
    trainer.fit(model, loader, loader)

    assert len(model.train_dataloader()) == 64
    assert len(model.val_dataloader()) == 64
    assert len(model.predict_dataloader()) == 64
    assert len(model.test_dataloader()) == 64

    trainer.validate(model, loader)

    assert len(model.train_dataloader()) == 64
    assert len(model.val_dataloader()) == 64
    assert len(model.predict_dataloader()) == 64
    assert len(model.test_dataloader()) == 64

    trainer.predict(model, loader)

    assert len(model.train_dataloader()) == 64
    assert len(model.val_dataloader()) == 64
    assert len(model.predict_dataloader()) == 64
    assert len(model.test_dataloader()) == 64

    trainer.test(model, loader)

    assert len(model.train_dataloader()) == 64
    assert len(model.val_dataloader()) == 64
    assert len(model.predict_dataloader()) == 64
    assert len(model.test_dataloader()) == 64
Ejemplo n.º 22
0
 def predict_dataloader(self):
     return DataLoader(RandomDataset(32, 64))
Ejemplo n.º 23
0
@RunIf(rich=True)
def test_rich_progress_bar_refresh_rate():
    progress_bar = RichProgressBar(refresh_rate_per_second=1)
    assert progress_bar.is_enabled
    assert not progress_bar.is_disabled
    progress_bar = RichProgressBar(refresh_rate_per_second=0)
    assert not progress_bar.is_enabled
    assert progress_bar.is_disabled


@RunIf(rich=True)
@mock.patch(
    "pytorch_lightning.callbacks.progress.rich_progress.Progress.update")
@pytest.mark.parametrize(
    "dataset", [RandomDataset(32, 64),
                RandomIterableDataset(32, 64)])
def test_rich_progress_bar(progress_update, tmpdir, dataset):
    class TestModel(BoringModel):
        def train_dataloader(self):
            return DataLoader(dataset=dataset)

        def val_dataloader(self):
            return DataLoader(dataset=dataset)

        def test_dataloader(self):
            return DataLoader(dataset=dataset)

        def predict_dataloader(self):
            return DataLoader(dataset=dataset)
 def test_dataloader(self):
     assert self._setup
     return DataLoader(RandomDataset(32, 64), batch_size=2)
Ejemplo n.º 25
0
 def test_dataloader(self):
     return torch.utils.data.DataLoader(RandomDataset(32, 64))
 def train_dataloader(self):
     # batch target memory >= 100x boring_model size
     batch_size = self.num_params * 100 // 32 + 1
     return DataLoader(RandomDataset(32, 5000), batch_size=batch_size)
Ejemplo n.º 27
0
    assert isinstance(trainer.progress_bar_callback, RichProgressBar)


@RunIf(rich=True)
def test_rich_progress_bar_refresh_rate_enabled():
    progress_bar = RichProgressBar(refresh_rate=1)
    assert progress_bar.is_enabled
    assert not progress_bar.is_disabled
    progress_bar = RichProgressBar(refresh_rate=0)
    assert not progress_bar.is_enabled
    assert progress_bar.is_disabled


@RunIf(rich=True)
@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update")
@pytest.mark.parametrize("dataset", [RandomDataset(32, 64), RandomIterableDataset(32, 64)])
def test_rich_progress_bar(progress_update, tmpdir, dataset):
    class TestModel(BoringModel):
        def train_dataloader(self):
            return DataLoader(dataset=dataset)

        def val_dataloader(self):
            return DataLoader(dataset=dataset)

        def test_dataloader(self):
            return DataLoader(dataset=dataset)

        def predict_dataloader(self):
            return DataLoader(dataset=dataset)

    model = TestModel()
class BoringModelNoDataloaders(BoringModel):
    def train_dataloader(self):
        raise NotImplementedError

    def val_dataloader(self):
        raise NotImplementedError

    def test_dataloader(self):
        raise NotImplementedError

    def predict_dataloader(self):
        raise NotImplementedError


_loader = DataLoader(RandomDataset(32, 64))
_loader_no_len = CustomNotImplementedErrorDataloader(_loader)


@pytest.mark.parametrize(
    "train_dataloaders, val_dataloaders, test_dataloaders, predict_dataloaders",
    [
        (_loader_no_len, None, None, None),
        (None, _loader_no_len, None, None),
        (None, None, _loader_no_len, None),
        (None, None, None, _loader_no_len),
        (None, [_loader, _loader_no_len], None, None),
    ],
)
@mock.patch("pytorch_lightning.plugins.training_type.tpu_spawn.xm")
def test_error_iterable_dataloaders_passed_to_fit(_, tmpdir, train_dataloaders,
Ejemplo n.º 29
0
def test_mp_device_dataloader_attribute(_):
    dataset = RandomDataset(32, 64)
    dataloader = TPUSpawnStrategy().process_dataloader(DataLoader(dataset))
    assert dataloader.dataset == dataset
Ejemplo n.º 30
0
 def train_dataloader(self):
     seed_everything(42)
     return torch.utils.data.DataLoader(RandomDataset(32, 64))