def construct_group_dict(group_path, config): """ Given a config and a path that points to a data group, compute the data group's updated parameters. The group_path is a list of keys and indices e.g. ['train', 'datasets', 1, 'groups', 0] that can be followed to reach a group's config. """ # Find (almost) all prefixes of the group path all_paths = list( map(compose(list, tz.take(seq=group_path)), range(1, len(group_path)))) # Filter to exclude paths that point to lists paths_to_merge = list( filter(lambda p: isinstance(last(p[1]), str), pairwise(all_paths))) # Find all the (mid-level) dicts that the filtered paths point to mid_level_dicts = list( map( lambda p: tz.keyfilter(lambda k: k != last(p[1]), tz.get_in(p[0], config)), paths_to_merge)) # Merge parameters at all levels to get a single parameter set for the group def dmerge(*args): if all(is_mapping, *args): return Munch(tz.merge(*args)) else: return tz.last(*args) group_dict = tz.merge_with( dmerge, tz.keyfilter(lambda k: k not in ['train', 'val', 'test'], config), # top-level dict *mid_level_dicts, # mid-level dicts tz.get_in(group_path, config) # bottom-level dict ) return group_dict
def test_merge_with_list(): assert merge_with(sum, [{"a": 1}, {"a": 2}]) == {"a": 3}
def test_merge_with(): assert merge_with(sum)({1: 1}, {1: 2}) == {1: 3}
def test_merge_with_list(): assert merge_with(sum, [{'a': 1}, {'a': 2}]) == {'a': 3}