def test_collection_defaults_to_hook_config(): """Test that hook save_configs propagate to collection defaults. For example, if we set ModeKeys.TRAIN: save_interval=10 in the hook and ModeKeys.EVAL: save_interval=20 in a collection, we would like the collection to be finalized as {ModeKeys.TRAIN: save_interval=10, ModeKeys.EVAL: save_interval=20}. """ cm = CollectionManager() cm.create_collection("foo") cm.get("foo").include_regex = "*" cm.get("foo").save_config = { ModeKeys.EVAL: SaveConfigMode(save_interval=20) } hook = Hook( out_dir="/tmp/test_collections/" + str(datetime.datetime.now()), save_config={ModeKeys.TRAIN: SaveConfigMode(save_interval=10)}, include_collections=["foo"], reduction_config=ReductionConfig(save_raw_tensor=True), ) hook.collection_manager = cm assert cm.get("foo").save_config.mode_save_configs[ModeKeys.TRAIN] is None assert cm.get("foo").reduction_config is None hook._prepare_collections() assert cm.get("foo").save_config.mode_save_configs[ ModeKeys.TRAIN].save_interval == 10 assert cm.get("foo").reduction_config.save_raw_tensor is True
def test_save_config_hookjson_config(): from smdebug.core.json_config import CONFIG_FILE_PATH_ENV_STR out_dir = "/tmp/test_hook_from_json_config_full" shutil.rmtree(out_dir, True) os.environ[ CONFIG_FILE_PATH_ENV_STR] = "tests/mxnet/test_json_configs/test_save_config_hookjson_config.json" hook = t_hook.create_from_json_file() test_save_config(hook=hook) shutil.rmtree(out_dir, True)
def test_hook_from_json_config_full(): out_dir = "/tmp/newlogsRunTest2/test_hook_from_json_config_full" shutil.rmtree(out_dir, True) os.environ[ CONFIG_FILE_PATH_ENV_STR] = "tests/mxnet/test_json_configs/test_hook_from_json_config_full.json" hook = t_hook.create_from_json_file() assert has_training_ended(out_dir) == False run_mnist_gluon_model(hook=hook, num_steps_train=10, num_steps_eval=10, register_to_loss_block=True) shutil.rmtree(out_dir, True)
def test_save_all_hook_from_json(): from smdebug.core.json_config import CONFIG_FILE_PATH_ENV_STR import os out_dir = "/tmp/newlogsRunTest2/test_hook_save_all_hook_from_json" shutil.rmtree(out_dir, True) os.environ[ CONFIG_FILE_PATH_ENV_STR] = "tests/mxnet/test_json_configs/test_hook_save_all_hook.json" hook = t_hook.create_from_json_file() test_save_all(hook, out_dir) # delete output shutil.rmtree(out_dir, True)
def test_modes_hook_from_json_config(): from smdebug.core.json_config import CONFIG_FILE_PATH_ENV_STR import shutil import os out_dir = "/tmp/test_modes_hookjson" shutil.rmtree(out_dir, True) os.environ[ CONFIG_FILE_PATH_ENV_STR] = "tests/mxnet/test_json_configs/test_modes_hook.json" hook = t_hook.create_from_json_file() test_modes(hook, out_dir) shutil.rmtree(out_dir, True)
def test_invalid_collection_config_exception(): cm = CollectionManager() cm.create_collection("foo") hook = Hook( out_dir="/tmp/test_collections/" + str(datetime.datetime.now()), save_config={ModeKeys.TRAIN: SaveConfigMode(save_interval=10)}, include_collections=["foo"], reduction_config=ReductionConfig(save_raw_tensor=True), ) hook.collection_manager = cm try: hook._prepare_collections() except InvalidCollectionConfiguration: pass else: assert False, "Invalid Collection Name did not raise error" cm.get("foo").include_regex = "*" try: hook._prepare_collections() except InvalidCollectionConfiguration: assert False, "Valid Collection Name raised an error"