def test_collection_defaults_to_hook_config(): """Test that hook save_configs propagate to collection defaults. For example, if we set ModeKeys.TRAIN: save_interval=10 in the hook and ModeKeys.EVAL: save_interval=20 in a collection, we would like the collection to be finalized as {ModeKeys.TRAIN: save_interval=10, ModeKeys.EVAL: save_interval=20}. """ cm = CollectionManager() cm.create_collection("foo") cm.get("foo").include_regex = "*" cm.get("foo").save_config = { ModeKeys.EVAL: SaveConfigMode(save_interval=20) } hook = Hook( out_dir="/tmp/test_collections/" + str(datetime.datetime.now()), save_config={ModeKeys.TRAIN: SaveConfigMode(save_interval=10)}, include_collections=["foo"], reduction_config=ReductionConfig(save_raw_tensor=True), ) hook.collection_manager = cm assert cm.get("foo").save_config.mode_save_configs[ModeKeys.TRAIN] is None assert cm.get("foo").reduction_config is None hook._prepare_collections() assert cm.get("foo").save_config.mode_save_configs[ ModeKeys.TRAIN].save_interval == 10 assert cm.get("foo").reduction_config.save_raw_tensor is True
def test_invalid_collection_config_exception(): cm = CollectionManager() cm.create_collection("foo") hook = Hook( out_dir="/tmp/test_collections/" + str(datetime.datetime.now()), save_config={ModeKeys.TRAIN: SaveConfigMode(save_interval=10)}, include_collections=["foo"], reduction_config=ReductionConfig(save_raw_tensor=True), ) hook.collection_manager = cm try: hook._prepare_collections() except InvalidCollectionConfiguration: pass else: assert False, "Invalid Collection Name did not raise error" cm.get("foo").include_regex = "*" try: hook._prepare_collections() except InvalidCollectionConfiguration: assert False, "Valid Collection Name raised an error"