def test_train_val_dataloaders_passed_to_fit(tmpdir): """ Verify that train & val dataloader can be passed to fit """ tutils.reset_seed() class CurrentTestModel(LightningTestModelBaseWithoutDataloader): pass hparams = tutils.get_hparams() # logger file to get meta trainer_options = dict(default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2) # train, val passed to fit model = CurrentTestModel(hparams) trainer = Trainer(**trainer_options) fit_options = dict(train_dataloader=model._dataloader(train=True), val_dataloader=model._dataloader(train=False)) results = trainer.fit(model, **fit_options) assert len(trainer.get_val_dataloaders()) == 1, \ f'`val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}'
def test_inf_val_dataloader(tmpdir): """Test inf val data loader (e.g. IterableDataset)""" tutils.reset_seed() class CurrentTestModel(LightInfValDataloader, LightningTestModel): pass hparams = tutils.get_hparams() model = CurrentTestModel(hparams) # fit model with pytest.raises(MisconfigurationException): trainer = Trainer(default_save_path=tmpdir, max_epochs=1, val_percent_check=0.5) trainer.fit(model) # logger file to get meta trainer = Trainer(default_save_path=tmpdir, max_epochs=1) result = trainer.fit(model) # verify training completed assert result == 1
def test_running_test_without_val(tmpdir): """Verify `test()` works on a model with no `val_loader`.""" tutils.reset_seed() class CurrentTestModel(LightningTestMixin, LightningTestModelBase): pass hparams = tutils.get_hparams() model = CurrentTestModel(hparams) # logger file to get meta logger = tutils.get_test_tube_logger(tmpdir, False) # logger file to get weights checkpoint = tutils.init_checkpoint_callback(logger) trainer_options = dict( show_progress_bar=False, max_epochs=1, train_percent_check=0.4, val_percent_check=0.2, test_percent_check=0.2, checkpoint_callback=checkpoint, logger=logger, early_stop_callback=False ) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) assert result == 1, 'training failed to complete' trainer.test() # test we have good test accuracy tutils.assert_ok_model_acc(trainer)
def test_trains_logger(tmpdir): """Verify that basic functionality of TRAINS logger works.""" tutils.reset_seed() hparams = tutils.get_hparams() model = LightningTestModel(hparams) TrainsLogger.set_bypass_mode(True) TrainsLogger.set_credentials(api_host='http://integration.trains.allegro.ai:8008', files_host='http://integration.trains.allegro.ai:8081', web_host='http://integration.trains.allegro.ai:8080', ) logger = TrainsLogger(project_name="lightning_log", task_name="pytorch lightning test") trainer_options = dict( default_save_path=tmpdir, max_epochs=1, train_percent_check=0.05, logger=logger ) trainer = Trainer(**trainer_options) result = trainer.fit(model) print('result finished') logger.finalize() assert result == 1, "Training failed"
def test_dataloader_config_errors(tmpdir): tutils.reset_seed() class CurrentTestModel( LightTrainDataloader, TestModelBase, ): pass hparams = tutils.get_hparams() model = CurrentTestModel(hparams) # percent check < 0 # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, train_percent_check=-0.1, ) # fit model trainer = Trainer(**trainer_options) with pytest.raises(ValueError): trainer.fit(model) # percent check > 1 # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, train_percent_check=1.1, ) # fit model trainer = Trainer(**trainer_options) with pytest.raises(ValueError): trainer.fit(model) # int val_check_interval > num batches # logger file to get meta trainer_options = dict(default_save_path=tmpdir, max_epochs=1, val_check_interval=10000) # fit model trainer = Trainer(**trainer_options) with pytest.raises(ValueError): trainer.fit(model) # float val_check_interval > 1 # logger file to get meta trainer_options = dict(default_save_path=tmpdir, max_epochs=1, val_check_interval=1.1) # fit model trainer = Trainer(**trainer_options) with pytest.raises(ValueError): trainer.fit(model)