Beispiel #1
0
def test_model_checkpoint_save_last(tmpdir):
    """Tests that save_last produces only one last checkpoint."""
    model = EvalModelTemplate()
    epochs = 3
    ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}'
    model_checkpoint = ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir, save_top_k=-1, save_last=True)
    trainer = Trainer(
        default_root_dir=tmpdir,
        early_stop_callback=False,
        checkpoint_callback=model_checkpoint,
        max_epochs=epochs,
        logger=False,
    )
    trainer.fit(model)
    last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, epochs - 1, {})
    last_filename = last_filename + '.ckpt'
    assert str(tmpdir / last_filename) == model_checkpoint.last_model_path
    assert set(os.listdir(tmpdir)) == set([f'epoch={i}.ckpt' for i in range(epochs)] + [last_filename])
    ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last'
Beispiel #2
0
def test_model_checkpoint_format_checkpoint_name(tmpdir):
    # empty filename:
    ckpt_name = ModelCheckpoint._format_checkpoint_name("", {"epoch": 3, "step": 2})
    assert ckpt_name == "epoch=3-step=2"

    ckpt_name = ModelCheckpoint._format_checkpoint_name(None, {"epoch": 3, "step": 2}, prefix="test")
    assert ckpt_name == "test-epoch=3-step=2"

    # no groups case:
    ckpt_name = ModelCheckpoint._format_checkpoint_name("ckpt", {}, prefix="test")
    assert ckpt_name == "test-ckpt"

    # no prefix
    ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch:03d}-{acc}", {"epoch": 3, "acc": 0.03})
    assert ckpt_name == "epoch=003-acc=0.03"

    # prefix
    char_org = ModelCheckpoint.CHECKPOINT_JOIN_CHAR
    ModelCheckpoint.CHECKPOINT_JOIN_CHAR = "@"
    ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch},{acc:.5f}", {"epoch": 3, "acc": 0.03}, prefix="test")
    assert ckpt_name == "test@epoch=3,acc=0.03000"
    ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org

    # no dirpath set
    ckpt_name = ModelCheckpoint(monitor="early_stop_on", dirpath=None).format_checkpoint_name({"epoch": 3, "step": 2})
    assert ckpt_name == "epoch=3-step=2.ckpt"
    ckpt_name = ModelCheckpoint(monitor="early_stop_on", dirpath="").format_checkpoint_name({"epoch": 5, "step": 4})
    assert ckpt_name == "epoch=5-step=4.ckpt"

    # CWD
    ckpt_name = ModelCheckpoint(monitor="early_stop_on", dirpath=".").format_checkpoint_name({"epoch": 3, "step": 4})
    assert ckpt_name == str(Path(".").resolve() / "epoch=3-step=4.ckpt")

    # with version
    ckpt = ModelCheckpoint(monitor="early_stop_on", dirpath=tmpdir, filename="name")
    ckpt_name = ckpt.format_checkpoint_name({}, ver=3)
    assert ckpt_name == tmpdir / "name-v3.ckpt"

    # using slashes
    ckpt = ModelCheckpoint(monitor="early_stop_on", dirpath=None, filename="{epoch}_{val/loss:.5f}")
    ckpt_name = ckpt.format_checkpoint_name({"epoch": 4, "val/loss": 0.03})
    assert ckpt_name == "epoch=4_val/loss=0.03000.ckpt"

    # auto_insert_metric_name=False
    ckpt_name = ModelCheckpoint._format_checkpoint_name(
        "epoch={epoch:03d}-val_acc={val/acc}", {"epoch": 3, "val/acc": 0.03}, auto_insert_metric_name=False
    )
    assert ckpt_name == "epoch=003-val_acc=0.03"
Beispiel #3
0
def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
    """Tests that the save_last checkpoint contains the latest information."""
    seed_everything(100)
    model = EvalModelTemplate()
    num_epochs = 3
    ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}'
    model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=num_epochs, save_last=True)
    trainer = Trainer(
        default_root_dir=tmpdir,
        early_stop_callback=False,
        checkpoint_callback=model_checkpoint,
        max_epochs=num_epochs,
    )
    trainer.fit(model)
    last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, num_epochs - 1, {})
    path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {})  # epoch=3.ckpt
    path_last = str(tmpdir / f'{last_filename}.ckpt')  # last-epoch=3.ckpt
    assert path_last_epoch != path_last
    ckpt_last_epoch = torch.load(path_last_epoch)
    ckpt_last = torch.load(path_last)

    trainer_keys = ("epoch", "global_step")
    for key in trainer_keys:
        assert ckpt_last_epoch[key] == ckpt_last[key]

    checkpoint_callback_keys = ("best_model_score", "best_model_path")
    for key in checkpoint_callback_keys:
        assert (
            ckpt_last["callbacks"][type(model_checkpoint)][key]
            == ckpt_last_epoch["callbacks"][type(model_checkpoint)][key]
        )

    # it is easier to load the model objects than to iterate over the raw dict of tensors
    model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch)
    model_last = EvalModelTemplate.load_from_checkpoint(path_last)
    for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()):
        assert w0.eq(w1).all()
    ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last'
def test_model_checkpoint_format_checkpoint_name(tmpdir):
    # empty filename:
    ckpt_name = ModelCheckpoint._format_checkpoint_name('', 3, 2, {})
    assert ckpt_name == 'epoch=3-step=2'

    ckpt_name = ModelCheckpoint._format_checkpoint_name(None,
                                                        3,
                                                        2, {},
                                                        prefix='test')
    assert ckpt_name == 'test-epoch=3-step=2'

    # no groups case:
    ckpt_name = ModelCheckpoint._format_checkpoint_name('ckpt',
                                                        3,
                                                        2, {},
                                                        prefix='test')
    assert ckpt_name == 'test-ckpt'

    # no prefix
    ckpt_name = ModelCheckpoint._format_checkpoint_name(
        '{epoch:03d}-{acc}', 3, 2, {'acc': 0.03})
    assert ckpt_name == 'epoch=003-acc=0.03'

    # prefix
    char_org = ModelCheckpoint.CHECKPOINT_JOIN_CHAR
    ModelCheckpoint.CHECKPOINT_JOIN_CHAR = '@'
    ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch},{acc:.5f}',
                                                        3,
                                                        2, {'acc': 0.03},
                                                        prefix='test')
    assert ckpt_name == 'test@epoch=3,acc=0.03000'
    ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org

    # no dirpath set
    ckpt_name = ModelCheckpoint(monitor='early_stop_on',
                                dirpath=None).format_checkpoint_name(3, 2, {})
    assert ckpt_name == 'epoch=3-step=2.ckpt'
    ckpt_name = ModelCheckpoint(monitor='early_stop_on',
                                dirpath='').format_checkpoint_name(5, 4, {})
    assert ckpt_name == 'epoch=5-step=4.ckpt'

    # CWD
    ckpt_name = ModelCheckpoint(monitor='early_stop_on',
                                dirpath='.').format_checkpoint_name(3, 4, {})
    assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=4.ckpt')

    # with ver
    ckpt_name = ModelCheckpoint(monitor='early_stop_on',
                                dirpath=tmpdir,
                                filename='name',
                                prefix='test').format_checkpoint_name(3,
                                                                      2, {},
                                                                      ver=3)
    assert ckpt_name == tmpdir / 'test-name-v3.ckpt'

    # using slashes
    ckpt_name = ModelCheckpoint(
        monitor='early_stop_on',
        dirpath=None,
        filename='{epoch}_{val/loss:.5f}').format_checkpoint_name(
            4, 3, {'val/loss': 0.03})
    assert ckpt_name == 'epoch=4_val/loss=0.03000.ckpt'

    # TODO: Checks with filepath. To be removed in v1.2
    # CWD
    ckpt_name = ModelCheckpoint(monitor='early_stop_on',
                                filepath='.').format_checkpoint_name(3, 2, {})
    assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=2.ckpt')

    # dir does not exist so it is used as filename
    filepath = tmpdir / 'dir'
    ckpt_name = ModelCheckpoint(monitor='early_stop_on',
                                filepath=filepath,
                                prefix='test').format_checkpoint_name(
                                    3, 2, {})
    assert ckpt_name == tmpdir / 'test-dir.ckpt'

    # now, dir exists
    os.mkdir(filepath)
    ckpt_name = ModelCheckpoint(monitor='early_stop_on',
                                filepath=filepath,
                                prefix='test').format_checkpoint_name(
                                    3, 2, {})
    assert ckpt_name == filepath / 'test-epoch=3-step=2.ckpt'
Beispiel #5
0
def test_model_checkpoint_format_checkpoint_name(tmpdir):
    # empty filename:
    ckpt_name = ModelCheckpoint._format_checkpoint_name('', 3, 2, {})
    assert ckpt_name == 'epoch=3-step=2'

    ckpt_name = ModelCheckpoint._format_checkpoint_name(None,
                                                        3,
                                                        2, {},
                                                        prefix='test')
    assert ckpt_name == 'test-epoch=3-step=2'

    # no groups case:
    ckpt_name = ModelCheckpoint._format_checkpoint_name('ckpt',
                                                        3,
                                                        2, {},
                                                        prefix='test')
    assert ckpt_name == 'test-ckpt'

    # no prefix
    ckpt_name = ModelCheckpoint._format_checkpoint_name(
        '{epoch:03d}-{acc}', 3, 2, {'acc': 0.03})
    assert ckpt_name == 'epoch=003-acc=0.03'

    # prefix
    char_org = ModelCheckpoint.CHECKPOINT_JOIN_CHAR
    ModelCheckpoint.CHECKPOINT_JOIN_CHAR = '@'
    ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch},{acc:.5f}',
                                                        3,
                                                        2, {'acc': 0.03},
                                                        prefix='test')
    assert ckpt_name == 'test@epoch=3,acc=0.03000'
    ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org

    # no dirpath set
    ckpt_name = ModelCheckpoint(monitor='early_stop_on',
                                dirpath=None).format_checkpoint_name(3, 2, {})
    assert ckpt_name == 'epoch=3-step=2.ckpt'
    ckpt_name = ModelCheckpoint(monitor='early_stop_on',
                                dirpath='').format_checkpoint_name(5, 4, {})
    assert ckpt_name == 'epoch=5-step=4.ckpt'

    # CWD
    ckpt_name = ModelCheckpoint(monitor='early_stop_on',
                                dirpath='.').format_checkpoint_name(3, 4, {})
    assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=4.ckpt')

    # with version
    ckpt = ModelCheckpoint(monitor='early_stop_on',
                           dirpath=tmpdir,
                           filename='name')
    ckpt_name = ckpt.format_checkpoint_name(3, 2, {}, ver=3)
    assert ckpt_name == tmpdir / 'name-v3.ckpt'

    # using slashes
    ckpt = ModelCheckpoint(monitor='early_stop_on',
                           dirpath=None,
                           filename='{epoch}_{val/loss:.5f}')
    ckpt_name = ckpt.format_checkpoint_name(4, 3, {'val/loss': 0.03})
    assert ckpt_name == 'epoch=4_val/loss=0.03000.ckpt'

    # auto_insert_metric_name=False
    ckpt_name = ModelCheckpoint._format_checkpoint_name(
        'epoch={epoch:03d}-val_acc={val/acc}',
        3,
        2, {'val/acc': 0.03},
        auto_insert_metric_name=False)
    assert ckpt_name == 'epoch=003-val_acc=0.03'