def test_testpass_overrides(tmpdir): # todo: check duplicated tests against trainer_checks hparams = EvalModelTemplate.get_default_hparams() # Misconfig when neither test_step or test_end is implemented with pytest.raises(MisconfigurationException, match='.*not implement `test_dataloader`.*'): model = EvalModelTemplate(**hparams) model.test_dataloader = LightningModule.test_dataloader Trainer().test(model) # Misconfig when neither test_step or test_end is implemented with pytest.raises(MisconfigurationException): model = EvalModelTemplate(**hparams) model.test_step = LightningModule.test_step Trainer().test(model) # No exceptions when one or both of test_step or test_end are implemented model = EvalModelTemplate(**hparams) model.test_step_end = LightningModule.test_step_end Trainer().test(model) model = EvalModelTemplate(**hparams) Trainer().test(model)
def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): """Verify num_batches for train, val & test dataloaders passed with batch limit in percent""" model = EvalModelTemplate() model.val_dataloader = model.val_dataloader__multiple_mixed_length model.test_dataloader = model.test_dataloader__multiple_mixed_length model.validation_step = model.validation_step__multiple_dataloaders model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders model.test_step = model.test_step__multiple_dataloaders model.test_epoch_end = model.test_epoch_end__multiple_dataloaders # train, multiple val and multiple test passed with percent_check trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=limit_train_batches, limit_val_batches=limit_val_batches, limit_test_batches=limit_test_batches, ) trainer.fit(model) expected_train_batches = int( len(trainer.train_dataloader) * limit_train_batches) expected_val_batches = [ int(len(dataloader) * limit_val_batches) for dataloader in trainer.val_dataloaders ] assert trainer.num_training_batches == expected_train_batches assert trainer.num_val_batches == expected_val_batches trainer.test(ckpt_path=None) expected_test_batches = [ int(len(dataloader) * limit_test_batches) for dataloader in trainer.test_dataloaders ] assert trainer.num_test_batches == expected_test_batches
def test_full_train_loop_with_results_obj_dp(tmpdir): os.environ['PL_DEV_DEBUG'] = '1' batches = 10 epochs = 3 model = EvalModelTemplate() model.validation_step = None model.test_step = None model.training_step = model.training_step_full_loop_result_obj_dp model.training_step_end = model.training_step_end_full_loop_result_obj_dp model.training_epoch_end = model.training_epoch_end_full_loop_result_obj_dp model.val_dataloader = None model.test_dataloader = None trainer = Trainer( default_root_dir=tmpdir, distributed_backend='dp', gpus=[0, 1], max_epochs=epochs, early_stop_callback=True, row_log_interval=2, limit_train_batches=batches, weights_summary=None, ) trainer.fit(model) # make sure we saw all the correct keys seen_keys = set() for metric in trainer.dev_debugger.logged_metrics: seen_keys.update(metric.keys()) assert 'train_step_metric' in seen_keys assert 'train_step_end_metric' in seen_keys assert 'epoch_train_epoch_end_metric' in seen_keys
def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): """ Test that error is raised if dataloader with only a few workers is used """ model = EvalModelTemplate() model.training_step = model.training_step__multiple_dataloaders model.validation_step = model.validation_step__multiple_dataloaders model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders model.test_step = model.test_step__multiple_dataloaders model.test_epoch_end = model.test_epoch_end__multiple_dataloaders val_dl = model.dataloader(train=False) val_dl.num_workers = 0 train_dl = model.dataloader(train=False) train_dl.num_workers = 0 train_multi_dl = {'a': train_dl, 'b': train_dl} val_multi_dl = [val_dl, val_dl] test_multi_dl = [train_dl, train_dl] trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2, ) with pytest.warns( UserWarning, match=f'The dataloader, {stage} dataloader{" 0" if stage != "train" else ""}, does not have many workers' ): if stage == 'test': ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == 'specific' else ckpt_path trainer.test(model, test_dataloaders=test_multi_dl, ckpt_path=ckpt_path) else: trainer.fit(model, train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl)
def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): """Verify num_batches for train, val & test dataloaders passed with batch limit as number""" model = EvalModelTemplate() model.val_dataloader = model.val_dataloader__multiple_mixed_length model.test_dataloader = model.test_dataloader__multiple_mixed_length model.validation_step = model.validation_step__multiple_dataloaders model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders model.test_step = model.test_step__multiple_dataloaders model.test_epoch_end = model.test_epoch_end__multiple_dataloaders # train, multiple val and multiple test passed with percent_check trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=limit_train_batches, limit_val_batches=limit_val_batches, limit_test_batches=limit_test_batches, ) trainer.fit(model) # ------------------------------------------- # MAKE SURE THE TRAINER SET THE CORRECT VALUES # ------------------------------------------- assert trainer.num_training_batches == limit_train_batches assert trainer.num_val_batches == [limit_val_batches] * len( trainer.val_dataloaders) trainer.test(ckpt_path=None) # when the limit is greater than the number of test batches it should be the num in loaders test_dataloader_lengths = [len(x) for x in model.test_dataloader()] if limit_test_batches > 1e10: assert trainer.num_test_batches == test_dataloader_lengths else: assert trainer.num_test_batches == [limit_test_batches] * len( trainer.test_dataloaders) # ------------------------------------------- # make sure we actually saw the expected num of batches # ------------------------------------------- num_val_dataloaders = len(model.val_dataloader()) num_test_dataloaders = len(model.test_dataloader()) if limit_train_batches > 0: # make sure val batches are as expected assert len(trainer.dev_debugger.num_seen_val_check_batches ) == num_val_dataloaders for dataloader_idx, num_batches in trainer.dev_debugger.num_seen_val_check_batches.items( ): assert num_batches == limit_val_batches # make sure test batches are as expected assert len(trainer.dev_debugger.num_seen_test_check_batches ) == num_test_dataloaders for dataloader_idx, num_batches in trainer.dev_debugger.num_seen_test_check_batches.items( ): if limit_test_batches > 1e10: assert num_batches == test_dataloader_lengths[dataloader_idx] else: assert num_batches == limit_test_batches
def test_wrong_test_settigs(tmpdir): """ Test the following cases related to test configuration of model: * error if `test_dataloader()` is overridden but `test_step()` is not * if both `test_dataloader()` and `test_step()` is overridden, throw warning if `test_epoch_end()` is not defined * error if `test_step()` is overridden but `test_dataloader()` is not """ hparams = EvalModelTemplate.get_default_hparams() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) # ---------------- # if have test_dataloader should have test_step # ---------------- with pytest.raises(MisconfigurationException): model = EvalModelTemplate(hparams) model.test_step = None trainer.fit(model) # ---------------- # if have test_dataloader and test_step recommend test_epoch_end # ---------------- with pytest.warns(RuntimeWarning): model = EvalModelTemplate(hparams) model.test_epoch_end = None trainer.test(model) # ---------------- # if have test_step and NO test_dataloader passed in tell user to pass test_dataloader # ---------------- with pytest.raises(MisconfigurationException): model = EvalModelTemplate(hparams) model.test_dataloader = LightningModule.test_dataloader trainer.test(model) # ---------------- # if have test_dataloader and NO test_step tell user to implement test_step # ---------------- with pytest.raises(MisconfigurationException): model = EvalModelTemplate(hparams) model.test_dataloader = LightningModule.test_dataloader model.test_step = None trainer.test(model, test_dataloaders=model.dataloader(train=False)) # ---------------- # if have test_dataloader and test_step but no test_epoch_end warn user # ---------------- with pytest.warns(RuntimeWarning): model = EvalModelTemplate(hparams) model.test_dataloader = LightningModule.test_dataloader model.test_epoch_end = None trainer.test(model, test_dataloaders=model.dataloader(train=False)) # ---------------- # if we are just testing, no need for train_dataloader, train_step, val_dataloader, and val_step # ---------------- model = EvalModelTemplate(hparams) model.test_dataloader = LightningModule.test_dataloader model.train_dataloader = None model.train_step = None model.val_dataloader = None model.val_step = None trainer.test(model, test_dataloaders=model.dataloader(train=False))