def test_flatten_dict(nested_dict): flat = collections.dict_to_flatdict(nested_dict) assert flat == { collections.CompoundKey([1]): 2, collections.CompoundKey([2, 1]): 2, collections.CompoundKey([2, 3]): 4, collections.CompoundKey([3, 1]): 2, collections.CompoundKey([3, 3, 4]): 5, collections.CompoundKey([3, 3, 6, 7]): 8, }
def interpolate_config(config: dict, env_var_prefix: str = None) -> Config: """ Processes a config dictionary, such as the one loaded from `load_toml`. """ # toml supports nested dicts, so we work with a flattened representation to do any # requested interpolation flat_config = collections.dict_to_flatdict(config) # --------------------- Interpolate env vars ----------------------- # check if any env var sets a configuration value with the format: # [ENV_VAR_PREFIX]__[Section]__[Optional Sub-Sections...]__[Key] = Value # and if it does, add it to the config file. if env_var_prefix: for env_var, env_var_value in os.environ.items(): if env_var.startswith(env_var_prefix + "__"): # strip the prefix off the env var env_var_option = env_var[len(env_var_prefix + "__") :] # make sure the resulting env var has at least one delimitied section and key if "__" not in env_var: continue # env vars with escaped characters are interpreted as literal "\", which # Python helpfully escapes with a second "\". This step makes sure that # escaped characters are properly interpreted. value = cast(str, env_var_value.encode().decode("unicode_escape")) # place the env var in the flat config as a compound key config_option = collections.CompoundKey( env_var_option.lower().split("__") ) flat_config[config_option] = string_to_type( cast(str, interpolate_env_vars(value)) ) # interpolate any env vars referenced for k, v in list(flat_config.items()): flat_config[k] = interpolate_env_vars(v) # --------------------- Interpolate other config keys ----------------- # TOML doesn't support references to other keys... but we do! # This has the potential to lead to nasty recursions, so we check at most 10 times. # we use a set called "keys_to_check" to track only the ones of interest, so we aren't # checking every key every time. keys_to_check = set(flat_config.keys()) for _ in range(10): # iterate over every key and value to check if the value uses interpolation for k in list(keys_to_check): # if the value isn't a string, it can't be a reference, so we exit if not isinstance(flat_config[k], str): keys_to_check.remove(k) continue # see if the ${...} syntax was used in the value and exit if it wasn't match = INTERPOLATION_REGEX.search(flat_config[k]) if not match: keys_to_check.remove(k) continue # the matched_string includes "${}"; the matched_key is just the inner value matched_string = match.group(0) matched_key = match.group(1) # get the referenced key from the config value ref_key = collections.CompoundKey(matched_key.split(".")) # get the value corresponding to the referenced key ref_value = flat_config.get(ref_key, "") # if the matched was the entire value, replace it with the interpolated value if flat_config[k] == matched_string: flat_config[k] = ref_value # if it was a partial match, then drop the interpolated value into the string else: flat_config[k] = flat_config[k].replace( matched_string, str(ref_value), 1 ) return cast(Config, collections.flatdict_to_dict(flat_config, dct_class=Config))