def test_trainer_predict_verify_config(tmpdir, datamodule): class TestModel(LightningModule): def __init__(self): super().__init__() self.layer = torch.nn.Linear(32, 2) def forward(self, x): return self.layer(x) class TestLightningDataModule(LightningDataModule): def __init__(self, dataloaders): super().__init__() self._dataloaders = dataloaders def test_dataloader(self): return self._dataloaders def predict_dataloader(self): return self._dataloaders data = [ torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2)) ] if datamodule: data = TestLightningDataModule(data) model = TestModel() trainer = Trainer(default_root_dir=tmpdir) results = trainer.predict(model, data) assert len(results) == 2 assert results[0][0].shape == torch.Size([1, 2])
def test_dataloader(self): if suffix: return [ torch.utils.data.DataLoader(RandomDataset(32, 64)), torch.utils.data.DataLoader(RandomDataset(32, 64)), ] return super().test_dataloader()
def predict_dataloader(self): return CombinedLoader({ "a": DataLoader(RandomDataset(32, 8), batch_size=2), "b": DataLoader(RandomDataset(32, 8), batch_size=4), })
def train_dataloader(self): if self.trigger_stop_iteration: return DataLoader( RandomDataset(BATCH_SIZE, 2 * EXPECT_NUM_BATCHES_PROCESSED)) return DataLoader( RandomDataset(BATCH_SIZE, EXPECT_NUM_BATCHES_PROCESSED))
def test_workers_are_shutdown(tmpdir, should_fail, persistent_workers): # `num_workers == 1` uses `_MultiProcessingDataLoaderIter` # `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance class _TestMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter): def __init__(self, *args, dataloader, **kwargs): super().__init__(*args, **kwargs) self.dataloader = dataloader def _shutdown_workers(self): self.dataloader.count_shutdown_workers += 1 super()._shutdown_workers() class TestDataLoader(DataLoader): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.count_shutdown_workers = 0 def _get_iterator(self): if self.num_workers == 0: return super()._get_iterator() else: self.check_worker_number_rationality() return _TestMultiProcessingDataLoaderIter(self, dataloader=self) train_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) val_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) class TestCallback(Callback): def on_train_epoch_end(self, trainer, *_): if trainer.current_epoch == 1: raise CustomException max_epochs = 3 model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=max_epochs, callbacks=TestCallback() if should_fail else None, ) if should_fail: with pytest.raises(CustomException): trainer.fit(model, train_dataloader, val_dataloader) else: trainer.fit(model, train_dataloader, val_dataloader) assert train_dataloader.count_shutdown_workers == 2 if should_fail else (2 if persistent_workers else max_epochs) # on sanity checking end, the workers are being deleted too. assert val_dataloader.count_shutdown_workers == 2 if persistent_workers else (3 if should_fail else max_epochs + 1) assert train_dataloader._iterator is None assert val_dataloader._iterator is None
def test_model_nohparams_train_test(tmpdir, cls): """Test models that do not take any argument in init.""" model = cls() trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) train_loader = DataLoader(RandomDataset(32, 64), batch_size=32) trainer.fit(model, train_loader) test_loader = DataLoader(RandomDataset(32, 64), batch_size=32) trainer.test(dataloaders=test_loader)
def test_update_dataloader_with_multiprocessing_context(): """This test verifies that replace_sampler conserves multiprocessing context.""" train = RandomDataset(32, 64) context = "spawn" train = DataLoader(train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True) new_data_loader = _update_dataloader(train, SequentialSampler(train.dataset)) assert new_data_loader.multiprocessing_context == train.multiprocessing_context
def test_loader_detaching(): """Checks that the loader has been resetted after the entrypoint.""" class LoaderTestModel(BoringModel): def training_step(self, batch, batch_idx): assert len(model.train_dataloader()) == 10 return super().training_step(batch, batch_idx) def validation_step(self, batch, batch_idx): assert len(model.val_dataloader()) == 10 return super().validation_step(batch, batch_idx) def test_step(self, batch, batch_idx): assert len(model.test_dataloader()) == 10 return super().test_step(batch, batch_idx) def predict_step(self, batch, batch_idx, dataloader_idx=None): assert len(model.predict_dataloader()) == 10 return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) loader = DataLoader(RandomDataset(32, 10), batch_size=1) model = LoaderTestModel() assert len(model.train_dataloader()) == 64 assert len(model.val_dataloader()) == 64 assert len(model.predict_dataloader()) == 64 assert len(model.test_dataloader()) == 64 trainer = Trainer(fast_dev_run=1) trainer.fit(model, loader, loader) assert len(model.train_dataloader()) == 64 assert len(model.val_dataloader()) == 64 assert len(model.predict_dataloader()) == 64 assert len(model.test_dataloader()) == 64 trainer.validate(model, loader) assert len(model.train_dataloader()) == 64 assert len(model.val_dataloader()) == 64 assert len(model.predict_dataloader()) == 64 assert len(model.test_dataloader()) == 64 trainer.predict(model, loader) assert len(model.train_dataloader()) == 64 assert len(model.val_dataloader()) == 64 assert len(model.predict_dataloader()) == 64 assert len(model.test_dataloader()) == 64 trainer.test(model, loader) assert len(model.train_dataloader()) == 64 assert len(model.val_dataloader()) == 64 assert len(model.predict_dataloader()) == 64 assert len(model.test_dataloader()) == 64
def test_trainer_predict_verify_config(tmpdir, datamodule): class TestModel(LightningModule): def __init__(self): super().__init__() self.layer = torch.nn.Linear(32, 2) def forward(self, x): return self.layer(x) class TestLightningDataModule(LightningDataModule): def __init__(self, dataloaders): super().__init__() self._dataloaders = dataloaders def test_dataloader(self): return self._dataloaders def predict_dataloader(self): return self._dataloaders dataloaders = [ torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2)) ] model = TestModel() trainer = Trainer(default_root_dir=tmpdir) if datamodule: datamodule = TestLightningDataModule(dataloaders) results = trainer.predict(model, datamodule=datamodule) else: results = trainer.predict(model, dataloaders=dataloaders) assert len(results) == 2 assert results[0][0].shape == torch.Size([1, 2]) model.predict_dataloader = None with pytest.raises(MisconfigurationException, match="Dataloader not found for `Trainer.predict`"): trainer.predict(model)
def test_amp_gpus(tmpdir, strategy, precision, gpus): """Make sure combinations of AMP and training types work if supported.""" tutils.reset_seed() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, gpus=gpus, strategy=strategy, precision=precision) model = AMPTestModel() trainer.fit(model) trainer.test(model) trainer.predict(model, DataLoader(RandomDataset(32, 64))) assert trainer.state.finished, f"Training failed with {trainer.state}"
def test_amp_cpus(tmpdir, accelerator, precision, num_processes): """Make sure combinations of AMP and training types work if supported.""" tutils.reset_seed() trainer = Trainer( default_root_dir=tmpdir, num_processes=num_processes, max_epochs=1, accelerator=accelerator, precision=precision ) model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) trainer.test(model) trainer.predict(model, DataLoader(RandomDataset(32, 64))) assert trainer.state.finished, f"Training failed with {trainer.state}"
def test_amp_single_gpu_ddp_spawn(tmpdir): """Make sure DP/DDP + AMP work.""" tutils.reset_seed() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, gpus=1, accelerator="ddp_spawn", precision=16) model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) trainer.test(model) trainer.predict(model, DataLoader(RandomDataset(32, 64))) assert trainer.state.finished, f"Training failed with {trainer.state}"
def test_dataloader_warnings(num_workers): class TestModel(BoringModel): def on_train_start(self, *_) -> None: raise SystemExit() dl = DataLoader(RandomDataset(32, 64), num_workers=num_workers) if hasattr(dl, "persistent_workers"): if num_workers == 0: warn_str = "Consider setting num_workers>0 and persistent_workers=True" else: warn_str = "Consider setting persistent_workers=True" else: warn_str = "Consider setting accelerator=ddp" trainer = Trainer(accelerator="ddp_spawn") with pytest.warns(UserWarning, match=warn_str), pytest.raises(SystemExit): trainer.fit(TestModel(), dl)
def test_prediction_writer_hook_call_intervals(tmpdir): """Test that the `write_on_batch_end` and `write_on_epoch_end` hooks get invoked based on the defined interval.""" DummyPredictionWriter.write_on_batch_end = Mock() DummyPredictionWriter.write_on_epoch_end = Mock() dataloader = DataLoader(RandomDataset(32, 64)) model = BoringModel() cb = DummyPredictionWriter("batch_and_epoch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) results = trainer.predict(model, dataloaders=dataloader) assert len(results) == 4 assert cb.write_on_batch_end.call_count == 4 assert cb.write_on_epoch_end.call_count == 1 DummyPredictionWriter.write_on_batch_end.reset_mock() DummyPredictionWriter.write_on_epoch_end.reset_mock() cb = DummyPredictionWriter("batch_and_epoch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) trainer.predict(model, dataloaders=dataloader, return_predictions=False) assert cb.write_on_batch_end.call_count == 4 assert cb.write_on_epoch_end.call_count == 1 DummyPredictionWriter.write_on_batch_end.reset_mock() DummyPredictionWriter.write_on_epoch_end.reset_mock() cb = DummyPredictionWriter("batch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) trainer.predict(model, dataloaders=dataloader, return_predictions=False) assert cb.write_on_batch_end.call_count == 4 assert cb.write_on_epoch_end.call_count == 0 DummyPredictionWriter.write_on_batch_end.reset_mock() DummyPredictionWriter.write_on_epoch_end.reset_mock() cb = DummyPredictionWriter("epoch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) trainer.predict(model, dataloaders=dataloader, return_predictions=False) assert cb.write_on_batch_end.call_count == 0 assert cb.write_on_epoch_end.call_count == 1
def test_loader_detaching(): """Checks that the loader has been resetted after the entrypoint.""" loader = DataLoader(RandomDataset(32, 10), batch_size=1) model = LoaderTestModel() assert len(model.train_dataloader()) == 64 assert len(model.val_dataloader()) == 64 assert len(model.predict_dataloader()) == 64 assert len(model.test_dataloader()) == 64 trainer = Trainer(fast_dev_run=1) trainer.fit(model, loader, loader) assert len(model.train_dataloader()) == 64 assert len(model.val_dataloader()) == 64 assert len(model.predict_dataloader()) == 64 assert len(model.test_dataloader()) == 64 trainer.validate(model, loader) assert len(model.train_dataloader()) == 64 assert len(model.val_dataloader()) == 64 assert len(model.predict_dataloader()) == 64 assert len(model.test_dataloader()) == 64 trainer.predict(model, loader) assert len(model.train_dataloader()) == 64 assert len(model.val_dataloader()) == 64 assert len(model.predict_dataloader()) == 64 assert len(model.test_dataloader()) == 64 trainer.test(model, loader) assert len(model.train_dataloader()) == 64 assert len(model.val_dataloader()) == 64 assert len(model.predict_dataloader()) == 64 assert len(model.test_dataloader()) == 64
def test_model_tpu_early_stop(tmpdir): """Test if single TPU core training works""" class CustomBoringModel(BoringModel): def validation_step(self, *args, **kwargs): out = super().validation_step(*args, **kwargs) self.log('val_loss', out['x']) return out tutils.reset_seed() model = CustomBoringModel() trainer = Trainer( callbacks=[EarlyStopping(monitor='val_loss')], default_root_dir=tmpdir, progress_bar_refresh_rate=0, max_epochs=2, limit_train_batches=2, limit_val_batches=2, tpu_cores=8, ) trainer.fit(model) trainer.test( test_dataloaders=DataLoader(RandomDataset(32, 2000), batch_size=32))
def test_prediction_writer_batch_indices(tmpdir, num_workers): DummyPredictionWriter.write_on_batch_end = Mock() DummyPredictionWriter.write_on_epoch_end = Mock() dataloader = DataLoader(RandomDataset(32, 64), batch_size=4, num_workers=num_workers) model = BoringModel() writer = DummyPredictionWriter("batch_and_epoch") trainer = Trainer(limit_predict_batches=4, callbacks=writer) trainer.predict(model, dataloaders=dataloader) writer.write_on_batch_end.assert_has_calls([ call(trainer, model, ANY, [0, 1, 2, 3], ANY, 0, 0), call(trainer, model, ANY, [4, 5, 6, 7], ANY, 1, 0), call(trainer, model, ANY, [8, 9, 10, 11], ANY, 2, 0), call(trainer, model, ANY, [12, 13, 14, 15], ANY, 3, 0), ]) writer.write_on_epoch_end.assert_has_calls([ call(trainer, model, ANY, [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]]), ])
def test_model_tpu_early_stop(tmpdir): """Test if single TPU core training works.""" class CustomBoringModel(BoringModel): def validation_step(self, *args, **kwargs): out = super().validation_step(*args, **kwargs) self.log("val_loss", out["x"]) return out tutils.reset_seed() model = CustomBoringModel() trainer = Trainer( callbacks=[EarlyStopping(monitor="val_loss")], default_root_dir=tmpdir, enable_progress_bar=False, max_epochs=2, limit_train_batches=2, limit_val_batches=2, accelerator="tpu", devices=8, ) trainer.fit(model) trainer.test( dataloaders=DataLoader(RandomDataset(32, 2000), batch_size=32))
def val_dataloader(self): return DataLoader(RandomDataset(32, 2000), batch_size=32)
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.core.step_result import Result from pytorch_lightning.plugins import TPUSpawnPlugin from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf from tests.helpers.utils import pl_multi_process_test if _TPU_AVAILABLE: import torch_xla import torch_xla.distributed.xla_multiprocessing as xmp SERIAL_EXEC = xmp.MpSerialExecutor() _LARGER_DATASET = RandomDataset(32, 2000) # 8 cores needs a big dataset def _serial_train_loader(): return DataLoader(_LARGER_DATASET, batch_size=32) class SerialLoaderBoringModel(BoringModel): def train_dataloader(self): return DataLoader(RandomDataset(32, 2000), batch_size=32) def val_dataloader(self): return DataLoader(RandomDataset(32, 2000), batch_size=32)
def train_dataloader(self): return DataLoader(RandomDataset(32, 64), batch_size=2)
def predict_dataloader(self): return [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))]
def train_dataloader(self): return torch.utils.data.DataLoader(RandomDataset(32, 64), collate_fn=collate_fn)
def train_dataloader(self): # override to test the `is_last_batch` value return DataLoader(RandomDataset(32, n_batches))
def val_dataloader(self): return DataLoader(RandomDataset(32, 64), batch_size=getattr(self, "batch_size", 1))
def test_dataloaders_with_missing_keyword_arguments(): trainer = Trainer() ds = RandomDataset(10, 20) class TestDataLoader(DataLoader): def __init__(self, dataset): super().__init__(dataset) loader = TestDataLoader(ds) sampler = SequentialSampler(ds) match = escape( "missing arguments are ['batch_sampler', 'sampler', 'shuffle']") with pytest.raises(MisconfigurationException, match=match): trainer.replace_sampler(loader, sampler, mode="fit") match = escape( "missing arguments are ['batch_sampler', 'batch_size', 'drop_last', 'sampler', 'shuffle']" ) with pytest.raises(MisconfigurationException, match=match): trainer.replace_sampler(loader, sampler, mode="predict") class TestDataLoader(DataLoader): def __init__(self, dataset, *args, **kwargs): super().__init__(dataset) loader = TestDataLoader(ds) sampler = SequentialSampler(ds) trainer.replace_sampler(loader, sampler, mode="fit") trainer.replace_sampler(loader, sampler, mode="predict") class TestDataLoader(DataLoader): def __init__(self, *foo, **bar): super().__init__(*foo, **bar) loader = TestDataLoader(ds) sampler = SequentialSampler(ds) trainer.replace_sampler(loader, sampler, mode="fit") trainer.replace_sampler(loader, sampler, mode="predict") class TestDataLoader(DataLoader): def __init__(self, num_feat, dataset, *args, shuffle=False): self.num_feat = num_feat super().__init__(dataset) loader = TestDataLoader(1, ds) sampler = SequentialSampler(ds) match = escape("missing arguments are ['batch_sampler', 'sampler']") with pytest.raises(MisconfigurationException, match=match): trainer.replace_sampler(loader, sampler, mode="fit") match = escape( "missing arguments are ['batch_sampler', 'batch_size', 'drop_last', 'sampler']" ) with pytest.raises(MisconfigurationException, match=match): trainer.replace_sampler(loader, sampler, mode="predict") class TestDataLoader(DataLoader): def __init__(self, num_feat, dataset, **kwargs): self.feat_num = num_feat super().__init__(dataset) loader = TestDataLoader(1, ds) sampler = SequentialSampler(ds) match = escape("missing attributes are ['num_feat']") with pytest.raises(MisconfigurationException, match=match): trainer.replace_sampler(loader, sampler, mode="fit") match = escape("missing attributes are ['num_feat']") with pytest.raises(MisconfigurationException, match=match): trainer.replace_sampler(loader, sampler, mode="predict")
def train_dataloader(self, *args, **kwargs) -> DataLoader: return DataLoader(RandomDataset(32, 64), batch_size=32)
def train_dataloader(self): return DataLoader(RandomDataset(BATCH_SIZE, DATASET_LEN))
def test_dataloader(self): return [ torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(num_dataloaders) ]
def test_dataloader(self): return [ torch.utils.data.DataLoader(RandomDataset(32, 64)), torch.utils.data.DataLoader(RandomDataset(32, 64)) ]