def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): """ Tests that the checkpoint saved as 'last.ckpt' contains the latest information. """ seed_everything(100) model = EvalModelTemplate() num_epochs = 3 model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=num_epochs, save_last=True) trainer = Trainer( default_root_dir=tmpdir, early_stop_callback=False, checkpoint_callback=model_checkpoint, max_epochs=num_epochs, ) trainer.fit(model) path_last_epoch = model_checkpoint.format_checkpoint_name( num_epochs - 1, {}) # epoch=3.ckpt path_last = str(tmpdir / ModelCheckpoint.CHECKPOINT_NAME_LAST) # last.ckpt assert path_last_epoch != path_last ckpt_last_epoch = torch.load(path_last_epoch) ckpt_last = torch.load(path_last) matching_keys = ( "epoch", "global_step", ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE, ModelCheckpoint.CHECKPOINT_STATE_BEST_PATH, ) for key in matching_keys: assert ckpt_last_epoch[key] == ckpt_last[key] # it is easier to load the model objects than to iterate over the raw dict of tensors model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch) model_last = EvalModelTemplate.load_from_checkpoint(path_last) for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()): assert w0.eq(w1).all()
def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): """Tests that the save_last checkpoint contains the latest information.""" seed_everything(100) model = EvalModelTemplate() num_epochs = 3 model_checkpoint = ModelCheckpoint(monitor='val_loss', filepath=tmpdir, save_top_k=num_epochs, save_last=True) trainer = Trainer( default_root_dir=tmpdir, early_stop_callback=False, checkpoint_callback=model_checkpoint, max_epochs=num_epochs, ) trainer.fit(model) path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {}) assert path_last_epoch != model_checkpoint.last_model_path ckpt_last_epoch = torch.load(path_last_epoch) ckpt_last = torch.load(model_checkpoint.last_model_path) assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step")) # it is easier to load the model objects than to iterate over the raw dict of tensors model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch) model_last = EvalModelTemplate.load_from_checkpoint(model_checkpoint.last_model_path) for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()): assert w0.eq(w1).all()
def test_strict_model_load(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Tests use case where trainer saves the model, and user loads it from tags independently.""" # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir monkeypatch.setenv("TORCH_HOME", tmpdir) model = EvalModelTemplate() # Extra layer model.c_d3 = torch.nn.Linear(model.hidden_dim, model.hidden_dim) # logger file to get meta logger = tutils.get_default_logger(tmpdir) # fit model trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(dirpath=tmpdir), ) result = trainer.fit(model) # traning complete assert result == 1 # save model new_weights_path = os.path.join(tmpdir, "save_test.ckpt") trainer.save_checkpoint(new_weights_path) # load new model hparams_path = tutils.get_data_path(logger, path_dir=tmpdir) hparams_path = os.path.join(hparams_path, "hparams.yaml") ckpt_path = ( f"http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}" if url_ckpt else new_weights_path ) try: EvalModelTemplate.load_from_checkpoint( checkpoint_path=ckpt_path, hparams_file=hparams_path, ) except Exception: failed = True else: failed = False assert failed, "Model should not been loaded since the extra layer added." failed = False try: EvalModelTemplate.load_from_checkpoint( checkpoint_path=ckpt_path, hparams_file=hparams_path, strict=False, ) except Exception: failed = True assert not failed, "Model should be loaded due to strict=False."
def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): """ Tests that the save_last checkpoint contains the latest information. """ seed_everything(100) model = EvalModelTemplate() num_epochs = 3 model_checkpoint = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=num_epochs, save_last=True) trainer = Trainer( default_root_dir=tmpdir, checkpoint_callback=model_checkpoint, max_epochs=num_epochs, ) trainer.fit(model) path_last_epoch = str(tmpdir / f"epoch={num_epochs - 1}.ckpt") path_last = str(tmpdir / "last.ckpt") assert path_last == model_checkpoint.last_model_path ckpt_last_epoch = torch.load(path_last_epoch) ckpt_last = torch.load(path_last) assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step")) ch_type = type(model_checkpoint) assert ckpt_last["callbacks"][ch_type] == ckpt_last_epoch["callbacks"][ ch_type] # it is easier to load the model objects than to iterate over the raw dict of tensors model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch) model_last = EvalModelTemplate.load_from_checkpoint( model_checkpoint.last_model_path) for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()): assert w0.eq(w1).all()
def test_strict_model_load_more_params(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Tests use case where trainer saves the model, and user loads it from tags independently.""" # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir monkeypatch.setenv('TORCH_HOME', tmpdir) model = EvalModelTemplate() # Extra layer model.c_d3 = torch.nn.Linear(model.hidden_dim, model.hidden_dim) # logger file to get meta logger = tutils.get_default_logger(tmpdir) # fit model trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(tmpdir), ) result = trainer.fit(model) # traning complete assert result == 1 # save model new_weights_path = os.path.join(tmpdir, 'save_test.ckpt') trainer.save_checkpoint(new_weights_path) # load new model hparams_path = os.path.join(tutils.get_data_path(logger, path_dir=tmpdir), 'hparams.yaml') hparams_url = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}' ckpt_path = hparams_url if url_ckpt else new_weights_path EvalModelTemplate.load_from_checkpoint( checkpoint_path=ckpt_path, hparams_file=hparams_path, strict=False, ) with pytest.raises( RuntimeError, match= r'Unexpected key\(s\) in state_dict: "c_d3.weight", "c_d3.bias"'): EvalModelTemplate.load_from_checkpoint( checkpoint_path=ckpt_path, hparams_file=hparams_path, strict=True, )
def test_no_val_end_module(tmpdir): """Tests use case where trainer saves the model, and user loads it from tags independently.""" model = EvalModelTemplate(tutils.get_default_hparams()) # logger file to get meta logger = tutils.get_default_logger(tmpdir) # fit model trainer = Trainer(max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(tmpdir)) result = trainer.fit(model) # traning complete assert result == 1, 'amp + ddp model failed to complete' # save model new_weights_path = os.path.join(tmpdir, 'save_test.ckpt') trainer.save_checkpoint(new_weights_path) # load new model tags_path = tutils.get_data_path(logger, path_dir=tmpdir) tags_path = os.path.join(tags_path, 'meta_tags.csv') model_2 = EvalModelTemplate.load_from_checkpoint( checkpoint_path=new_weights_path, tags_csv=tags_path) model_2.eval()
def test_no_val_module(tmpdir): """Tests use case where trainer saves the model, and user loads it from tags independently.""" model = EvalModelTemplate(tutils.get_default_hparams()) # logger file to get meta logger = tutils.get_default_logger(tmpdir) trainer = Trainer(max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(tmpdir)) # fit model result = trainer.fit(model) # training complete assert result == 1, 'amp + ddp model failed to complete' # save model new_weights_path = os.path.join(tmpdir, 'save_test.ckpt') trainer.save_checkpoint(new_weights_path) # assert ckpt has hparams ckpt = torch.load(new_weights_path) assert 'hparams' in ckpt.keys(), 'hparams missing from checkpoints' # won't load without hparams in the ckpt model_2 = EvalModelTemplate.load_from_checkpoint( checkpoint_path=new_weights_path, ) model_2.eval()
def test_no_val_end_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Tests use case where trainer saves the model, and user loads it from tags independently.""" # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir monkeypatch.setenv('TORCH_HOME', tmpdir) model = EvalModelTemplate() # logger file to get meta logger = tutils.get_default_logger(tmpdir) # fit model trainer = Trainer(max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(tmpdir)) result = trainer.fit(model) # traning complete assert result == 1, 'amp + ddp model failed to complete' # save model new_weights_path = os.path.join(tmpdir, 'save_test.ckpt') trainer.save_checkpoint(new_weights_path) # load new model hparams_path = tutils.get_data_path(logger, path_dir=tmpdir) hparams_path = os.path.join(hparams_path, 'hparams.yaml') ckpt_path = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}' if url_ckpt else new_weights_path model_2 = EvalModelTemplate.load_from_checkpoint(checkpoint_path=ckpt_path, hparams_file=hparams_path) model_2.eval()
def test_no_val_module(tmpdir): """Tests use case where trainer saves the model, and user loads it from tags independently.""" model = EvalModelTemplate() # logger file to get meta logger = tutils.get_default_logger(tmpdir) trainer = Trainer(max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(tmpdir)) # fit model result = trainer.fit(model) # training complete assert result == 1, 'amp + ddp model failed to complete' # save model new_weights_path = os.path.join(tmpdir, 'save_test.ckpt') trainer.save_checkpoint(new_weights_path) # assert ckpt has hparams ckpt = torch.load(new_weights_path) assert CHECKPOINT_KEY_MODULE_ARGS in ckpt.keys( ), 'module_arguments missing from checkpoints' # load new model hparams_path = tutils.get_data_path(logger, path_dir=tmpdir) hparams_path = os.path.join(hparams_path, 'hparams.yaml') model_2 = EvalModelTemplate.load_from_checkpoint( checkpoint_path=new_weights_path, hparams_file=hparams_path) model_2.eval()
def test_running_test_pretrained_model_cpu(tmpdir): """Verify test() on pretrained model.""" model = EvalModelTemplate() # logger file to get meta logger = tutils.get_default_logger(tmpdir) # logger file to get weights checkpoint = tutils.init_checkpoint_callback(logger) trainer_options = dict( progress_bar_refresh_rate=0, max_epochs=3, limit_train_batches=0.4, limit_val_batches=0.2, checkpoint_callback=checkpoint, logger=logger, default_root_dir=tmpdir, ) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) # correct result and ok accuracy assert result == 1, 'training failed to complete' pretrained_model = EvalModelTemplate.load_from_checkpoint( trainer.checkpoint_callback.best_model_path) new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model) # test we have good test accuracy tutils.assert_ok_model_acc(new_trainer)
def test_model_saving_loading(tmpdir): """Tests use case where trainer saves the model, and user loads it from tags independently.""" model = EvalModelTemplate() # logger file to get meta logger = tutils.get_default_logger(tmpdir) trainer_options = dict( max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(tmpdir) ) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) # traning complete assert result == 1, 'amp + ddp model failed to complete' # make a prediction dataloaders = model.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] for dataloader in dataloaders: for batch in dataloader: break x, y = batch x = x.view(x.size(0), -1) # generate preds before saving model model.eval() pred_before_saving = model(x) # save model new_weights_path = os.path.join(tmpdir, 'save_test.ckpt') trainer.save_checkpoint(new_weights_path) # load new model hparams_path = tutils.get_data_path(logger, path_dir=tmpdir) hparams_path = os.path.join(hparams_path, 'hparams.yaml') model_2 = EvalModelTemplate.load_from_checkpoint( checkpoint_path=new_weights_path, hparams_file=hparams_path ) model_2.eval() # make prediction # assert that both predictions are the same new_pred = model_2(x) assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1
def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): """Tests that the save_last checkpoint contains the latest information.""" seed_everything(100) model = EvalModelTemplate() num_epochs = 3 ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}' model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=num_epochs, save_last=True) trainer = Trainer( default_root_dir=tmpdir, early_stop_callback=False, checkpoint_callback=model_checkpoint, max_epochs=num_epochs, ) trainer.fit(model) last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, num_epochs - 1, {}) path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {}) # epoch=3.ckpt path_last = str(tmpdir / f'{last_filename}.ckpt') # last-epoch=3.ckpt assert path_last_epoch != path_last ckpt_last_epoch = torch.load(path_last_epoch) ckpt_last = torch.load(path_last) trainer_keys = ("epoch", "global_step") for key in trainer_keys: assert ckpt_last_epoch[key] == ckpt_last[key] checkpoint_callback_keys = ("best_model_score", "best_model_path") for key in checkpoint_callback_keys: assert ( ckpt_last["callbacks"][type(model_checkpoint)][key] == ckpt_last_epoch["callbacks"][type(model_checkpoint)][key] ) # it is easier to load the model objects than to iterate over the raw dict of tensors model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch) model_last = EvalModelTemplate.load_from_checkpoint(path_last) for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()): assert w0.eq(w1).all() ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last'
def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir): """Verify `test()` on pretrained model.""" tutils.set_random_master_port() model = EvalModelTemplate() # exp file to get meta logger = tutils.get_default_logger(tmpdir) # exp file to get weights checkpoint = tutils.init_checkpoint_callback(logger) trainer_options = dict( progress_bar_refresh_rate=0, max_epochs=2, limit_train_batches=0.4, limit_val_batches=0.2, checkpoint_callback=checkpoint, logger=logger, gpus=[0, 1], distributed_backend='ddp_spawn', default_root_dir=tmpdir, ) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) log.info(os.listdir(tutils.get_data_path(logger, path_dir=tmpdir))) # correct result and ok accuracy assert result == 1, 'training failed to complete' pretrained_model = EvalModelTemplate.load_from_checkpoint( trainer.checkpoint_callback.best_model_path) # run test set new_trainer = Trainer(**trainer_options) results = new_trainer.test(pretrained_model) pretrained_model.cpu() acc = results[0]['test_acc'] assert acc > 0.5, f"Model failed to get expected {0.5} accuracy. test_acc = {acc}" dataloaders = model.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] for dataloader in dataloaders: tpipes.run_prediction(dataloader, pretrained_model)
def test_load_model_from_checkpoint(tmpdir): """Verify test() on pretrained model.""" hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(**hparams) trainer_options = dict( progress_bar_refresh_rate=0, max_epochs=2, limit_train_batches=0.4, limit_val_batches=0.2, checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1), default_root_dir=tmpdir, ) # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) trainer.test(ckpt_path=None) # correct result and ok accuracy assert result == 1, 'training failed to complete' # load last checkpoint last_checkpoint = sorted( glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1] pretrained_model = EvalModelTemplate.load_from_checkpoint(last_checkpoint) # test that hparams loaded correctly for k, v in hparams.items(): assert getattr(pretrained_model, k) == v # assert weights are the same for (old_name, old_p), (new_name, new_p) in zip(model.named_parameters(), pretrained_model.named_parameters()): assert torch.all(torch.eq( old_p, new_p)), 'loaded weights are not the same as the saved weights' new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model) # test we have good test accuracy tutils.assert_ok_model_acc(new_trainer)
def run_test_from_config(trainer_options): """Trains the default model with the given config.""" set_random_master_port() reset_seed() ckpt_path = trainer_options['weights_save_path'] trainer_options.update(checkpoint_callback=ModelCheckpoint(ckpt_path)) model = EvalModelTemplate() trainer = Trainer(**trainer_options) result = trainer.fit(model) assert result == 1 # Horovod should be initialized following training. If not, this will raise an exception. assert hvd.size() == 2 if trainer.global_rank > 0: # on higher ranks the checkpoint location is unknown # we want to test checkpointing on rank 0 only assert not hasattr(trainer, 'ckpt_path') assert not trainer.checkpoint_callback.best_model_path return # test model loading pretrained_model = EvalModelTemplate.load_from_checkpoint( trainer.checkpoint_callback.best_model_path) # test new model accuracy test_loaders = model.test_dataloader() if not isinstance(test_loaders, list): test_loaders = [test_loaders] for dataloader in test_loaders: run_prediction(dataloader, pretrained_model) # test HPC loading / saving trainer.hpc_save(ckpt_path, trainer.logger) trainer.hpc_load(ckpt_path, on_gpu=args.on_gpu) if args.on_gpu: trainer = Trainer(gpus=1, distributed_backend='horovod', max_epochs=1) # Test the root_gpu property assert trainer.root_gpu == hvd.local_rank()
def run_test_from_config(trainer_options): """Trains the default model with the given config.""" set_random_master_port() reset_seed() ckpt_path = trainer_options['weights_save_path'] trainer_options.update(callbacks=[ModelCheckpoint(dirpath=ckpt_path)]) model = EvalModelTemplate() trainer = Trainer(**trainer_options) result = trainer.fit(model) assert result == 1 # Horovod should be initialized following training. If not, this will raise an exception. assert hvd.size() == 2 if trainer.global_rank > 0: return # test model loading pretrained_model = EvalModelTemplate.load_from_checkpoint( trainer.checkpoint_callback.best_model_path) # test new model accuracy test_loaders = model.test_dataloader() if not isinstance(test_loaders, list): test_loaders = [test_loaders] for dataloader in test_loaders: run_prediction(dataloader, pretrained_model) # test HPC saving trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger) # test HPC loading checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder( ckpt_path) trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=args.on_gpu) if args.on_gpu: trainer = Trainer(gpus=1, accelerator='horovod', max_epochs=1) # Test the root_gpu property assert trainer.root_gpu == hvd.local_rank()
def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Tests use case where trainer saves the model, and user loads it from tags independently.""" # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir monkeypatch.setenv("TORCH_HOME", str(tmpdir)) model = EvalModelTemplate() # logger file to get meta logger = tutils.get_default_logger(tmpdir) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(dirpath=tmpdir), ) # fit model result = trainer.fit(model) # training complete assert result == 1, "amp + ddp model failed to complete" # save model new_weights_path = os.path.join(tmpdir, "save_test.ckpt") trainer.save_checkpoint(new_weights_path) # assert ckpt has hparams ckpt = torch.load(new_weights_path) assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in ckpt.keys(), "module_arguments missing from checkpoints" # load new model hparams_path = tutils.get_data_path(logger, path_dir=tmpdir) hparams_path = os.path.join(hparams_path, "hparams.yaml") ckpt_path = ( f"http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}" if url_ckpt else new_weights_path ) model_2 = EvalModelTemplate.load_from_checkpoint( checkpoint_path=ckpt_path, hparams_file=hparams_path, ) model_2.eval()
def test_load_past_checkpoint(tmpdir, past_key): model = EvalModelTemplate() # verify we can train trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) trainer.fit(model) # make sure the raw checkpoint saved the properties raw_checkpoint_path = _raw_checkpoint_path(trainer) raw_checkpoint = torch.load(raw_checkpoint_path) raw_checkpoint[past_key] = raw_checkpoint[ LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] raw_checkpoint[past_key]['batch_size'] = -17 del raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] # save back the checkpoint torch.save(raw_checkpoint, raw_checkpoint_path) # verify that model loads correctly model2 = EvalModelTemplate.load_from_checkpoint(raw_checkpoint_path) assert model2.hparams.batch_size == -17
def test_model_checkpoint_only_weights(tmpdir): """Tests use case where ModelCheckpoint is configured to save only model weights, and user tries to load checkpoint to resume training. """ model = EvalModelTemplate() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, checkpoint_callback=ModelCheckpoint(tmpdir, monitor='early_stop_on', save_weights_only=True), ) # fit model result = trainer.fit(model) # training complete assert result == 1, 'training failed to complete' checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0] # assert saved checkpoint has no trainer data checkpoint = torch.load(checkpoint_path) assert 'optimizer_states' not in checkpoint, 'checkpoint should contain only model weights' assert 'lr_schedulers' not in checkpoint, 'checkpoint should contain only model weights' # assert loading model works when checkpoint has only weights assert EvalModelTemplate.load_from_checkpoint( checkpoint_path=checkpoint_path) # directly save model new_weights_path = os.path.join(tmpdir, 'save_test.ckpt') trainer.save_checkpoint(new_weights_path, weights_only=True) # assert saved checkpoint has no trainer data checkpoint = torch.load(new_weights_path) assert 'optimizer_states' not in checkpoint, 'checkpoint should contain only model weights' assert 'lr_schedulers' not in checkpoint, 'checkpoint should contain only model weights' # assert restoring train state fails with pytest.raises(KeyError, match='checkpoint contains only the model'): trainer.checkpoint_connector.restore_training_state(checkpoint)