예제 #1
0
파일: config.py 프로젝트: linhaobuaa/seml
def convert_parameter_collections(input_config: dict):
    flattened_dict = flatten(input_config)
    parameter_collection_keys = [
        k for k in flattened_dict.keys()
        if flattened_dict[k] == "parameter_collection"
    ]
    if len(parameter_collection_keys) > 0:
        logging.warning(
            "Parameter collections are deprecated. Use dot-notation for nested parameters instead."
        )
    while len(parameter_collection_keys) > 0:
        k = parameter_collection_keys[0]
        del flattened_dict[k]
        # sub1.sub2.type ==> # sub1.sub2
        k = ".".join(k.split(".")[:-1])
        parameter_collections_params = [
            param_key for param_key in flattened_dict.keys()
            if param_key.startswith(k)
        ]
        for p in parameter_collections_params:
            if f"{k}.params" in p:
                new_key = p.replace(f"{k}.params", k)
                if new_key in flattened_dict:
                    logging.error(
                        f"Could not convert parameter collections due to key collision: {new_key}."
                    )
                    sys.exit(1)
                flattened_dict[new_key] = flattened_dict[p]
                del flattened_dict[p]
        parameter_collection_keys = [
            k for k in flattened_dict.keys()
            if flattened_dict[k] == "parameter_collection"
        ]
    return unflatten(flattened_dict)
예제 #2
0
def extract_parameter_set(input_config: dict, key: str):
    flattened_dict = flatten(input_config.get(key, {}))
    keys = flattened_dict.keys()
    if key != 'fixed':
        keys = [".".join(k.split(".")[:-1]) for k in keys
                if flattened_dict[k] != 'parameter_collection']
    return set(keys)
예제 #3
0
파일: config.py 프로젝트: linhaobuaa/seml
def standardize_config(config: dict):
    config = unflatten(flatten(config), levels=[0])
    out_dict = {}
    for k in RESERVED_KEYS:
        if k == "fixed":
            out_dict[k] = config.get(k, {})
        else:
            out_dict[k] = unflatten(config.get(k, {}), levels=[-1])
    return out_dict
예제 #4
0
파일: config.py 프로젝트: linhaobuaa/seml
def generate_configs(experiment_config):
    """Generate parameter configurations based on an input configuration.

    Input is a nested configuration where on each level there can be 'fixed', 'grid', and 'random' parameters.

    In essence, we take the cartesian product of all the `grid` parameters and take random samples for the random
    parameters. The nested structure makes it possible to define different parameter spaces e.g. for different datasets.
    Parameter definitions lower in the hierarchy overwrite parameters defined closer to the root.

    For each leaf configuration we take the maximum of all num_samples values on the path since we need to have the same
    number of samples for each random parameter.

    For each configuration of the `grid` parameters we then create `num_samples` configurations of the random
    parameters, i.e. leading to `num_samples * len(grid_configurations)` configurations.

    See Also `examples/example_config.yaml` and the example below.

    Parameters
    ----------
    experiment_config: dict
        Dictionary that specifies the "search space" of parameters that will be enumerated. Should be
        parsed from a YAML file.

    Returns
    -------
    all_configs: list of dicts
        Contains the individual combinations of the parameters.


    """

    reserved, next_level = unpack_config(experiment_config)
    reserved = standardize_config(reserved)
    level_stack = [('', next_level)]
    config_levels = [reserved]
    final_configs = []

    detect_duplicate_parameters(invert_config(reserved), None)

    while len(level_stack) > 0:
        current_sub_name, sub_vals = level_stack.pop(0)
        sub_config, sub_levels = unpack_config(sub_vals)
        sub_config = standardize_config(sub_config)
        config_above = config_levels.pop(0)

        inverted_sub_config = invert_config(sub_config)
        detect_duplicate_parameters(inverted_sub_config, current_sub_name)

        inverted_config_above = invert_config(config_above)
        redefined_parameters = set(inverted_sub_config.keys()).intersection(
            set(inverted_config_above.keys()))

        if len(redefined_parameters) > 0:
            logging.warning(
                f"Found redefined parameters in {current_sub_name}: {redefined_parameters}. "
                f"Redefinitions of parameters override earlier ones.")
            config_above = copy.deepcopy(config_above)
            for p in redefined_parameters:
                sections = inverted_config_above[p]
                for s in sections:
                    del config_above[s][p]

        config = merge_dicts(config_above, sub_config)

        if len(sub_levels) == 0:
            final_configs.append((current_sub_name, config))

        for sub_name, sub_vals in sub_levels.items():
            new_sub_name = f'{current_sub_name}.{sub_name}' if current_sub_name != '' else sub_name
            level_stack.append((new_sub_name, sub_vals))
            config_levels.append(config)

    all_configs = []
    for subconfig_name, conf in final_configs:
        conf = standardize_config(conf)
        random_params = conf['random'] if 'random' in conf else {}
        fixed_params = flatten(conf['fixed']) if 'fixed' in conf else {}
        grid_params = conf['grid'] if 'grid' in conf else {}

        if len(random_params) > 0:
            num_samples = random_params['samples']
            root_seed = random_params.get('seed', None)
            random_sampled = sample_random_configs(flatten(random_params),
                                                   seed=root_seed,
                                                   samples=num_samples)

        grids = [
            generate_grid(v, parent_key=k) for k, v in grid_params.items()
        ]
        grid_configs = dict([sub for item in grids for sub in item])
        grid_product = list(cartesian_product_dict(grid_configs))

        with_fixed = [{**d, **fixed_params} for d in grid_product]
        if len(random_params) > 0:
            with_random = [{
                **grid,
                **random
            } for grid in with_fixed for random in random_sampled]
        else:
            with_random = with_fixed
        all_configs.extend(with_random)

    # Cast NumPy integers to normal integers since PyMongo doesn't like them
    all_configs = [{
        k: int(v) if isinstance(v, np.integer) else v
        for k, v in config.items()
    } for config in all_configs]

    all_configs = [unflatten(conf) for conf in all_configs]
    return all_configs