def on_pretrain_routine_end(self): assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0 # if model and state loaded correctly, predictions will be good even though we # haven't trained with the new loaded model new_trainer.state.stage = RunningStage.VALIDATING dataloader = dm.train_dataloader() tpipes.run_model_prediction(self.trainer.lightning_module, dataloader=dataloader) self.on_pretrain_routine_end_called = True
def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir): """Verify `test()` on pretrained model.""" tutils.set_random_main_port() dm = ClassifDataModule() model = ClassificationModel() # exp file to get meta logger = tutils.get_default_logger(tmpdir) # exp file to get weights checkpoint = tutils.init_checkpoint_callback(logger) trainer_options = dict( enable_progress_bar=False, max_epochs=2, limit_train_batches=2, limit_val_batches=2, callbacks=[checkpoint], logger=logger, accelerator="gpu", devices=[0, 1], strategy="ddp_spawn", default_root_dir=tmpdir, ) # fit model trainer = Trainer(**trainer_options) trainer.fit(model, datamodule=dm) log.info(os.listdir(tutils.get_data_path(logger, path_dir=tmpdir))) # correct result and ok accuracy assert trainer.state.finished, f"Training failed with {trainer.state}" pretrained_model = ClassificationModel.load_from_checkpoint( trainer.checkpoint_callback.best_model_path) # run test set new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model, datamodule=dm) pretrained_model.cpu() dataloaders = dm.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] for dataloader in dataloaders: tpipes.run_model_prediction(pretrained_model, dataloader, min_acc=0.1)
def on_validation_start(self): assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0 dataloader = dm.val_dataloader() tpipes.run_model_prediction(self.trainer.lightning_module, dataloader=dataloader)