示例#1
0
def get_preset_variants(
    spec: Dict,
    config: Dict,
    constant_grid_search: bool = False,
    random_state: "RandomState" = None,
):
    """Get variants according to a spec, initialized with a config.

    Variables from the spec are overwritten by the variables in the config.
    Thus, we may end up with less sampled parameters.

    This function also checks if values used to overwrite search space
    parameters are valid, and logs a warning if not.
    """
    spec = copy.deepcopy(spec)

    resolved, _, _ = parse_spec_vars(config)

    for path, val in resolved:
        try:
            domain = _get_value(spec["config"], path)
            if isinstance(domain, dict):
                if "grid_search" in domain:
                    domain = Categorical(domain["grid_search"])
                else:
                    # If users want to overwrite an entire subdict,
                    # let them do it.
                    domain = None
        except IndexError as exc:
            raise ValueError(
                f"Pre-set config key `{'/'.join(path)}` does not correspond "
                f"to a valid key in the search space definition. Please add "
                f"this path to the `config` variable passed to `tune.run()`."
            ) from exc

        if domain:
            if isinstance(domain, Domain):
                if not domain.is_valid(val):
                    logger.warning(
                        f"Pre-set value `{val}` is not within valid values of "
                        f"parameter `{'/'.join(path)}`: {domain.domain_str}"
                    )
            else:
                # domain is actually a fixed value
                if domain != val:
                    logger.warning(
                        f"Pre-set value `{val}` is not equal to the value of "
                        f"parameter `{'/'.join(path)}`: {domain}"
                    )
        assign_value(spec["config"], path, val)

    return _generate_variants(
        spec, constant_grid_search=constant_grid_search, random_state=random_state
    )
示例#2
0
def _try_resolve(v) -> Tuple[bool, Any]:
    if isinstance(v, Domain):
        # Domain to sample from
        return False, v
    elif isinstance(v, dict) and len(v) == 1 and "eval" in v:
        # Lambda function in eval syntax
        return False, Function(
            lambda spec: eval(v["eval"], _STANDARD_IMPORTS, {"spec": spec})
        )
    elif isinstance(v, dict) and len(v) == 1 and "grid_search" in v:
        # Grid search values
        grid_values = v["grid_search"]
        return False, Categorical(grid_values).grid()
    return True, v
示例#3
0
def _try_resolve(v) -> Tuple[bool, Any]:
    if isinstance(v, Domain):
        # Domain to sample from
        return False, v
    elif isinstance(v, dict) and len(v) == 1 and "eval" in v:
        # Lambda function in eval syntax
        return False, Function(
            lambda spec: eval(v["eval"], _STANDARD_IMPORTS, {"spec": spec}))
    elif isinstance(v, dict) and len(v) == 1 and "grid_search" in v:
        # Grid search values
        grid_values = v["grid_search"]
        if not isinstance(grid_values, list):
            raise TuneError(
                "Grid search expected list of values, got: {}".format(
                    grid_values))
        return False, Categorical(grid_values).grid()
    return True, v