def test_trainer_predict_verify_config(tmpdir, datamodule):
    class TestModel(LightningModule):
        def __init__(self):
            super().__init__()
            self.layer = torch.nn.Linear(32, 2)

        def forward(self, x):
            return self.layer(x)

    class TestLightningDataModule(LightningDataModule):
        def __init__(self, dataloaders):
            super().__init__()
            self._dataloaders = dataloaders

        def test_dataloader(self):
            return self._dataloaders

        def predict_dataloader(self):
            return self._dataloaders

    data = [
        torch.utils.data.DataLoader(RandomDataset(32, 2)),
        torch.utils.data.DataLoader(RandomDataset(32, 2))
    ]
    if datamodule:
        data = TestLightningDataModule(data)

    model = TestModel()
    trainer = Trainer(default_root_dir=tmpdir)
    results = trainer.predict(model, data)

    assert len(results) == 2
    assert results[0][0].shape == torch.Size([1, 2])
 def test_dataloader(self):
     if suffix:
         return [
             torch.utils.data.DataLoader(RandomDataset(32, 64)),
             torch.utils.data.DataLoader(RandomDataset(32, 64)),
         ]
     return super().test_dataloader()
 def predict_dataloader(self):
     return CombinedLoader({
         "a":
         DataLoader(RandomDataset(32, 8), batch_size=2),
         "b":
         DataLoader(RandomDataset(32, 8), batch_size=4),
     })
Beispiel #4
0
 def train_dataloader(self):
     if self.trigger_stop_iteration:
         return DataLoader(
             RandomDataset(BATCH_SIZE,
                           2 * EXPECT_NUM_BATCHES_PROCESSED))
     return DataLoader(
         RandomDataset(BATCH_SIZE, EXPECT_NUM_BATCHES_PROCESSED))
Beispiel #5
0
def test_workers_are_shutdown(tmpdir, should_fail, persistent_workers):
    # `num_workers == 1` uses `_MultiProcessingDataLoaderIter`
    # `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance

    class _TestMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter):
        def __init__(self, *args, dataloader, **kwargs):
            super().__init__(*args, **kwargs)
            self.dataloader = dataloader

        def _shutdown_workers(self):
            self.dataloader.count_shutdown_workers += 1
            super()._shutdown_workers()

    class TestDataLoader(DataLoader):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.count_shutdown_workers = 0

        def _get_iterator(self):
            if self.num_workers == 0:
                return super()._get_iterator()
            else:
                self.check_worker_number_rationality()
                return _TestMultiProcessingDataLoaderIter(self, dataloader=self)

    train_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers)
    val_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers)

    class TestCallback(Callback):
        def on_train_epoch_end(self, trainer, *_):
            if trainer.current_epoch == 1:
                raise CustomException

    max_epochs = 3

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=tmpdir,
        limit_train_batches=2,
        limit_val_batches=2,
        max_epochs=max_epochs,
        callbacks=TestCallback() if should_fail else None,
    )

    if should_fail:
        with pytest.raises(CustomException):
            trainer.fit(model, train_dataloader, val_dataloader)
    else:
        trainer.fit(model, train_dataloader, val_dataloader)

    assert train_dataloader.count_shutdown_workers == 2 if should_fail else (2 if persistent_workers else max_epochs)
    # on sanity checking end, the workers are being deleted too.
    assert val_dataloader.count_shutdown_workers == 2 if persistent_workers else (3 if should_fail else max_epochs + 1)
    assert train_dataloader._iterator is None
    assert val_dataloader._iterator is None
def test_model_nohparams_train_test(tmpdir, cls):
    """Test models that do not take any argument in init."""

    model = cls()
    trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)

    train_loader = DataLoader(RandomDataset(32, 64), batch_size=32)
    trainer.fit(model, train_loader)

    test_loader = DataLoader(RandomDataset(32, 64), batch_size=32)
    trainer.test(dataloaders=test_loader)
Beispiel #7
0
def test_update_dataloader_with_multiprocessing_context():
    """This test verifies that replace_sampler conserves multiprocessing context."""
    train = RandomDataset(32, 64)
    context = "spawn"
    train = DataLoader(train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True)
    new_data_loader = _update_dataloader(train, SequentialSampler(train.dataset))
    assert new_data_loader.multiprocessing_context == train.multiprocessing_context
def test_loader_detaching():
    """Checks that the loader has been resetted after the entrypoint."""
    class LoaderTestModel(BoringModel):
        def training_step(self, batch, batch_idx):
            assert len(model.train_dataloader()) == 10
            return super().training_step(batch, batch_idx)

        def validation_step(self, batch, batch_idx):
            assert len(model.val_dataloader()) == 10
            return super().validation_step(batch, batch_idx)

        def test_step(self, batch, batch_idx):
            assert len(model.test_dataloader()) == 10
            return super().test_step(batch, batch_idx)

        def predict_step(self, batch, batch_idx, dataloader_idx=None):
            assert len(model.predict_dataloader()) == 10
            return super().predict_step(batch,
                                        batch_idx,
                                        dataloader_idx=dataloader_idx)

    loader = DataLoader(RandomDataset(32, 10), batch_size=1)

    model = LoaderTestModel()

    assert len(model.train_dataloader()) == 64
    assert len(model.val_dataloader()) == 64
    assert len(model.predict_dataloader()) == 64
    assert len(model.test_dataloader()) == 64

    trainer = Trainer(fast_dev_run=1)
    trainer.fit(model, loader, loader)

    assert len(model.train_dataloader()) == 64
    assert len(model.val_dataloader()) == 64
    assert len(model.predict_dataloader()) == 64
    assert len(model.test_dataloader()) == 64

    trainer.validate(model, loader)

    assert len(model.train_dataloader()) == 64
    assert len(model.val_dataloader()) == 64
    assert len(model.predict_dataloader()) == 64
    assert len(model.test_dataloader()) == 64

    trainer.predict(model, loader)

    assert len(model.train_dataloader()) == 64
    assert len(model.val_dataloader()) == 64
    assert len(model.predict_dataloader()) == 64
    assert len(model.test_dataloader()) == 64

    trainer.test(model, loader)

    assert len(model.train_dataloader()) == 64
    assert len(model.val_dataloader()) == 64
    assert len(model.predict_dataloader()) == 64
    assert len(model.test_dataloader()) == 64
def test_trainer_predict_verify_config(tmpdir, datamodule):
    class TestModel(LightningModule):
        def __init__(self):
            super().__init__()
            self.layer = torch.nn.Linear(32, 2)

        def forward(self, x):
            return self.layer(x)

    class TestLightningDataModule(LightningDataModule):
        def __init__(self, dataloaders):
            super().__init__()
            self._dataloaders = dataloaders

        def test_dataloader(self):
            return self._dataloaders

        def predict_dataloader(self):
            return self._dataloaders

    dataloaders = [
        torch.utils.data.DataLoader(RandomDataset(32, 2)),
        torch.utils.data.DataLoader(RandomDataset(32, 2))
    ]

    model = TestModel()

    trainer = Trainer(default_root_dir=tmpdir)

    if datamodule:
        datamodule = TestLightningDataModule(dataloaders)
        results = trainer.predict(model, datamodule=datamodule)
    else:
        results = trainer.predict(model, dataloaders=dataloaders)

    assert len(results) == 2
    assert results[0][0].shape == torch.Size([1, 2])

    model.predict_dataloader = None

    with pytest.raises(MisconfigurationException,
                       match="Dataloader not found for `Trainer.predict`"):
        trainer.predict(model)
Beispiel #10
0
def test_amp_gpus(tmpdir, strategy, precision, gpus):
    """Make sure combinations of AMP and training types work if supported."""
    tutils.reset_seed()

    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, gpus=gpus, strategy=strategy, precision=precision)

    model = AMPTestModel()
    trainer.fit(model)
    trainer.test(model)
    trainer.predict(model, DataLoader(RandomDataset(32, 64)))

    assert trainer.state.finished, f"Training failed with {trainer.state}"
Beispiel #11
0
def test_amp_cpus(tmpdir, accelerator, precision, num_processes):
    """Make sure combinations of AMP and training types work if supported."""
    tutils.reset_seed()

    trainer = Trainer(
        default_root_dir=tmpdir, num_processes=num_processes, max_epochs=1, accelerator=accelerator, precision=precision
    )

    model = AMPTestModel()
    # tutils.run_model_test(trainer_options, model)
    trainer.fit(model)
    trainer.test(model)
    trainer.predict(model, DataLoader(RandomDataset(32, 64)))

    assert trainer.state.finished, f"Training failed with {trainer.state}"
Beispiel #12
0
def test_amp_single_gpu_ddp_spawn(tmpdir):
    """Make sure DP/DDP + AMP work."""
    tutils.reset_seed()
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      gpus=1,
                      accelerator="ddp_spawn",
                      precision=16)

    model = AMPTestModel()
    # tutils.run_model_test(trainer_options, model)
    trainer.fit(model)
    trainer.test(model)
    trainer.predict(model, DataLoader(RandomDataset(32, 64)))
    assert trainer.state.finished, f"Training failed with {trainer.state}"
def test_dataloader_warnings(num_workers):
    class TestModel(BoringModel):
        def on_train_start(self, *_) -> None:
            raise SystemExit()

    dl = DataLoader(RandomDataset(32, 64), num_workers=num_workers)
    if hasattr(dl, "persistent_workers"):
        if num_workers == 0:
            warn_str = "Consider setting num_workers>0 and persistent_workers=True"
        else:
            warn_str = "Consider setting persistent_workers=True"
    else:
        warn_str = "Consider setting accelerator=ddp"

    trainer = Trainer(accelerator="ddp_spawn")
    with pytest.warns(UserWarning, match=warn_str), pytest.raises(SystemExit):
        trainer.fit(TestModel(), dl)
def test_prediction_writer_hook_call_intervals(tmpdir):
    """Test that the `write_on_batch_end` and `write_on_epoch_end` hooks get invoked based on the defined
    interval."""
    DummyPredictionWriter.write_on_batch_end = Mock()
    DummyPredictionWriter.write_on_epoch_end = Mock()

    dataloader = DataLoader(RandomDataset(32, 64))

    model = BoringModel()
    cb = DummyPredictionWriter("batch_and_epoch")
    trainer = Trainer(limit_predict_batches=4, callbacks=cb)
    results = trainer.predict(model, dataloaders=dataloader)
    assert len(results) == 4
    assert cb.write_on_batch_end.call_count == 4
    assert cb.write_on_epoch_end.call_count == 1

    DummyPredictionWriter.write_on_batch_end.reset_mock()
    DummyPredictionWriter.write_on_epoch_end.reset_mock()

    cb = DummyPredictionWriter("batch_and_epoch")
    trainer = Trainer(limit_predict_batches=4, callbacks=cb)
    trainer.predict(model, dataloaders=dataloader, return_predictions=False)
    assert cb.write_on_batch_end.call_count == 4
    assert cb.write_on_epoch_end.call_count == 1

    DummyPredictionWriter.write_on_batch_end.reset_mock()
    DummyPredictionWriter.write_on_epoch_end.reset_mock()

    cb = DummyPredictionWriter("batch")
    trainer = Trainer(limit_predict_batches=4, callbacks=cb)
    trainer.predict(model, dataloaders=dataloader, return_predictions=False)
    assert cb.write_on_batch_end.call_count == 4
    assert cb.write_on_epoch_end.call_count == 0

    DummyPredictionWriter.write_on_batch_end.reset_mock()
    DummyPredictionWriter.write_on_epoch_end.reset_mock()

    cb = DummyPredictionWriter("epoch")
    trainer = Trainer(limit_predict_batches=4, callbacks=cb)
    trainer.predict(model, dataloaders=dataloader, return_predictions=False)
    assert cb.write_on_batch_end.call_count == 0
    assert cb.write_on_epoch_end.call_count == 1
Beispiel #15
0
def test_loader_detaching():
    """Checks that the loader has been resetted after the entrypoint."""

    loader = DataLoader(RandomDataset(32, 10), batch_size=1)

    model = LoaderTestModel()

    assert len(model.train_dataloader()) == 64
    assert len(model.val_dataloader()) == 64
    assert len(model.predict_dataloader()) == 64
    assert len(model.test_dataloader()) == 64

    trainer = Trainer(fast_dev_run=1)
    trainer.fit(model, loader, loader)

    assert len(model.train_dataloader()) == 64
    assert len(model.val_dataloader()) == 64
    assert len(model.predict_dataloader()) == 64
    assert len(model.test_dataloader()) == 64

    trainer.validate(model, loader)

    assert len(model.train_dataloader()) == 64
    assert len(model.val_dataloader()) == 64
    assert len(model.predict_dataloader()) == 64
    assert len(model.test_dataloader()) == 64

    trainer.predict(model, loader)

    assert len(model.train_dataloader()) == 64
    assert len(model.val_dataloader()) == 64
    assert len(model.predict_dataloader()) == 64
    assert len(model.test_dataloader()) == 64

    trainer.test(model, loader)

    assert len(model.train_dataloader()) == 64
    assert len(model.val_dataloader()) == 64
    assert len(model.predict_dataloader()) == 64
    assert len(model.test_dataloader()) == 64
Beispiel #16
0
def test_model_tpu_early_stop(tmpdir):
    """Test if single TPU core training works"""
    class CustomBoringModel(BoringModel):
        def validation_step(self, *args, **kwargs):
            out = super().validation_step(*args, **kwargs)
            self.log('val_loss', out['x'])
            return out

    tutils.reset_seed()
    model = CustomBoringModel()
    trainer = Trainer(
        callbacks=[EarlyStopping(monitor='val_loss')],
        default_root_dir=tmpdir,
        progress_bar_refresh_rate=0,
        max_epochs=2,
        limit_train_batches=2,
        limit_val_batches=2,
        tpu_cores=8,
    )
    trainer.fit(model)
    trainer.test(
        test_dataloaders=DataLoader(RandomDataset(32, 2000), batch_size=32))
def test_prediction_writer_batch_indices(tmpdir, num_workers):
    DummyPredictionWriter.write_on_batch_end = Mock()
    DummyPredictionWriter.write_on_epoch_end = Mock()

    dataloader = DataLoader(RandomDataset(32, 64),
                            batch_size=4,
                            num_workers=num_workers)
    model = BoringModel()
    writer = DummyPredictionWriter("batch_and_epoch")
    trainer = Trainer(limit_predict_batches=4, callbacks=writer)
    trainer.predict(model, dataloaders=dataloader)

    writer.write_on_batch_end.assert_has_calls([
        call(trainer, model, ANY, [0, 1, 2, 3], ANY, 0, 0),
        call(trainer, model, ANY, [4, 5, 6, 7], ANY, 1, 0),
        call(trainer, model, ANY, [8, 9, 10, 11], ANY, 2, 0),
        call(trainer, model, ANY, [12, 13, 14, 15], ANY, 3, 0),
    ])

    writer.write_on_epoch_end.assert_has_calls([
        call(trainer, model, ANY,
             [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]]),
    ])
Beispiel #18
0
def test_model_tpu_early_stop(tmpdir):
    """Test if single TPU core training works."""
    class CustomBoringModel(BoringModel):
        def validation_step(self, *args, **kwargs):
            out = super().validation_step(*args, **kwargs)
            self.log("val_loss", out["x"])
            return out

    tutils.reset_seed()
    model = CustomBoringModel()
    trainer = Trainer(
        callbacks=[EarlyStopping(monitor="val_loss")],
        default_root_dir=tmpdir,
        enable_progress_bar=False,
        max_epochs=2,
        limit_train_batches=2,
        limit_val_batches=2,
        accelerator="tpu",
        devices=8,
    )
    trainer.fit(model)
    trainer.test(
        dataloaders=DataLoader(RandomDataset(32, 2000), batch_size=32))
Beispiel #19
0
 def val_dataloader(self):
     return DataLoader(RandomDataset(32, 2000), batch_size=32)
Beispiel #20
0
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.plugins import TPUSpawnPlugin
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
from tests.helpers.utils import pl_multi_process_test

if _TPU_AVAILABLE:
    import torch_xla
    import torch_xla.distributed.xla_multiprocessing as xmp
    SERIAL_EXEC = xmp.MpSerialExecutor()

_LARGER_DATASET = RandomDataset(32, 2000)


# 8 cores needs a big dataset
def _serial_train_loader():
    return DataLoader(_LARGER_DATASET, batch_size=32)


class SerialLoaderBoringModel(BoringModel):
    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 2000), batch_size=32)

    def val_dataloader(self):
        return DataLoader(RandomDataset(32, 2000), batch_size=32)

Beispiel #21
0
 def train_dataloader(self):
     return DataLoader(RandomDataset(32, 64), batch_size=2)
Beispiel #22
0
 def predict_dataloader(self):
     return [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))]
Beispiel #23
0
 def train_dataloader(self):
     return torch.utils.data.DataLoader(RandomDataset(32, 64), collate_fn=collate_fn)
Beispiel #24
0
 def train_dataloader(self):
     # override to test the `is_last_batch` value
     return DataLoader(RandomDataset(32, n_batches))
 def val_dataloader(self):
     return DataLoader(RandomDataset(32, 64),
                       batch_size=getattr(self, "batch_size", 1))
def test_dataloaders_with_missing_keyword_arguments():
    trainer = Trainer()
    ds = RandomDataset(10, 20)

    class TestDataLoader(DataLoader):
        def __init__(self, dataset):
            super().__init__(dataset)

    loader = TestDataLoader(ds)
    sampler = SequentialSampler(ds)
    match = escape(
        "missing arguments are ['batch_sampler', 'sampler', 'shuffle']")
    with pytest.raises(MisconfigurationException, match=match):
        trainer.replace_sampler(loader, sampler, mode="fit")
    match = escape(
        "missing arguments are ['batch_sampler', 'batch_size', 'drop_last', 'sampler', 'shuffle']"
    )
    with pytest.raises(MisconfigurationException, match=match):
        trainer.replace_sampler(loader, sampler, mode="predict")

    class TestDataLoader(DataLoader):
        def __init__(self, dataset, *args, **kwargs):
            super().__init__(dataset)

    loader = TestDataLoader(ds)
    sampler = SequentialSampler(ds)
    trainer.replace_sampler(loader, sampler, mode="fit")
    trainer.replace_sampler(loader, sampler, mode="predict")

    class TestDataLoader(DataLoader):
        def __init__(self, *foo, **bar):
            super().__init__(*foo, **bar)

    loader = TestDataLoader(ds)
    sampler = SequentialSampler(ds)
    trainer.replace_sampler(loader, sampler, mode="fit")
    trainer.replace_sampler(loader, sampler, mode="predict")

    class TestDataLoader(DataLoader):
        def __init__(self, num_feat, dataset, *args, shuffle=False):
            self.num_feat = num_feat
            super().__init__(dataset)

    loader = TestDataLoader(1, ds)
    sampler = SequentialSampler(ds)
    match = escape("missing arguments are ['batch_sampler', 'sampler']")
    with pytest.raises(MisconfigurationException, match=match):
        trainer.replace_sampler(loader, sampler, mode="fit")
    match = escape(
        "missing arguments are ['batch_sampler', 'batch_size', 'drop_last', 'sampler']"
    )
    with pytest.raises(MisconfigurationException, match=match):
        trainer.replace_sampler(loader, sampler, mode="predict")

    class TestDataLoader(DataLoader):
        def __init__(self, num_feat, dataset, **kwargs):
            self.feat_num = num_feat
            super().__init__(dataset)

    loader = TestDataLoader(1, ds)
    sampler = SequentialSampler(ds)
    match = escape("missing attributes are ['num_feat']")
    with pytest.raises(MisconfigurationException, match=match):
        trainer.replace_sampler(loader, sampler, mode="fit")
    match = escape("missing attributes are ['num_feat']")
    with pytest.raises(MisconfigurationException, match=match):
        trainer.replace_sampler(loader, sampler, mode="predict")
 def train_dataloader(self, *args, **kwargs) -> DataLoader:
     return DataLoader(RandomDataset(32, 64), batch_size=32)
Beispiel #28
0
 def train_dataloader(self):
     return DataLoader(RandomDataset(BATCH_SIZE, DATASET_LEN))
 def test_dataloader(self):
     return [
         torch.utils.data.DataLoader(RandomDataset(32, 64))
         for _ in range(num_dataloaders)
     ]
 def test_dataloader(self):
     return [
         torch.utils.data.DataLoader(RandomDataset(32, 64)),
         torch.utils.data.DataLoader(RandomDataset(32, 64))
     ]