def test_strict_model_load(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Tests use case where trainer saves the model, and user loads it from tags independently.""" # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir monkeypatch.setenv("TORCH_HOME", tmpdir) model = EvalModelTemplate() # Extra layer model.c_d3 = torch.nn.Linear(model.hidden_dim, model.hidden_dim) # logger file to get meta logger = tutils.get_default_logger(tmpdir) # fit model trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(dirpath=tmpdir), ) result = trainer.fit(model) # traning complete assert result == 1 # save model new_weights_path = os.path.join(tmpdir, "save_test.ckpt") trainer.save_checkpoint(new_weights_path) # load new model hparams_path = tutils.get_data_path(logger, path_dir=tmpdir) hparams_path = os.path.join(hparams_path, "hparams.yaml") ckpt_path = ( f"http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}" if url_ckpt else new_weights_path ) try: EvalModelTemplate.load_from_checkpoint( checkpoint_path=ckpt_path, hparams_file=hparams_path, ) except Exception: failed = True else: failed = False assert failed, "Model should not been loaded since the extra layer added." failed = False try: EvalModelTemplate.load_from_checkpoint( checkpoint_path=ckpt_path, hparams_file=hparams_path, strict=False, ) except Exception: failed = True assert not failed, "Model should be loaded due to strict=False."
def test_strict_model_load_more_params(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Tests use case where trainer saves the model, and user loads it from tags independently.""" # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir monkeypatch.setenv('TORCH_HOME', tmpdir) model = EvalModelTemplate() # Extra layer model.c_d3 = torch.nn.Linear(model.hidden_dim, model.hidden_dim) # logger file to get meta logger = tutils.get_default_logger(tmpdir) # fit model trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(tmpdir), ) result = trainer.fit(model) # traning complete assert result == 1 # save model new_weights_path = os.path.join(tmpdir, 'save_test.ckpt') trainer.save_checkpoint(new_weights_path) # load new model hparams_path = os.path.join(tutils.get_data_path(logger, path_dir=tmpdir), 'hparams.yaml') hparams_url = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}' ckpt_path = hparams_url if url_ckpt else new_weights_path EvalModelTemplate.load_from_checkpoint( checkpoint_path=ckpt_path, hparams_file=hparams_path, strict=False, ) with pytest.raises( RuntimeError, match= r'Unexpected key\(s\) in state_dict: "c_d3.weight", "c_d3.bias"'): EvalModelTemplate.load_from_checkpoint( checkpoint_path=ckpt_path, hparams_file=hparams_path, strict=True, )