def assert_good_acc(): assert new_trainer.current_epoch == real_global_epoch and new_trainer.current_epoch > 0 # if model and state loaded correctly, predictions will be good even though we # haven't trained with the new loaded model dp_model = new_trainer.model dp_model.eval() dataloader = trainer.get_train_dataloader() tutils.run_prediction(dataloader, dp_model, dp=True)
def test_running_test_pretrained_model_ddp(tmpdir): """Verify `test()` on pretrained model.""" if not tutils.can_run_gpu_test(): return tutils.reset_seed() tutils.set_random_master_port() hparams = tutils.get_hparams() model = LightningTestModel(hparams) # exp file to get meta logger = tutils.get_test_tube_logger(tmpdir, False) # exp file to get weights checkpoint = tutils.init_checkpoint_callback(logger) trainer_options = dict(show_progress_bar=False, max_epochs=1, train_percent_check=0.4, val_percent_check=0.2, checkpoint_callback=checkpoint, logger=logger, gpus=[0, 1], distributed_backend='ddp') # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) log.info(os.listdir(tutils.get_data_path(logger, path_dir=tmpdir))) # correct result and ok accuracy assert result == 1, 'training failed to complete' pretrained_model = tutils.load_model(logger, trainer.checkpoint_callback.filepath, module_class=LightningTestModel) # run test set new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model) dataloaders = model.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] for dataloader in dataloaders: tutils.run_prediction(dataloader, pretrained_model)
def test_multiple_test_dataloader(tmpdir): """Verify multiple test_dataloader.""" tutils.reset_seed() class CurrentTestModel( LightTrainDataloader, LightTestMultipleDataloadersMixin, LightEmptyTestStep, TestModelBase, ): pass hparams = tutils.get_hparams() model = CurrentTestModel(hparams) # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=0.2 ) # fit model trainer = Trainer(**trainer_options) trainer.fit(model) trainer.test() # verify there are 2 val loaders assert len(trainer.test_dataloaders) == 2, \ 'Multiple test_dataloaders not initiated properly' # make sure predictions are good for each test set for dataloader in trainer.test_dataloaders: tutils.run_prediction(dataloader, trainer.model) # run the test method trainer.test()
def test_multiple_val_dataloader(tmpdir): """Verify multiple val_dataloader.""" tutils.reset_seed() class CurrentTestModel( LightTrainDataloader, LightValidationMultipleDataloadersMixin, TestModelBase, ): pass hparams = tutils.get_hparams() model = CurrentTestModel(hparams) # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, train_percent_check=1.0, ) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) # verify training completed assert result == 1 # verify there are 2 val loaders assert len(trainer.val_dataloaders) == 2, \ 'Multiple val_dataloaders not initiated properly' # make sure predictions are good for each val set for dataloader in trainer.val_dataloaders: tutils.run_prediction(dataloader, trainer.model)