def test_save_config_modes(out_dir):
    pre_test_clean_up()
    hook = SessionHook(out_dir=out_dir, include_collections=["weights"])
    hook.get_collection("weights").save_config = {
        modes.TRAIN: SaveConfigMode(save_interval=2),
        modes.EVAL: SaveConfigMode(save_interval=3),
    }
    helper_save_config_modes(out_dir, hook)
def test_multi_collection_match(out_dir):
    pre_test_clean_up()
    hook = SessionHook(
        out_dir=out_dir,
        include_regex=["loss:0"],
        include_collections=["default", "trial"],
        save_config=SaveConfig(save_interval=2),
    )
    hook.get_collection("trial").include("loss:0")
    helper_test_multi_collection_match(out_dir, hook)
def test_save_config_json(out_dir, monkeypatch):
    pre_test_clean_up()
    monkeypatch.setenv(
        CONFIG_FILE_PATH_ENV_STR,
        "tests/tensorflow/hooks/test_json_configs/test_save_config.json")
    hook = SessionHook.create_from_json_file()
    helper_test_save_config(out_dir, hook)
def test_save_config_modes_json(out_dir, monkeypatch):
    monkeypatch.setenv(
        CONFIG_FILE_PATH_ENV_STR,
        "tests/tensorflow/hooks/test_json_configs/test_save_config_modes_config_coll.json",
    )
    hook = SessionHook.create_from_json_file()
    helper_save_config_modes(out_dir, hook)
Exemple #5
0
def test_hook_config_json(out_dir, monkeypatch):
    monkeypatch.setenv(
        CONFIG_FILE_PATH_ENV_STR,
        "tests/tensorflow/hooks/test_json_configs/test_hook_from_json_config.json",
    )
    hook = SessionHook.create_from_json_file()
    test_save_all_full(out_dir, hook)
def test_save_config_start_and_end(out_dir):
    pre_test_clean_up()
    hook = SessionHook(
        out_dir=out_dir,
        save_all=False,
        save_config=SaveConfig(save_interval=2, start_step=8, end_step=14),
    )
    helper_save_config_start_and_end(out_dir, hook)
def test_simple_include_regex_json(out_dir, monkeypatch):
    pre_test_clean_up()
    monkeypatch.setenv(
        CONFIG_FILE_PATH_ENV_STR,
        "tests/tensorflow/hooks/test_json_configs/test_simple_include_regex.json",
    )
    hook = SessionHook.create_from_json_file()
    helper_test_simple_include_regex(out_dir, hook)
def test_simple_include(out_dir):
    pre_test_clean_up()
    hook = SessionHook(
        out_dir=out_dir,
        save_config=SaveConfig(save_interval=2),
        include_collections=["default", "losses"],
    )
    helper_test_simple_include(out_dir, hook)
def test_simple_include_regex(out_dir):
    pre_test_clean_up()
    hook = SessionHook(
        out_dir=out_dir,
        include_regex=["loss:0"],
        include_collections=[],
        save_config=SaveConfig(save_interval=2),
    )
    helper_test_simple_include_regex(out_dir, hook)
def test_multi_collection_match_json(out_dir, monkeypatch):
    pre_test_clean_up()
    monkeypatch.setenv(
        CONFIG_FILE_PATH_ENV_STR,
        "tests/tensorflow/hooks/test_json_configs/test_multi_collection_match.json",
    )
    hook = SessionHook.create_from_json_file()
    hook.get_collection("trial").include("loss:0")
    helper_test_multi_collection_match(out_dir, hook)
def test_hook_write(out_dir):
    pre_test_clean_up()
    # set up hook
    hook = SessionHook(
        out_dir, save_all=True, include_collections=None, save_config=SaveConfig(save_interval=999)
    )
    helper_hook_write(out_dir, hook)
    tr = create_trial_fast_refresh(out_dir)
    print(tr.tensor_names(collection="weights"))
    assert len(tr.tensor_names(collection="weights"))
def test_save_config_disable(out_dir, monkeypatch):
    pre_test_clean_up()
    monkeypatch.setenv(
        CONFIG_FILE_PATH_ENV_STR,
        "tests/tensorflow/hooks/test_json_configs/test_save_config_disable.json",
    )
    hook = SessionHook.create_from_json_file()
    simple_model(hook)
    tr = create_trial(out_dir)
    assert len(tr.steps()) == 0
    assert len(tr.tensor_names()) == 0
Exemple #13
0
def test_save_all_full(out_dir, hook=None):
    tf.reset_default_graph()
    if hook is None:
        hook = SessionHook(out_dir=out_dir,
                           save_all=True,
                           save_config=SaveConfig(save_interval=2))

    simple_model(hook)
    tr = create_trial_fast_refresh(out_dir)
    assert len(tr.tensor_names()) > 50
    print(tr.tensor_names(collection="weights"))
    assert len(tr.tensor_names(collection="weights")) == 1
    assert len(tr.tensor_names(collection="gradients")) == 1
    assert len(tr.tensor_names(collection="losses")) == 1
def test_save_config_skip_steps(out_dir):
    pre_test_clean_up()
    hook = SessionHook(out_dir=out_dir,
                       save_all=False,
                       save_config=SaveConfig(save_interval=2, start_step=8))
    helper_save_config_skip_steps(out_dir, hook)
def test_save_config(out_dir):
    pre_test_clean_up()
    hook = SessionHook(out_dir=out_dir,
                       save_all=False,
                       save_config=SaveConfig(save_interval=2))
    helper_test_save_config(out_dir, hook)