コード例 #1
0
ファイル: test_functional.py プロジェクト: roromaniac/xfuse
def test_restore_session(shared_datadir, script_runner, mocker, tmp_path):
    r"""Test session restore"""

    script_runner.run(
        "xfuse",
        "run",
        f"--save-path={tmp_path}",
        str(shared_datadir / "test_restore_session.toml"),
    )

    state_dict = get_state_dict()
    reset_state()

    def _mock_run(*_args, **_kwargs):
        with Session(panic=Unset()):
            assert get("training_data").step > 1
            new_state_dict = get_state_dict()
            assert all(
                (new_state_dict.modules[module_name][param_name] == param_value
                 ).all()
                for module_name, module_state in state_dict.modules.items()
                for param_name, param_value in module_state.items())
            assert all(
                (new_state_dict.params[param_name] == param_value).all()
                for param_name, param_value in state_dict.params.items())

    mocker.patch("xfuse.__main__._run", _mock_run)

    ret = script_runner.run(
        "xfuse",
        "run",
        f"--save-path={tmp_path}",
        str(shared_datadir / "test_restore_session.toml"),
    )
    assert ret.success
コード例 #2
0
ファイル: test_functional.py プロジェクト: roromaniac/xfuse
 def _mock_run(*_args, **_kwargs):
     with Session(panic=Unset()):
         assert get("training_data").step > 1
         new_state_dict = get_state_dict()
         assert all(
             (new_state_dict.modules[module_name][param_name] == param_value
              ).all()
             for module_name, module_state in state_dict.modules.items()
             for param_name, param_value in module_state.items())
         assert all(
             (new_state_dict.params[param_name] == param_value).all()
             for param_name, param_value in state_dict.params.items())
コード例 #3
0
def test_restore_session(
    config, shared_datadir, script_runner, mocker, tmp_path
):
    r"""Test session restore"""
    with Session(training_data=TrainingData()):
        script_runner.run(
            "xfuse",
            "run",
            "--debug",
            f"--save-path={tmp_path}",
            str(shared_datadir / config),
        )

    state_dict = get_state_dict()
    reset_state()

    def _mock_run(*_args, **_kwargs):
        with Session(panic=Unset()):
            assert get("training_data").step > 0
            new_state_dict = get_state_dict()
            assert all(
                (
                    new_state_dict.modules[module_name][param_name]
                    == param_value
                ).all()
                for module_name, module_state in state_dict.modules.items()
                for param_name, param_value in module_state.items()
            )
            assert all(
                (
                    new_state_dict.params[param_name].data == param_value.data
                ).all()
                for param_name, param_value in state_dict.params.items()
                if param_value.data.nelement() > 0
            )

    mocker.patch("xfuse.__main__._run", _mock_run)

    ret = script_runner.run(
        "xfuse",
        "run",
        f"--save-path={tmp_path}",
        str(shared_datadir / config),
        "--session=" + str(tmp_path / "final.session"),
    )
    assert ret.success