def test_no_val_end_module(tmpdir): """Tests use case where trainer saves the model, and user loads it from tags independently.""" tutils.reset_seed() class CurrentTestModel(LightningValidationStepMixin, LightningTestModelBase): pass hparams = tutils.get_hparams() model = CurrentTestModel(hparams) # logger file to get meta logger = tutils.get_test_tube_logger(tmpdir, False) trainer_options = dict(max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(tmpdir)) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) # traning complete assert result == 1, 'amp + ddp model failed to complete' # save model new_weights_path = os.path.join(tmpdir, 'save_test.ckpt') trainer.save_checkpoint(new_weights_path) # load new model tags_path = tutils.get_data_path(logger, path_dir=tmpdir) tags_path = os.path.join(tags_path, 'meta_tags.csv') model_2 = LightningTestModel.load_from_metrics( weights_path=new_weights_path, tags_csv=tags_path) model_2.eval()
def test_model_saving_loading(): """ Tests use case where trainer saves the model, and user loads it from tags independently :return: """ reset_seed() hparams = get_hparams() model = LightningTestModel(hparams) save_dir = init_save_dir() # logger file to get meta logger = get_test_tube_logger(False) logger.log_hyperparams(hparams) logger.save() trainer_options = dict(max_nb_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(save_dir)) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) # traning complete assert result == 1, 'amp + ddp model failed to complete' # make a prediction for dataloader in model.test_dataloader(): for batch in dataloader: break x, y = batch x = x.view(x.size(0), -1) # generate preds before saving model model.eval() pred_before_saving = model(x) # save model new_weights_path = os.path.join(save_dir, 'save_test.ckpt') trainer.save_checkpoint(new_weights_path) # load new model tags_path = logger.experiment.get_data_path(logger.experiment.name, logger.experiment.version) tags_path = os.path.join(tags_path, 'meta_tags.csv') model_2 = LightningTestModel.load_from_metrics( weights_path=new_weights_path, tags_csv=tags_path) model_2.eval() # make prediction # assert that both predictions are the same new_pred = model_2(x) assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1 clear_save_dir()
def test_no_val_end_module(): """ Tests use case where trainer saves the model, and user loads it from tags independently :return: """ reset_seed() class CurrentTestModel(LightningValidationStepMixin, LightningTestModelBase): pass hparams = get_hparams() model = CurrentTestModel(hparams) save_dir = init_save_dir() # logger file to get meta logger = get_test_tube_logger(False) logger.log_hyperparams(hparams) logger.save() trainer_options = dict(max_nb_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(save_dir)) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) # traning complete assert result == 1, 'amp + ddp model failed to complete' # save model new_weights_path = os.path.join(save_dir, 'save_test.ckpt') trainer.save_checkpoint(new_weights_path) # load new model tags_path = logger.experiment.get_data_path(logger.experiment.name, logger.experiment.version) tags_path = os.path.join(tags_path, 'meta_tags.csv') model_2 = LightningTestModel.load_from_metrics( weights_path=new_weights_path, tags_csv=tags_path) model_2.eval() # make prediction clear_save_dir()
def test_no_val_end_module(): """ Tests use case where trainer saves the model, and user loads it from tags independently :return: """ class CurrentTestModel(LightningValidationStepMixin, LightningTestModelBase): pass hparams = get_hparams() model = CurrentTestModel(hparams) save_dir = init_save_dir() # exp file to get meta exp = get_exp(False) exp.argparse(hparams) exp.save() trainer_options = dict( max_nb_epochs=1, cluster=SlurmCluster(), experiment=exp, checkpoint_callback=ModelCheckpoint(save_dir) ) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) # traning complete assert result == 1, 'amp + ddp model failed to complete' # save model new_weights_path = os.path.join(save_dir, 'save_test.ckpt') trainer.save_checkpoint(new_weights_path) # load new model tags_path = exp.get_data_path(exp.name, exp.version) tags_path = os.path.join(tags_path, 'meta_tags.csv') model_2 = LightningTestModel.load_from_metrics(weights_path=new_weights_path, tags_csv=tags_path, on_gpu=False) model_2.eval() # make prediction clear_save_dir()