def test_multiple_val_dataloader(tmpdir): """Verify multiple val_dataloader.""" model = EvalModelTemplate() model.val_dataloader = model.val_dataloader__multiple model.validation_step = model.validation_step__multiple_dataloaders model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders # fit model trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=1.0, ) trainer.fit(model) # verify training completed assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" # verify there are 2 val loaders assert len(trainer.val_dataloaders) == 2, \ 'Multiple val_dataloaders not initiated properly' # make sure predictions are good for each val set for dataloader in trainer.val_dataloaders: tpipes.run_prediction_eval_model_template(trained_model=model, dataloader=dataloader)
def on_pretrain_routine_end(self): assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0 # if model and state loaded correctly, predictions will be good even though we # haven't trained with the new loaded model new_trainer.state.stage = RunningStage.VALIDATING dataloader = self.train_dataloader() tpipes.run_prediction_eval_model_template(self.trainer.lightning_module, dataloader=dataloader) self.on_pretrain_routine_end_called = True
def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir): """Verify `test()` on pretrained model.""" tutils.set_random_master_port() dm = ClassifDataModule() model = ClassificationModel() # exp file to get meta logger = tutils.get_default_logger(tmpdir) # exp file to get weights checkpoint = tutils.init_checkpoint_callback(logger) trainer_options = dict( progress_bar_refresh_rate=0, max_epochs=2, limit_train_batches=2, limit_val_batches=2, callbacks=[checkpoint], logger=logger, gpus=[0, 1], accelerator='ddp_spawn', default_root_dir=tmpdir, ) # fit model trainer = Trainer(**trainer_options) trainer.fit(model, datamodule=dm) log.info(os.listdir(tutils.get_data_path(logger, path_dir=tmpdir))) # correct result and ok accuracy assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" pretrained_model = ClassificationModel.load_from_checkpoint( trainer.checkpoint_callback.best_model_path) # run test set new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model) pretrained_model.cpu() dataloaders = dm.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] for dataloader in dataloaders: tpipes.run_prediction_eval_model_template(pretrained_model, dataloader, min_acc=0.1)
def test_running_test_pretrained_model_distrib_dp(tmpdir): """Verify `test()` on pretrained model.""" tutils.set_random_main_port() dm = ClassifDataModule() model = CustomClassificationModelDP(lr=0.1) # exp file to get meta logger = tutils.get_default_logger(tmpdir) # exp file to get weights checkpoint = tutils.init_checkpoint_callback(logger) trainer_options = dict( enable_progress_bar=False, max_epochs=2, limit_train_batches=5, limit_val_batches=5, callbacks=[checkpoint], logger=logger, gpus=[0, 1], strategy="dp", default_root_dir=tmpdir, ) # fit model trainer = Trainer(**trainer_options) trainer.fit(model, datamodule=dm) # correct result and ok accuracy assert trainer.state.finished, f"Training failed with {trainer.state}" pretrained_model = CustomClassificationModelDP.load_from_checkpoint( trainer.checkpoint_callback.best_model_path) # run test set new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model, datamodule=dm) pretrained_model.cpu() dataloaders = dm.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] for dataloader in dataloaders: tpipes.run_prediction_eval_model_template(pretrained_model, dataloader)
def test_multiple_eval_dataloader(tmpdir, ckpt_path): """Verify multiple evaluation dataloaders.""" class MultipleTestDataloaderModel(EvalModelTemplate): def test_dataloader(self): return [self.dataloader(train=False), self.dataloader(train=False)] def test_step(self, *args, **kwargs): return super().test_step__multiple_dataloaders(*args, **kwargs) def val_dataloader(self): return self.test_dataloader() def validation_step(self, *args, **kwargs): output = self.test_step(*args, **kwargs) return {k.replace("test_", "val_"): v for k, v in output.items()} model = MultipleTestDataloaderModel() # fit model trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_val_batches=10, limit_train_batches=100, ) trainer.fit(model) if ckpt_path == 'specific': ckpt_path = trainer.checkpoint_callback.best_model_path trainer.validate(ckpt_path=ckpt_path, verbose=False) # verify there are 2 loaders assert len(trainer.val_dataloaders) == 2 # make sure predictions are good for each dl for dataloader in trainer.val_dataloaders: tpipes.run_prediction_eval_model_template(trainer.model, dataloader) trainer.test(ckpt_path=ckpt_path, verbose=False) assert len(trainer.test_dataloaders) == 2 for dataloader in trainer.test_dataloaders: tpipes.run_prediction_eval_model_template(trainer.model, dataloader)
def test_multiple_test_dataloader(tmpdir, ckpt_path): """Verify multiple test_dataloader.""" model_template = EvalModelTemplate() class MultipleTestDataloaderModel(EvalModelTemplate): def test_dataloader(self): return [self.dataloader(train=False), self.dataloader(train=False)] def test_step(self, batch, batch_idx, *args, **kwargs): return model_template.test_step__multiple_dataloaders( batch, batch_idx, *args, **kwargs) model = MultipleTestDataloaderModel() # fit model trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_val_batches=10, limit_train_batches=100, ) trainer.fit(model) if ckpt_path == 'specific': ckpt_path = trainer.checkpoint_callback.best_model_path trainer.test(ckpt_path=ckpt_path) # verify there are 2 test loaders assert len(trainer.test_dataloaders ) == 2, 'Multiple test_dataloaders not initiated properly' # make sure predictions are good for each test set for dataloader in trainer.test_dataloaders: tpipes.run_prediction_eval_model_template(trainer.model, dataloader) # run the test method trainer.test(ckpt_path=ckpt_path)