def test_save_config_modes(out_dir): pre_test_clean_up() hook = SessionHook(out_dir=out_dir, include_collections=["weights"]) hook.get_collection("weights").save_config = { modes.TRAIN: SaveConfigMode(save_interval=2), modes.EVAL: SaveConfigMode(save_interval=3), } helper_save_config_modes(out_dir, hook)
def test_multi_collection_match(out_dir): pre_test_clean_up() hook = SessionHook( out_dir=out_dir, include_regex=["loss:0"], include_collections=["default", "trial"], save_config=SaveConfig(save_interval=2), ) hook.get_collection("trial").include("loss:0") helper_test_multi_collection_match(out_dir, hook)
def test_save_config_json(out_dir, monkeypatch): pre_test_clean_up() monkeypatch.setenv( CONFIG_FILE_PATH_ENV_STR, "tests/tensorflow/hooks/test_json_configs/test_save_config.json") hook = SessionHook.create_from_json_file() helper_test_save_config(out_dir, hook)
def test_save_config_modes_json(out_dir, monkeypatch): monkeypatch.setenv( CONFIG_FILE_PATH_ENV_STR, "tests/tensorflow/hooks/test_json_configs/test_save_config_modes_config_coll.json", ) hook = SessionHook.create_from_json_file() helper_save_config_modes(out_dir, hook)
def test_hook_config_json(out_dir, monkeypatch): monkeypatch.setenv( CONFIG_FILE_PATH_ENV_STR, "tests/tensorflow/hooks/test_json_configs/test_hook_from_json_config.json", ) hook = SessionHook.create_from_json_file() test_save_all_full(out_dir, hook)
def test_save_config_start_and_end(out_dir): pre_test_clean_up() hook = SessionHook( out_dir=out_dir, save_all=False, save_config=SaveConfig(save_interval=2, start_step=8, end_step=14), ) helper_save_config_start_and_end(out_dir, hook)
def test_simple_include_regex_json(out_dir, monkeypatch): pre_test_clean_up() monkeypatch.setenv( CONFIG_FILE_PATH_ENV_STR, "tests/tensorflow/hooks/test_json_configs/test_simple_include_regex.json", ) hook = SessionHook.create_from_json_file() helper_test_simple_include_regex(out_dir, hook)
def test_simple_include(out_dir): pre_test_clean_up() hook = SessionHook( out_dir=out_dir, save_config=SaveConfig(save_interval=2), include_collections=["default", "losses"], ) helper_test_simple_include(out_dir, hook)
def test_simple_include_regex(out_dir): pre_test_clean_up() hook = SessionHook( out_dir=out_dir, include_regex=["loss:0"], include_collections=[], save_config=SaveConfig(save_interval=2), ) helper_test_simple_include_regex(out_dir, hook)
def test_multi_collection_match_json(out_dir, monkeypatch): pre_test_clean_up() monkeypatch.setenv( CONFIG_FILE_PATH_ENV_STR, "tests/tensorflow/hooks/test_json_configs/test_multi_collection_match.json", ) hook = SessionHook.create_from_json_file() hook.get_collection("trial").include("loss:0") helper_test_multi_collection_match(out_dir, hook)
def test_hook_write(out_dir): pre_test_clean_up() # set up hook hook = SessionHook( out_dir, save_all=True, include_collections=None, save_config=SaveConfig(save_interval=999) ) helper_hook_write(out_dir, hook) tr = create_trial_fast_refresh(out_dir) print(tr.tensor_names(collection="weights")) assert len(tr.tensor_names(collection="weights"))
def test_save_config_disable(out_dir, monkeypatch): pre_test_clean_up() monkeypatch.setenv( CONFIG_FILE_PATH_ENV_STR, "tests/tensorflow/hooks/test_json_configs/test_save_config_disable.json", ) hook = SessionHook.create_from_json_file() simple_model(hook) tr = create_trial(out_dir) assert len(tr.steps()) == 0 assert len(tr.tensor_names()) == 0
def test_save_all_full(out_dir, hook=None): tf.reset_default_graph() if hook is None: hook = SessionHook(out_dir=out_dir, save_all=True, save_config=SaveConfig(save_interval=2)) simple_model(hook) tr = create_trial_fast_refresh(out_dir) assert len(tr.tensor_names()) > 50 print(tr.tensor_names(collection="weights")) assert len(tr.tensor_names(collection="weights")) == 1 assert len(tr.tensor_names(collection="gradients")) == 1 assert len(tr.tensor_names(collection="losses")) == 1
def test_save_config_skip_steps(out_dir): pre_test_clean_up() hook = SessionHook(out_dir=out_dir, save_all=False, save_config=SaveConfig(save_interval=2, start_step=8)) helper_save_config_skip_steps(out_dir, hook)
def test_save_config(out_dir): pre_test_clean_up() hook = SessionHook(out_dir=out_dir, save_all=False, save_config=SaveConfig(save_interval=2)) helper_test_save_config(out_dir, hook)