Beispiel #1
0
def test_model_checkpoint_to_yaml(tmpdir, save_top_k):
    """ Test that None in checkpoint callback is valid and that chkp_path is set correctly """
    tutils.reset_seed()
    model = EvalModelTemplate()

    checkpoint = ModelCheckpoint(filepath=tmpdir, monitor='early_stop_on', save_top_k=save_top_k)

    trainer = Trainer(default_root_dir=tmpdir, checkpoint_callback=checkpoint, overfit_batches=0.20, max_epochs=2)
    trainer.fit(model)

    path_yaml = os.path.join(tmpdir, 'best_k_models.yaml')
    checkpoint.to_yaml(path_yaml)
    d = yaml.full_load(open(path_yaml, 'r'))
    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    assert d == best_k
Beispiel #2
0
def test_model_checkpoint_to_yaml(tmpdir, save_top_k: int):
    """Test that None in checkpoint callback is valid and that chkp_path is set correctly."""
    tutils.reset_seed()
    model = LogInTwoMethods()

    checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_top_k=save_top_k)

    trainer = Trainer(default_root_dir=tmpdir, callbacks=[checkpoint], overfit_batches=0.20, max_epochs=2)
    trainer.fit(model)

    path_yaml = os.path.join(tmpdir, "best_k_models.yaml")
    checkpoint.to_yaml(path_yaml)
    d = yaml.full_load(open(path_yaml))
    best_k = dict(checkpoint.best_k_models.items())
    assert d == best_k