def test_multiple_val_dataloader(tmpdir): """Verify multiple val_dataloader.""" tutils.reset_seed() class CurrentTestModel(LightningValidationMultipleDataloadersMixin, LightningTestModelBase): pass hparams = tutils.get_hparams() model = CurrentTestModel(hparams) # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=1.0, ) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) # verify training completed assert result == 1 # verify there are 2 val loaders assert len(trainer.get_val_dataloaders()) == 2, \ 'Multiple val_dataloaders not initiated properly' # make sure predictions are good for each val set for dataloader in trainer.get_val_dataloaders(): tutils.run_prediction(dataloader, trainer.model)
def test_mixing_of_dataloader_options(tmpdir): """Verify that dataloaders can be passed to fit""" tutils.reset_seed() class CurrentTestModel(LightningTestModelBase): pass hparams = tutils.get_hparams() model = CurrentTestModel(hparams) # logger file to get meta trainer_options = dict(default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2) # fit model trainer = Trainer(**trainer_options) fit_options = dict(val_dataloader=model._dataloader(train=False)) results = trainer.fit(model, **fit_options) # fit model trainer = Trainer(**trainer_options) fit_options = dict(val_dataloader=model._dataloader(train=False), test_dataloader=model._dataloader(train=False)) results = trainer.fit(model, **fit_options) assert len(trainer.get_val_dataloaders()) == 1, \ f'`val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}' assert len(trainer.get_test_dataloaders()) == 1, \ f'`test_dataloaders` not initiated properly, got {trainer.get_test_dataloaders()}'
def test_multiple_dataloaders_passed_to_fit(tmpdir): """ Verify that multiple val & test dataloaders can be passed to fit """ tutils.reset_seed() class CurrentTestModel(LightningTestModelBaseWithoutDataloader): pass hparams = tutils.get_hparams() # logger file to get meta trainer_options = dict(default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2) # train, multiple val and multiple test passed to fit model = CurrentTestModel(hparams) trainer = Trainer(**trainer_options) fit_options = dict(train_dataloader=model._dataloader(train=True), val_dataloader=[ model._dataloader(train=False), model._dataloader(train=False) ], test_dataloader=[ model._dataloader(train=False), model._dataloader(train=False) ]) results = trainer.fit(model, **fit_options) assert len(trainer.get_val_dataloaders()) == 2, \ f'Multiple `val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}' assert len(trainer.get_test_dataloaders()) == 2, \ f'Multiple `test_dataloaders` not initiated properly, got {trainer.get_test_dataloaders()}'