def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n): """Verify that dataloaders can be passed.""" model = EvalModelTemplate() if n == 1: dataloaders = model.dataloader(train=False) else: dataloaders = [model.dataloader(train=False)] * 2 model.validation_step = model.validation_step__multiple_dataloaders model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders model.test_step = model.test_step__multiple_dataloaders # train, multiple val and multiple test passed to fit trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2, ) trainer.fit(model, train_dataloader=model.dataloader(train=True), val_dataloaders=dataloaders) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert len(trainer.val_dataloaders) == n if ckpt_path == 'specific': ckpt_path = trainer.checkpoint_callback.best_model_path trainer.test(test_dataloaders=dataloaders, ckpt_path=ckpt_path) trainer.validate(val_dataloaders=dataloaders, ckpt_path=ckpt_path) assert len(trainer.val_dataloaders) == n assert len(trainer.test_dataloaders) == n
def test_warning_with_few_workers(tmpdir): """ Test that error is raised if dataloader with only a few workers is used """ model = EvalModelTemplate() # logger file to get meta trainer_options = dict( default_root_dir=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2 ) fit_options = dict(train_dataloader=model.dataloader(train=True), val_dataloaders=model.dataloader(train=False)) test_options = dict(test_dataloaders=model.dataloader(train=False)) trainer = Trainer(**trainer_options) # fit model with pytest.warns(UserWarning, match='train'): trainer.fit(model, **fit_options) with pytest.warns(UserWarning, match='val'): trainer.fit(model, **fit_options) with pytest.warns(UserWarning, match='test'): trainer.test(**test_options)
def test_warning_with_few_workers(tmpdir, ckpt_path): """ Test that error is raised if dataloader with only a few workers is used """ model = EvalModelTemplate() # logger file to get meta trainer_options = dict(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2) train_dl = model.dataloader(train=True) train_dl.num_workers = 0 val_dl = model.dataloader(train=False) val_dl.num_workers = 0 train_dl = model.dataloader(train=False) train_dl.num_workers = 0 fit_options = dict(train_dataloader=train_dl, val_dataloaders=val_dl) trainer = Trainer(**trainer_options) # fit model with pytest.warns(UserWarning, match='train'): trainer.fit(model, **fit_options) with pytest.warns(UserWarning, match='val'): trainer.fit(model, **fit_options) if ckpt_path == 'specific': ckpt_path = trainer.checkpoint_callback.best_model_path test_options = dict(test_dataloaders=train_dl, ckpt_path=ckpt_path) with pytest.warns(UserWarning, match='test'): trainer.test(**test_options)
def test_mixing_of_dataloader_options(tmpdir, ckpt_path): """Verify that dataloaders can be passed to fit""" model = EvalModelTemplate() trainer_options = dict(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2) # fit model trainer = Trainer(**trainer_options) results = trainer.fit(model, val_dataloaders=model.dataloader(train=False)) assert results # fit model trainer = Trainer(**trainer_options) results = trainer.fit(model, val_dataloaders=model.dataloader(train=False)) assert results if ckpt_path == 'specific': ckpt_path = trainer.checkpoint_callback.best_model_path trainer.test(test_dataloaders=model.dataloader(train=False), ckpt_path=ckpt_path) assert len(trainer.val_dataloaders) == 1, \ f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' assert len(trainer.test_dataloaders) == 1, \ f'`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
def test_all_dataloaders_passed_to_fit(tmpdir, ckpt_path): """Verify train, val & test dataloader(s) can be passed to fit and test method""" model = EvalModelTemplate() # train, val and test passed to fit trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2) fit_options = dict(train_dataloader=model.dataloader(train=True), val_dataloaders=model.dataloader(train=False)) result = trainer.fit(model, **fit_options) if ckpt_path == 'specific': ckpt_path = trainer.checkpoint_callback.best_model_path test_options = dict(test_dataloaders=model.dataloader(train=False), ckpt_path=ckpt_path) trainer.test(**test_options) assert result == 1 assert len(trainer.val_dataloaders) == 1, \ f'val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' assert len(trainer.test_dataloaders) == 1, \ f'test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path): """Verify that multiple val & test dataloaders can be passed to fit.""" model = EvalModelTemplate() model.validation_step = model.validation_step__multiple_dataloaders model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders model.test_step = model.test_step__multiple_dataloaders # train, multiple val and multiple test passed to fit trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2) fit_options = dict(train_dataloader=model.dataloader(train=True), val_dataloaders=[ model.dataloader(train=False), model.dataloader(train=False) ]) trainer.fit(model, **fit_options) if ckpt_path == 'specific': ckpt_path = trainer.checkpoint_callback.best_model_path test_options = dict(test_dataloaders=[ model.dataloader(train=False), model.dataloader(train=False) ], ckpt_path=ckpt_path) trainer.test(**test_options) assert len(trainer.val_dataloaders) == 2, \ f'Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' assert len(trainer.test_dataloaders) == 2, \ f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
def test_mixing_of_dataloader_options(tmpdir): """Verify that dataloaders can be passed to fit""" model = EvalModelTemplate() trainer_options = dict( default_root_dir=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2 ) # fit model trainer = Trainer(**trainer_options) results = trainer.fit(model, val_dataloaders=model.dataloader(train=False)) assert results # fit model trainer = Trainer(**trainer_options) results = trainer.fit(model, val_dataloaders=model.dataloader(train=False)) assert results trainer.test(test_dataloaders=model.dataloader(train=False)) assert len(trainer.val_dataloaders) == 1, \ f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' assert len(trainer.test_dataloaders) == 1, \ f'`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
def test_multiple_dataloaders_passed_to_fit(tmpdir): """Verify that multiple val & test dataloaders can be passed to fit.""" model = EvalModelTemplate() model.validation_step = model.validation_step__multiple_dataloaders model.test_step = model.test_step__multiple_dataloaders # train, multiple val and multiple test passed to fit trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2 ) fit_options = dict(train_dataloader=model.dataloader(train=True), val_dataloaders=[model.dataloader(train=False), model.dataloader(train=False)]) test_options = dict(test_dataloaders=[model.dataloader(train=False), model.dataloader(train=False)]) trainer.fit(model, **fit_options) trainer.test(**test_options) assert len(trainer.val_dataloaders) == 2, \ f'Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' assert len(trainer.test_dataloaders) == 2, \ f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
def test_warning_with_few_workers_multi_loader(mock, tmpdir, ckpt_path): """ Test that error is raised if dataloader with only a few workers is used """ model = EvalModelTemplate() model.training_step = model.training_step__multiple_dataloaders model.validation_step = model.validation_step__multiple_dataloaders model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders model.test_step = model.test_step__multiple_dataloaders model.test_epoch_end = model.test_epoch_end__multiple_dataloaders # logger file to get meta train_dl = model.dataloader(train=True) train_dl.num_workers = 0 val_dl = model.dataloader(train=False) val_dl.num_workers = 0 train_dl = model.dataloader(train=False) train_dl.num_workers = 0 train_multi_dl = {'a': train_dl, 'b': train_dl} val_multi_dl = [val_dl, val_dl] test_multi_dl = [train_dl, train_dl] fit_options = dict(train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2, ) # fit model with pytest.warns( UserWarning, match= 'The dataloader, train dataloader, does not have many workers which may be a bottleneck.' ): trainer.fit(model, **fit_options) with pytest.warns( UserWarning, match= 'The dataloader, val dataloader 0, does not have many workers which may be a bottleneck.' ): trainer.fit(model, **fit_options) if ckpt_path == 'specific': ckpt_path = trainer.checkpoint_callback.best_model_path test_options = dict(test_dataloaders=test_multi_dl, ckpt_path=ckpt_path) with pytest.warns( UserWarning, match= 'The dataloader, test dataloader 0, does not have many workers which may be a bottleneck.' ): trainer.test(**test_options)
def test_warning_on_wrong_test_settigs(tmpdir): """ Test the following cases related to test configuration of model: * error if `test_dataloader()` is overriden but `test_step()` is not * if both `test_dataloader()` and `test_step()` is overriden, throw warning if `test_epoch_end()` is not defined * error if `test_step()` is overriden but `test_dataloader()` is not """ tutils.reset_seed() hparams = tutils.get_default_hparams() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) # ---------------- # if have test_dataloader should have test_step # ---------------- with pytest.raises(MisconfigurationException): model = EvalModelTemplate(hparams) model.test_step = None trainer.fit(model) # ---------------- # if have test_dataloader and test_step recommend test_epoch_end # ---------------- with pytest.warns(RuntimeWarning): model = EvalModelTemplate(hparams) model.test_epoch_end = None trainer.test(model) # ---------------- # if have test_step and NO test_dataloader passed in tell user to pass test_dataloader # ---------------- with pytest.raises(MisconfigurationException): model = EvalModelTemplate(hparams) model.test_dataloader = lambda: None trainer.test(model) # ---------------- # if have test_dataloader and NO test_step tell user to implement test_step # ---------------- with pytest.raises(MisconfigurationException): model = EvalModelTemplate(hparams) model.test_dataloader = lambda: None model.test_step = None trainer.test(model, test_dataloaders=model.dataloader(train=False)) # ---------------- # if have test_dataloader and test_step but no test_epoch_end warn user # ---------------- with pytest.warns(RuntimeWarning): model = EvalModelTemplate(hparams) model.test_dataloader = lambda: None model.test_epoch_end = None trainer.test(model, test_dataloaders=model.dataloader(train=False))
def test_train_val_dataloaders_passed_to_fit(tmpdir): """ Verify that train & val dataloader can be passed to fit """ # train, val passed to fit model = EvalModelTemplate() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2) fit_options = dict(train_dataloader=model.dataloader(train=True), val_dataloaders=model.dataloader(train=False)) result = trainer.fit(model, **fit_options) assert result == 1 assert len(trainer.val_dataloaders) == 1, \ f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): """ Test that error is raised if dataloader with only a few workers is used """ model = EvalModelTemplate() model.training_step = model.training_step__multiple_dataloaders model.validation_step = model.validation_step__multiple_dataloaders model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders model.test_step = model.test_step__multiple_dataloaders model.test_epoch_end = model.test_epoch_end__multiple_dataloaders val_dl = model.dataloader(train=False) val_dl.num_workers = 0 train_dl = model.dataloader(train=False) train_dl.num_workers = 0 train_multi_dl = {'a': train_dl, 'b': train_dl} val_multi_dl = [val_dl, val_dl] test_multi_dl = [train_dl, train_dl] trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2, ) with pytest.warns( UserWarning, match= f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers' ): if stage == 'test': ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == 'specific' else ckpt_path trainer.test(model, test_dataloaders=test_multi_dl, ckpt_path=ckpt_path) else: trainer.fit(model, train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl)
def test_train_dataloader_passed_to_fit(tmpdir): """Verify that train dataloader can be passed to fit """ # only train passed to fit model = EvalModelTemplate() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2) fit_options = dict(train_dataloader=model.dataloader(train=True)) result = trainer.fit(model, **fit_options) assert result == 1
def test_all_dataloaders_passed_to_fit(tmpdir): """Verify train, val & test dataloader(s) can be passed to fit and test method""" model = EvalModelTemplate() # train, val and test passed to fit trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2) fit_options = dict(train_dataloader=model.dataloader(train=True), val_dataloaders=model.dataloader(train=False)) test_options = dict(test_dataloaders=model.dataloader(train=False)) result = trainer.fit(model, **fit_options) trainer.test(**test_options) assert result == 1 assert len(trainer.val_dataloaders) == 1, \ f'val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' assert len(trainer.test_dataloaders) == 1, \ f'test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
def test_train_dataloader_passed_to_fit(tmpdir): """Verify that train dataloader can be passed to fit """ # only train passed to fit model = EvalModelTemplate(tutils.get_default_hparams()) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2 ) fit_options = dict(train_dataloader=model.dataloader(train=True)) result = trainer.fit(model, **fit_options) assert result == 1
def test_error_on_dataloader_passed_to_fit(tmpdir): """Verify that when the auto scale batch size feature raises an error if a train dataloader is passed to fit """ # only train passed to fit model = EvalModelTemplate() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2, auto_scale_batch_size='power') fit_options = dict(train_dataloader=model.dataloader(train=True)) with pytest.raises(MisconfigurationException): trainer.fit(model, **fit_options)
def test_train_dataloader_passed_to_fit(tmpdir): """Verify that train dataloader can be passed to fit """ # only train passed to fit model = EvalModelTemplate() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2, ) fit_options = dict(train_dataloader=model.dataloader(train=True)) trainer.fit(model, **fit_options) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
def test_test_loop_config(tmpdir): """" When either test loop or test data are missing """ hparams = EvalModelTemplate.get_default_hparams() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) # has test loop but no test data with pytest.warns(UserWarning): model = EvalModelTemplate(**hparams) model.test_dataloader = None trainer.test(model) # has test data but no test loop with pytest.warns(UserWarning): model = EvalModelTemplate(**hparams) model.test_step = None trainer.test(model, test_dataloaders=model.dataloader(train=False))