def test_hparams_save_load(tmpdir): model = EvalModelTemplate(vars(tutils.get_default_hparams())) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, ) # fit model result = trainer.fit(model) assert result == 1 # try to load the model now pretrained_model = tutils.load_model_from_checkpoint( trainer.checkpoint_callback.dirpath, module_class=EvalModelTemplate) assert pretrained_model
def test_hparams_save_load(tmpdir): model = DictHparamsModel({'in_features': 28 * 28, 'out_features': 10, 'failed_key': lambda x: x}) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, ) # fit model result = trainer.fit(model) assert result == 1 # try to load the model now pretrained_model = tutils.load_model_from_checkpoint( trainer.checkpoint_callback.dirpath, module_class=DictHparamsModel )
def test_hparams_save_load(tmpdir): model = DictHparamsModel({'in_features': 28 * 28, 'out_features': 10}) # logger file to get meta trainer_options = dict( default_save_path=tmpdir, max_epochs=1, ) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) assert result == 1 # try to load the model now pretrained_model = tutils.load_model_from_checkpoint( trainer.checkpoint_callback.dirpath, module_class=DictHparamsModel)