Example #1
0
def standardize_config(config: dict):
    config = unflatten(flatten(config), levels=[0])
    out_dict = {}
    for k in RESERVED_KEYS:
        if k == "fixed":
            out_dict[k] = config.get(k, {})
        else:
            out_dict[k] = unflatten(config.get(k, {}), levels=[-1])
    return out_dict
Example #2
0
    def test_basic(self):
        flattened = {'a.b.c': 111, 'a.d': 22}
        unflattened = utils.unflatten(flattened, sep=".", recursive=False)
        unflattened2 = utils.unflatten(
            flattened, sep=".",
            recursive=True)  # should not make a difference here
        expected = {'a': {'b': {'c': 111}, 'd': 22}}

        self.assertEqual(expected, unflattened)
        self.assertEqual(expected, unflattened2)
Example #3
0
    def test_out_of_bounds(self):
        flattened = {'a.b.c.d.e': 111, 'a.b.c.d.f': 222, 'a.b.c.g.h': 333}
        with self.assertRaises(IndexError):
            unflattened = utils.unflatten(flattened,
                                          sep='.',
                                          recursive=False,
                                          levels=[5])

        with self.assertRaises(IndexError):
            unflattened = utils.unflatten(flattened,
                                          sep='.',
                                          recursive=False,
                                          levels=[-5])
Example #4
0
def convert_parameter_collections(input_config: dict):
    flattened_dict = flatten(input_config)
    parameter_collection_keys = [
        k for k in flattened_dict.keys()
        if flattened_dict[k] == "parameter_collection"
    ]
    if len(parameter_collection_keys) > 0:
        logging.warning(
            "Parameter collections are deprecated. Use dot-notation for nested parameters instead."
        )
    while len(parameter_collection_keys) > 0:
        k = parameter_collection_keys[0]
        del flattened_dict[k]
        # sub1.sub2.type ==> # sub1.sub2
        k = ".".join(k.split(".")[:-1])
        parameter_collections_params = [
            param_key for param_key in flattened_dict.keys()
            if param_key.startswith(k)
        ]
        for p in parameter_collections_params:
            if f"{k}.params" in p:
                new_key = p.replace(f"{k}.params", k)
                if new_key in flattened_dict:
                    logging.error(
                        f"Could not convert parameter collections due to key collision: {new_key}."
                    )
                    sys.exit(1)
                flattened_dict[new_key] = flattened_dict[p]
                del flattened_dict[p]
        parameter_collection_keys = [
            k for k in flattened_dict.keys()
            if flattened_dict[k] == "parameter_collection"
        ]
    return unflatten(flattened_dict)
Example #5
0
    def test_unflatten_multiple_levels(self):
        flattened = {'a.b.c.d.e': 111, 'a.b.c.d.f': 222, 'a.b.c.g.h': 333}
        unflattened = utils.unflatten(flattened,
                                      sep='.',
                                      recursive=False,
                                      levels=[0, -1])
        expected = {
            'a': {
                'b.c.d': {
                    'e': 111,
                    'f': 222,
                },
                'b.c.g': {
                    'h': 333,
                }
            }
        }
        self.assertEqual(unflattened, expected)

        unflattened2 = utils.unflatten(flattened,
                                       sep='.',
                                       recursive=False,
                                       levels=[0, 1, 3])
        expected2 = {
            'a': {
                'b': {
                    'c.d': {
                        'e': 111,
                        'f': 222
                    },
                    'c.g': {
                        'h': 333
                    }
                }
            }
        }
        self.assertEqual(unflattened2, expected2)

        unflattened3 = utils.unflatten(flattened,
                                       sep='.',
                                       recursive=False,
                                       levels=[0, 1, 2, 3])
        expected3 = utils.unflatten(flattened, sep=".", recursive=False)
        self.assertEqual(unflattened3, expected3)

        unflattened4 = utils.unflatten(flattened,
                                       sep='.',
                                       recursive=False,
                                       levels=[4])
        self.assertEqual(unflattened4, flattened)

        unflattened5 = utils.unflatten(flattened,
                                       sep='.',
                                       recursive=False,
                                       levels=[-2])
        expected5 = utils.unflatten(flattened,
                                    sep='.',
                                    recursive=False,
                                    levels=[2])
        self.assertEqual(unflattened5, expected5)
Example #6
0
    def test_unflatten_single_level(self):
        flattened = {'a.b.c': 111, 'a.b': {'c': 222}}
        unflattened = utils.unflatten(flattened,
                                      sep='.',
                                      recursive=True,
                                      levels=[-1])
        unflattened2 = utils.unflatten(flattened,
                                       sep='.',
                                       recursive=True,
                                       levels=-1)
        expected = {'a.b': {'c': 111}, 'a': {'b': {'c': 222}}}
        self.assertEqual(unflattened, expected)
        self.assertEqual(unflattened, unflattened2)

        unflattened3 = utils.unflatten(flattened,
                                       sep='.',
                                       recursive=True,
                                       levels=[0])
        expected2 = {'a': {'b.c': 111, 'b': {'c': 222}}}
        self.assertEqual(unflattened3, expected2)
Example #7
0
    def test_recursive(self):
        flattened = {'a.b.c': 111, 'a.d': {'e': {'f.g': 333}}}
        unflattened_recursive = utils.unflatten(flattened,
                                                sep=".",
                                                recursive=True)
        expected_recursive = {
            'a': {
                'b': {
                    'c': 111
                },
                'd': {
                    'e': {
                        'f': {
                            'g': 333
                        }
                    }
                }
            }
        }
        assert unflattened_recursive == expected_recursive
        self.assertEqual(unflattened_recursive, expected_recursive)

        unflattened_nonrecursive = utils.unflatten(flattened,
                                                   sep='.',
                                                   recursive=False)
        expected_nonrecursive = {
            'a': {
                'b': {
                    'c': 111
                },
                'd': {
                    'e': {
                        'f.g': 333
                    }
                }
            }
        }
        self.assertEqual(unflattened_nonrecursive, expected_nonrecursive)
Example #8
0
def sample_random_configs(random_config, samples=1, seed=None):
    """
    Sample random configurations from the specified search space.

    Parameters
    ----------
    random_config: dict
        dict where each key is a parameter and the value defines how the random sample is drawn. The samples will be
        drawn using the function sample_parameter.
    samples: int
        The number of samples to draw per parameter
    seed: int or None
        The seed to use when drawing the parameter value. Defaults to None.

    Returns
    -------
    random_configurations: list of dicts
        List of dicts, where each dict gives a value for all parameters defined in the input random_config dict.

    """

    if len(random_config) == 0:
        return [{}]

    rdm_keys = [
        k for k in random_config.keys() if k not in ["samples", "seed"]
    ]
    random_config = {k: random_config[k] for k in rdm_keys}
    random_parameter_dicts = unflatten(random_config, levels=-1)
    random_samples = [
        sample_parameter(random_parameter_dicts[k],
                         samples,
                         seed,
                         parent_key=k) for k in random_parameter_dicts.keys()
    ]
    random_samples = dict([sub for item in random_samples for sub in item])
    random_configurations = [{k: v[ix]
                              for k, v in random_samples.items()}
                             for ix in range(samples)]

    return random_configurations
Example #9
0
    def test_conflicting_keys(self):
        flattened = {'a.b.c': 111, 'a.b': {'c': 222}}
        unflattened = utils.unflatten(flattened, sep='.', recursive=True)
        expected = {
            'a': {
                'b': {
                    'c': 222
                }
            }
        }  # later entries overwrite former ones
        self.assertEqual(unflattened, expected)

        flattened2 = {
            'a.b': {
                'c': 222
            },
            'a.b.c': 111
        }  # different order of keys
        unflattened2 = utils.unflatten(flattened2, sep='.', recursive=True)
        expected2 = {'a': {'b': {'c': 111}}}
        self.assertEqual(unflattened2, expected2)

        # this case is actually a bit tricky, but again we follow the paradigm that later entries overwrite former ones.
        flattened3 = {'a.b': ['not_dict'], 'a.b.c': 111}
        unflattened3 = utils.unflatten(flattened3, sep='.', recursive=True)
        expected3 = {'a': {'b': {'c': 111}}}
        self.assertEqual(unflattened3, expected3)

        # now the other way round
        flattened4 = {'a.b.c': 111, 'a.b': ['not_dict']}
        unflattened4 = utils.unflatten(flattened4, sep='.', recursive=True)
        expected4 = {'a': {'b': ['not_dict']}}
        self.assertEqual(unflattened4, expected4)

        flattened5 = {'a': {'b': ['not_dict']}, 'a.b.c': 111}
        unflattened5 = utils.unflatten(flattened5, sep='.', recursive=True)
        expected5 = {'a': {'b': {'c': 111}}}
        self.assertEqual(unflattened5, expected5)

        flattened6 = {'a.b.c': 111, 'a': {'b': ['not_dict']}}
        unflattened6 = utils.unflatten(flattened6, sep='.', recursive=True)
        expected6 = {'a': {'b': ['not_dict']}}
        self.assertEqual(unflattened6, expected6)
Example #10
0
    def test_recursive_with_levels(self):
        flattened_base = {'a.b.c.d.e': 111, 'a.b.c.d.f': 222, 'a.b.c.g.h': 333}

        flattened2 = flattened_base.copy()
        flattened2['a'] = {'b.c.d.e': 777, 'b.c.d.i': 999}

        unflattened = utils.unflatten(flattened2,
                                      sep=".",
                                      recursive=True,
                                      levels=0)
        expected = {
            'a': {
                'b.c.d.e': 111,
                'b.c.d.f': 222,
                'b.c.g.h': 333,
                'b': {
                    'c.d.e': 777,
                    'c.d.i': 999,
                }
            }
        }
        self.assertEqual(unflattened, expected)

        unflattened2 = utils.unflatten(flattened2,
                                       sep=".",
                                       recursive=False,
                                       levels=0)
        expected2 = {
            'a': {
                'b.c.d.e': 777,
                'b.c.d.f': 222,
                'b.c.g.h': 333,
                'b.c.d.i': 999,
            }
        }
        self.assertEqual(unflattened2, expected2)

        with self.assertRaises(IndexError):
            utils.unflatten(flattened2, sep=".", recursive=True, levels=1)

        with self.assertRaises(IndexError):
            utils.unflatten(flattened2, sep=".", recursive=False, levels=1)
Example #11
0
def generate_configs(experiment_config):
    """Generate parameter configurations based on an input configuration.

    Input is a nested configuration where on each level there can be 'fixed', 'grid', and 'random' parameters.

    In essence, we take the cartesian product of all the `grid` parameters and take random samples for the random
    parameters. The nested structure makes it possible to define different parameter spaces e.g. for different datasets.
    Parameter definitions lower in the hierarchy overwrite parameters defined closer to the root.

    For each leaf configuration we take the maximum of all num_samples values on the path since we need to have the same
    number of samples for each random parameter.

    For each configuration of the `grid` parameters we then create `num_samples` configurations of the random
    parameters, i.e. leading to `num_samples * len(grid_configurations)` configurations.

    See Also `examples/example_config.yaml` and the example below.

    Parameters
    ----------
    experiment_config: dict
        Dictionary that specifies the "search space" of parameters that will be enumerated. Should be
        parsed from a YAML file.

    Returns
    -------
    all_configs: list of dicts
        Contains the individual combinations of the parameters.


    """

    reserved, next_level = unpack_config(experiment_config)
    reserved = standardize_config(reserved)
    level_stack = [('', next_level)]
    config_levels = [reserved]
    final_configs = []

    detect_duplicate_parameters(invert_config(reserved), None)

    while len(level_stack) > 0:
        current_sub_name, sub_vals = level_stack.pop(0)
        sub_config, sub_levels = unpack_config(sub_vals)
        sub_config = standardize_config(sub_config)
        config_above = config_levels.pop(0)

        inverted_sub_config = invert_config(sub_config)
        detect_duplicate_parameters(inverted_sub_config, current_sub_name)

        inverted_config_above = invert_config(config_above)
        redefined_parameters = set(inverted_sub_config.keys()).intersection(
            set(inverted_config_above.keys()))

        if len(redefined_parameters) > 0:
            logging.warning(
                f"Found redefined parameters in {current_sub_name}: {redefined_parameters}. "
                f"Redefinitions of parameters override earlier ones.")
            config_above = copy.deepcopy(config_above)
            for p in redefined_parameters:
                sections = inverted_config_above[p]
                for s in sections:
                    del config_above[s][p]

        config = merge_dicts(config_above, sub_config)

        if len(sub_levels) == 0:
            final_configs.append((current_sub_name, config))

        for sub_name, sub_vals in sub_levels.items():
            new_sub_name = f'{current_sub_name}.{sub_name}' if current_sub_name != '' else sub_name
            level_stack.append((new_sub_name, sub_vals))
            config_levels.append(config)

    all_configs = []
    for subconfig_name, conf in final_configs:
        conf = standardize_config(conf)
        random_params = conf['random'] if 'random' in conf else {}
        fixed_params = flatten(conf['fixed']) if 'fixed' in conf else {}
        grid_params = conf['grid'] if 'grid' in conf else {}

        if len(random_params) > 0:
            num_samples = random_params['samples']
            root_seed = random_params.get('seed', None)
            random_sampled = sample_random_configs(flatten(random_params),
                                                   seed=root_seed,
                                                   samples=num_samples)

        grids = [
            generate_grid(v, parent_key=k) for k, v in grid_params.items()
        ]
        grid_configs = dict([sub for item in grids for sub in item])
        grid_product = list(cartesian_product_dict(grid_configs))

        with_fixed = [{**d, **fixed_params} for d in grid_product]
        if len(random_params) > 0:
            with_random = [{
                **grid,
                **random
            } for grid in with_fixed for random in random_sampled]
        else:
            with_random = with_fixed
        all_configs.extend(with_random)

    # Cast NumPy integers to normal integers since PyMongo doesn't like them
    all_configs = [{
        k: int(v) if isinstance(v, np.integer) else v
        for k, v in config.items()
    } for config in all_configs]

    all_configs = [unflatten(conf) for conf in all_configs]
    return all_configs
Example #12
0
 def test_empty(self):
     unflattened = utils.unflatten({})
     self.assertEqual(unflattened, {})
Example #13
0
    def test_errors(self):
        with self.assertRaises(ValueError):
            utils.unflatten({}, levels=[])

        with self.assertRaises(TypeError):
            utils.unflatten({}, levels=1.2)
Example #14
0
 def test_merge_duplicate_keys(self):
     flattened = {'a.b.c': 111, 'a': {'b': {'d': 222}}}
     unflattened = utils.unflatten(flattened, sep=".", recursive=True)
     expected = {'a': {'b': {'c': 111, 'd': 222}}}
     self.assertEqual(unflattened, expected)