コード例 #1
0
def test_strict_model_load(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
    """Tests use case where trainer saves the model, and user loads it from tags independently."""
    # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
    monkeypatch.setenv("TORCH_HOME", tmpdir)

    model = EvalModelTemplate()
    # Extra layer
    model.c_d3 = torch.nn.Linear(model.hidden_dim, model.hidden_dim)

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # fit model
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        logger=logger,
        checkpoint_callback=ModelCheckpoint(dirpath=tmpdir),
    )
    result = trainer.fit(model)

    # traning complete
    assert result == 1

    # save model
    new_weights_path = os.path.join(tmpdir, "save_test.ckpt")
    trainer.save_checkpoint(new_weights_path)

    # load new model
    hparams_path = tutils.get_data_path(logger, path_dir=tmpdir)
    hparams_path = os.path.join(hparams_path, "hparams.yaml")
    ckpt_path = (
        f"http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}"
        if url_ckpt
        else new_weights_path
    )

    try:
        EvalModelTemplate.load_from_checkpoint(
            checkpoint_path=ckpt_path,
            hparams_file=hparams_path,
        )
    except Exception:
        failed = True
    else:
        failed = False

    assert failed, "Model should not been loaded since the extra layer added."

    failed = False
    try:
        EvalModelTemplate.load_from_checkpoint(
            checkpoint_path=ckpt_path,
            hparams_file=hparams_path,
            strict=False,
        )
    except Exception:
        failed = True

    assert not failed, "Model should be loaded due to strict=False."
コード例 #2
0
def test_strict_model_load_more_params(monkeypatch, tmpdir, tmpdir_server,
                                       url_ckpt):
    """Tests use case where trainer saves the model, and user loads it from tags independently."""
    # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
    monkeypatch.setenv('TORCH_HOME', tmpdir)

    model = EvalModelTemplate()
    # Extra layer
    model.c_d3 = torch.nn.Linear(model.hidden_dim, model.hidden_dim)

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # fit model
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        logger=logger,
        checkpoint_callback=ModelCheckpoint(tmpdir),
    )
    result = trainer.fit(model)

    # traning complete
    assert result == 1

    # save model
    new_weights_path = os.path.join(tmpdir, 'save_test.ckpt')
    trainer.save_checkpoint(new_weights_path)

    # load new model
    hparams_path = os.path.join(tutils.get_data_path(logger, path_dir=tmpdir),
                                'hparams.yaml')
    hparams_url = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}'
    ckpt_path = hparams_url if url_ckpt else new_weights_path

    EvalModelTemplate.load_from_checkpoint(
        checkpoint_path=ckpt_path,
        hparams_file=hparams_path,
        strict=False,
    )

    with pytest.raises(
            RuntimeError,
            match=
            r'Unexpected key\(s\) in state_dict: "c_d3.weight", "c_d3.bias"'):
        EvalModelTemplate.load_from_checkpoint(
            checkpoint_path=ckpt_path,
            hparams_file=hparams_path,
            strict=True,
        )