def test_collection_defaults_to_hook_config():
    """Test that hook save_configs propagate to collection defaults.

  For example, if we set ModeKeys.TRAIN: save_interval=10 in the hook
  and ModeKeys.EVAL: save_interval=20 in a collection, we would like the collection to
  be finalized as {ModeKeys.TRAIN: save_interval=10, ModeKeys.EVAL: save_interval=20}.
  """
    cm = CollectionManager()
    cm.create_collection("foo")
    cm.get("foo").include_regex = "*"
    cm.get("foo").save_config = {
        ModeKeys.EVAL: SaveConfigMode(save_interval=20)
    }

    hook = Hook(
        out_dir="/tmp/test_collections/" + str(datetime.datetime.now()),
        save_config={ModeKeys.TRAIN: SaveConfigMode(save_interval=10)},
        include_collections=["foo"],
        reduction_config=ReductionConfig(save_raw_tensor=True),
    )
    hook.collection_manager = cm
    assert cm.get("foo").save_config.mode_save_configs[ModeKeys.TRAIN] is None
    assert cm.get("foo").reduction_config is None
    hook._prepare_collections()
    assert cm.get("foo").save_config.mode_save_configs[
        ModeKeys.TRAIN].save_interval == 10
    assert cm.get("foo").reduction_config.save_raw_tensor is True
def test_manager_export_load():
    cm = CollectionManager()
    cm.create_collection("default")
    cm.get("default").include("loss")
    cm.add(Collection("trial1"))
    cm.add("trial2")
    cm.get("trial2").include("total_loss")
    cm.export("/tmp/dummy_trial", DEFAULT_COLLECTIONS_FILE_NAME)
    cm2 = CollectionManager.load(
        os.path.join(get_path_to_collections("/tmp/dummy_trial"),
                     DEFAULT_COLLECTIONS_FILE_NAME))
    assert cm == cm2
def test_manager():
    cm = CollectionManager()
    cm.create_collection("default")
    cm.get("default").include("loss")
    cm.get("default").add_tensor_name("assaas")
    cm.add(Collection("trial1"))
    cm.add("trial2")
    cm.get("trial2").include("total_loss")
    assert len(cm.collections) == 3
    assert cm.get("default") == cm.collections["default"]
    assert "loss" in cm.get("default").include_regex
    assert len(cm.get("default").tensor_names) > 0
    assert "total_loss" in cm.collections["trial2"].include_regex
def test_invalid_collection_config_exception():
    cm = CollectionManager()
    cm.create_collection("foo")

    hook = Hook(
        out_dir="/tmp/test_collections/" + str(datetime.datetime.now()),
        save_config={ModeKeys.TRAIN: SaveConfigMode(save_interval=10)},
        include_collections=["foo"],
        reduction_config=ReductionConfig(save_raw_tensor=True),
    )
    hook.collection_manager = cm
    try:
        hook._prepare_collections()
    except InvalidCollectionConfiguration:
        pass
    else:
        assert False, "Invalid Collection Name did not raise error"

    cm.get("foo").include_regex = "*"
    try:
        hook._prepare_collections()
    except InvalidCollectionConfiguration:
        assert False, "Valid Collection Name raised an error"
def write_dummy_collection_file(trial):
    cm = CollectionManager()
    cm.create_collection("default")
    cm.add(Collection(trial))
    cm.export(trial, DEFAULT_COLLECTIONS_FILE_NAME)