def test_model_checkpoint_save_last(tmpdir): """Tests that save_last produces only one last checkpoint.""" model = EvalModelTemplate() epochs = 3 ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}' model_checkpoint = ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir, save_top_k=-1, save_last=True) trainer = Trainer( default_root_dir=tmpdir, early_stop_callback=False, checkpoint_callback=model_checkpoint, max_epochs=epochs, logger=False, ) trainer.fit(model) last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, epochs - 1, {}) last_filename = last_filename + '.ckpt' assert str(tmpdir / last_filename) == model_checkpoint.last_model_path assert set(os.listdir(tmpdir)) == set([f'epoch={i}.ckpt' for i in range(epochs)] + [last_filename]) ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last'
def test_model_checkpoint_format_checkpoint_name(tmpdir): # empty filename: ckpt_name = ModelCheckpoint._format_checkpoint_name("", {"epoch": 3, "step": 2}) assert ckpt_name == "epoch=3-step=2" ckpt_name = ModelCheckpoint._format_checkpoint_name(None, {"epoch": 3, "step": 2}, prefix="test") assert ckpt_name == "test-epoch=3-step=2" # no groups case: ckpt_name = ModelCheckpoint._format_checkpoint_name("ckpt", {}, prefix="test") assert ckpt_name == "test-ckpt" # no prefix ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch:03d}-{acc}", {"epoch": 3, "acc": 0.03}) assert ckpt_name == "epoch=003-acc=0.03" # prefix char_org = ModelCheckpoint.CHECKPOINT_JOIN_CHAR ModelCheckpoint.CHECKPOINT_JOIN_CHAR = "@" ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch},{acc:.5f}", {"epoch": 3, "acc": 0.03}, prefix="test") assert ckpt_name == "test@epoch=3,acc=0.03000" ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org # no dirpath set ckpt_name = ModelCheckpoint(monitor="early_stop_on", dirpath=None).format_checkpoint_name({"epoch": 3, "step": 2}) assert ckpt_name == "epoch=3-step=2.ckpt" ckpt_name = ModelCheckpoint(monitor="early_stop_on", dirpath="").format_checkpoint_name({"epoch": 5, "step": 4}) assert ckpt_name == "epoch=5-step=4.ckpt" # CWD ckpt_name = ModelCheckpoint(monitor="early_stop_on", dirpath=".").format_checkpoint_name({"epoch": 3, "step": 4}) assert ckpt_name == str(Path(".").resolve() / "epoch=3-step=4.ckpt") # with version ckpt = ModelCheckpoint(monitor="early_stop_on", dirpath=tmpdir, filename="name") ckpt_name = ckpt.format_checkpoint_name({}, ver=3) assert ckpt_name == tmpdir / "name-v3.ckpt" # using slashes ckpt = ModelCheckpoint(monitor="early_stop_on", dirpath=None, filename="{epoch}_{val/loss:.5f}") ckpt_name = ckpt.format_checkpoint_name({"epoch": 4, "val/loss": 0.03}) assert ckpt_name == "epoch=4_val/loss=0.03000.ckpt" # auto_insert_metric_name=False ckpt_name = ModelCheckpoint._format_checkpoint_name( "epoch={epoch:03d}-val_acc={val/acc}", {"epoch": 3, "val/acc": 0.03}, auto_insert_metric_name=False ) assert ckpt_name == "epoch=003-val_acc=0.03"
def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): """Tests that the save_last checkpoint contains the latest information.""" seed_everything(100) model = EvalModelTemplate() num_epochs = 3 ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}' model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=num_epochs, save_last=True) trainer = Trainer( default_root_dir=tmpdir, early_stop_callback=False, checkpoint_callback=model_checkpoint, max_epochs=num_epochs, ) trainer.fit(model) last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, num_epochs - 1, {}) path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {}) # epoch=3.ckpt path_last = str(tmpdir / f'{last_filename}.ckpt') # last-epoch=3.ckpt assert path_last_epoch != path_last ckpt_last_epoch = torch.load(path_last_epoch) ckpt_last = torch.load(path_last) trainer_keys = ("epoch", "global_step") for key in trainer_keys: assert ckpt_last_epoch[key] == ckpt_last[key] checkpoint_callback_keys = ("best_model_score", "best_model_path") for key in checkpoint_callback_keys: assert ( ckpt_last["callbacks"][type(model_checkpoint)][key] == ckpt_last_epoch["callbacks"][type(model_checkpoint)][key] ) # it is easier to load the model objects than to iterate over the raw dict of tensors model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch) model_last = EvalModelTemplate.load_from_checkpoint(path_last) for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()): assert w0.eq(w1).all() ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last'
def test_model_checkpoint_format_checkpoint_name(tmpdir): # empty filename: ckpt_name = ModelCheckpoint._format_checkpoint_name('', 3, 2, {}) assert ckpt_name == 'epoch=3-step=2' ckpt_name = ModelCheckpoint._format_checkpoint_name(None, 3, 2, {}, prefix='test') assert ckpt_name == 'test-epoch=3-step=2' # no groups case: ckpt_name = ModelCheckpoint._format_checkpoint_name('ckpt', 3, 2, {}, prefix='test') assert ckpt_name == 'test-ckpt' # no prefix ckpt_name = ModelCheckpoint._format_checkpoint_name( '{epoch:03d}-{acc}', 3, 2, {'acc': 0.03}) assert ckpt_name == 'epoch=003-acc=0.03' # prefix char_org = ModelCheckpoint.CHECKPOINT_JOIN_CHAR ModelCheckpoint.CHECKPOINT_JOIN_CHAR = '@' ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch},{acc:.5f}', 3, 2, {'acc': 0.03}, prefix='test') assert ckpt_name == 'test@epoch=3,acc=0.03000' ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org # no dirpath set ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath=None).format_checkpoint_name(3, 2, {}) assert ckpt_name == 'epoch=3-step=2.ckpt' ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='').format_checkpoint_name(5, 4, {}) assert ckpt_name == 'epoch=5-step=4.ckpt' # CWD ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='.').format_checkpoint_name(3, 4, {}) assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=4.ckpt') # with ver ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, filename='name', prefix='test').format_checkpoint_name(3, 2, {}, ver=3) assert ckpt_name == tmpdir / 'test-name-v3.ckpt' # using slashes ckpt_name = ModelCheckpoint( monitor='early_stop_on', dirpath=None, filename='{epoch}_{val/loss:.5f}').format_checkpoint_name( 4, 3, {'val/loss': 0.03}) assert ckpt_name == 'epoch=4_val/loss=0.03000.ckpt' # TODO: Checks with filepath. To be removed in v1.2 # CWD ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='.').format_checkpoint_name(3, 2, {}) assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=2.ckpt') # dir does not exist so it is used as filename filepath = tmpdir / 'dir' ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name( 3, 2, {}) assert ckpt_name == tmpdir / 'test-dir.ckpt' # now, dir exists os.mkdir(filepath) ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name( 3, 2, {}) assert ckpt_name == filepath / 'test-epoch=3-step=2.ckpt'
def test_model_checkpoint_format_checkpoint_name(tmpdir): # empty filename: ckpt_name = ModelCheckpoint._format_checkpoint_name('', 3, 2, {}) assert ckpt_name == 'epoch=3-step=2' ckpt_name = ModelCheckpoint._format_checkpoint_name(None, 3, 2, {}, prefix='test') assert ckpt_name == 'test-epoch=3-step=2' # no groups case: ckpt_name = ModelCheckpoint._format_checkpoint_name('ckpt', 3, 2, {}, prefix='test') assert ckpt_name == 'test-ckpt' # no prefix ckpt_name = ModelCheckpoint._format_checkpoint_name( '{epoch:03d}-{acc}', 3, 2, {'acc': 0.03}) assert ckpt_name == 'epoch=003-acc=0.03' # prefix char_org = ModelCheckpoint.CHECKPOINT_JOIN_CHAR ModelCheckpoint.CHECKPOINT_JOIN_CHAR = '@' ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch},{acc:.5f}', 3, 2, {'acc': 0.03}, prefix='test') assert ckpt_name == 'test@epoch=3,acc=0.03000' ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org # no dirpath set ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath=None).format_checkpoint_name(3, 2, {}) assert ckpt_name == 'epoch=3-step=2.ckpt' ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='').format_checkpoint_name(5, 4, {}) assert ckpt_name == 'epoch=5-step=4.ckpt' # CWD ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='.').format_checkpoint_name(3, 4, {}) assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=4.ckpt') # with version ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, filename='name') ckpt_name = ckpt.format_checkpoint_name(3, 2, {}, ver=3) assert ckpt_name == tmpdir / 'name-v3.ckpt' # using slashes ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=None, filename='{epoch}_{val/loss:.5f}') ckpt_name = ckpt.format_checkpoint_name(4, 3, {'val/loss': 0.03}) assert ckpt_name == 'epoch=4_val/loss=0.03000.ckpt' # auto_insert_metric_name=False ckpt_name = ModelCheckpoint._format_checkpoint_name( 'epoch={epoch:03d}-val_acc={val/acc}', 3, 2, {'val/acc': 0.03}, auto_insert_metric_name=False) assert ckpt_name == 'epoch=003-val_acc=0.03'