示例#1
0
def test_restore_session(shared_datadir, script_runner, mocker, tmp_path):
    r"""Test session restore"""

    script_runner.run(
        "xfuse",
        "run",
        f"--save-path={tmp_path}",
        str(shared_datadir / "test_restore_session.toml"),
    )

    state_dict = get_state_dict()
    reset_state()

    def _mock_run(*_args, **_kwargs):
        with Session(panic=Unset()):
            assert get("training_data").step > 1
            new_state_dict = get_state_dict()
            assert all(
                (new_state_dict.modules[module_name][param_name] == param_value
                 ).all()
                for module_name, module_state in state_dict.modules.items()
                for param_name, param_value in module_state.items())
            assert all(
                (new_state_dict.params[param_name] == param_value).all()
                for param_name, param_value in state_dict.params.items())

    mocker.patch("xfuse.__main__._run", _mock_run)

    ret = script_runner.run(
        "xfuse",
        "run",
        f"--save-path={tmp_path}",
        str(shared_datadir / "test_restore_session.toml"),
    )
    assert ret.success
示例#2
0
def test_restore_session(
    config, shared_datadir, script_runner, mocker, tmp_path
):
    r"""Test session restore"""
    with Session(training_data=TrainingData()):
        script_runner.run(
            "xfuse",
            "run",
            "--debug",
            f"--save-path={tmp_path}",
            str(shared_datadir / config),
        )

    state_dict = get_state_dict()
    reset_state()

    def _mock_run(*_args, **_kwargs):
        with Session(panic=Unset()):
            assert get("training_data").step > 0
            new_state_dict = get_state_dict()
            assert all(
                (
                    new_state_dict.modules[module_name][param_name]
                    == param_value
                ).all()
                for module_name, module_state in state_dict.modules.items()
                for param_name, param_value in module_state.items()
            )
            assert all(
                (
                    new_state_dict.params[param_name].data == param_value.data
                ).all()
                for param_name, param_value in state_dict.params.items()
                if param_value.data.nelement() > 0
            )

    mocker.patch("xfuse.__main__._run", _mock_run)

    ret = script_runner.run(
        "xfuse",
        "run",
        f"--save-path={tmp_path}",
        str(shared_datadir / config),
        "--session=" + str(tmp_path / "final.session"),
    )
    assert ret.success
示例#3
0
文件: conftest.py 项目: ludvb/xfuse
def pytest_runtest_setup(item):
    # pylint: disable=missing-function-docstring
    pyro.clear_param_store()
    reset_state()
    if item.get_closest_marker("fix_rng") is not None:
        torch.manual_seed(0)