예제 #1
0
def test_collection_defaults_to_hook_config():
    """Test that hook save_configs propagate to collection defaults.

  For example, if we set ModeKeys.TRAIN: save_interval=10 in the hook
  and ModeKeys.EVAL: save_interval=20 in a collection, we would like the collection to
  be finalized as {ModeKeys.TRAIN: save_interval=10, ModeKeys.EVAL: save_interval=20}.
  """
    cm = CollectionManager()
    cm.create_collection("foo")
    cm.get("foo").include_regex = "*"
    cm.get("foo").save_config = {
        ModeKeys.EVAL: SaveConfigMode(save_interval=20)
    }

    hook = Hook(
        out_dir="/tmp/test_collections/" + str(datetime.datetime.now()),
        save_config={ModeKeys.TRAIN: SaveConfigMode(save_interval=10)},
        include_collections=["foo"],
        reduction_config=ReductionConfig(save_raw_tensor=True),
    )
    hook.collection_manager = cm
    assert cm.get("foo").save_config.mode_save_configs[ModeKeys.TRAIN] is None
    assert cm.get("foo").reduction_config is None
    hook._prepare_collections()
    assert cm.get("foo").save_config.mode_save_configs[
        ModeKeys.TRAIN].save_interval == 10
    assert cm.get("foo").reduction_config.save_raw_tensor is True
def test_save_config_hookjson_config():
    from smdebug.core.json_config import CONFIG_FILE_PATH_ENV_STR

    out_dir = "/tmp/test_hook_from_json_config_full"
    shutil.rmtree(out_dir, True)
    os.environ[
        CONFIG_FILE_PATH_ENV_STR] = "tests/mxnet/test_json_configs/test_save_config_hookjson_config.json"
    hook = t_hook.create_from_json_file()
    test_save_config(hook=hook)
    shutil.rmtree(out_dir, True)
예제 #3
0
def test_hook_from_json_config_full():
    out_dir = "/tmp/newlogsRunTest2/test_hook_from_json_config_full"
    shutil.rmtree(out_dir, True)
    os.environ[
        CONFIG_FILE_PATH_ENV_STR] = "tests/mxnet/test_json_configs/test_hook_from_json_config_full.json"
    hook = t_hook.create_from_json_file()
    assert has_training_ended(out_dir) == False
    run_mnist_gluon_model(hook=hook,
                          num_steps_train=10,
                          num_steps_eval=10,
                          register_to_loss_block=True)
    shutil.rmtree(out_dir, True)
예제 #4
0
def test_save_all_hook_from_json():
    from smdebug.core.json_config import CONFIG_FILE_PATH_ENV_STR
    import os

    out_dir = "/tmp/newlogsRunTest2/test_hook_save_all_hook_from_json"
    shutil.rmtree(out_dir, True)
    os.environ[
        CONFIG_FILE_PATH_ENV_STR] = "tests/mxnet/test_json_configs/test_hook_save_all_hook.json"
    hook = t_hook.create_from_json_file()
    test_save_all(hook, out_dir)
    # delete output
    shutil.rmtree(out_dir, True)
예제 #5
0
def test_modes_hook_from_json_config():
    from smdebug.core.json_config import CONFIG_FILE_PATH_ENV_STR
    import shutil
    import os

    out_dir = "/tmp/test_modes_hookjson"
    shutil.rmtree(out_dir, True)
    os.environ[
        CONFIG_FILE_PATH_ENV_STR] = "tests/mxnet/test_json_configs/test_modes_hook.json"
    hook = t_hook.create_from_json_file()
    test_modes(hook, out_dir)
    shutil.rmtree(out_dir, True)
예제 #6
0
def test_invalid_collection_config_exception():
    cm = CollectionManager()
    cm.create_collection("foo")

    hook = Hook(
        out_dir="/tmp/test_collections/" + str(datetime.datetime.now()),
        save_config={ModeKeys.TRAIN: SaveConfigMode(save_interval=10)},
        include_collections=["foo"],
        reduction_config=ReductionConfig(save_raw_tensor=True),
    )
    hook.collection_manager = cm
    try:
        hook._prepare_collections()
    except InvalidCollectionConfiguration:
        pass
    else:
        assert False, "Invalid Collection Name did not raise error"

    cm.get("foo").include_regex = "*"
    try:
        hook._prepare_collections()
    except InvalidCollectionConfiguration:
        assert False, "Valid Collection Name raised an error"