Пример #1
0
def create_choice_param_from_choice_override(override: Override) -> Dict[str, Any]:
    key = override.get_key_element()
    param = {
        "name": key,
        "type": "choice",
        "values": list(override.sweep_iterator(transformer=Transformer.encode)),
    }
    return param
Пример #2
0
 def is_matching(override: Override, default: DefaultElement) -> bool:
     assert override.key_or_group == default.config_group
     if override.is_delete():
         return override.get_subject_package() == default.package
     else:
         return override.key_or_group == default.config_group and (
             override.pkg1 == default.package
             or override.pkg1 == "" and default.package is None)
Пример #3
0
def create_fixed_param_from_element_override(override: Override) -> Dict[str, Any]:
    key = override.get_key_element()
    param = {
        "name": key,
        "type": "fixed",
        "value": override.value(),
    }

    return param
Пример #4
0
def create_choice_param_from_range_override(override: Override) -> Dict[str, Any]:
    key = override.get_key_element()
    param = {
        "name": key,
        "type": "choice",
        "values": [val for val in override.sweep_iterator()],
    }

    return param
Пример #5
0
def create_range_param_using_interval_override(override: Override) -> Dict[str, Any]:
    key = override.get_key_element()
    value = override.value()
    assert isinstance(value, IntervalSweep)
    param = {
        "name": key,
        "type": "range",
        "bounds": [value.start, value.end],
    }
    return param
Пример #6
0
def create_nevergrad_parameter_from_override(override: Override) -> Any:
    val = override.value()
    if not override.is_sweep_override():
        return val
    if override.is_choice_sweep():
        assert isinstance(val, ChoiceSweep)
        vals = [
            x for x in override.sweep_iterator(transformer=Transformer.encode)
        ]
        if "ordered" in val.tags:
            return ng.p.TransitionChoice(vals)
        else:
            return ng.p.Choice(vals)
    elif override.is_range_sweep():
        vals = [
            x for x in override.sweep_iterator(transformer=Transformer.encode)
        ]
        return ng.p.Choice(vals)
    elif override.is_interval_sweep():
        assert isinstance(val, IntervalSweep)
        if "log" in val.tags:
            scalar = ng.p.Log(lower=val.start, upper=val.end)
        else:
            scalar = ng.p.Scalar(lower=val.start,
                                 upper=val.end)  # type: ignore
        if isinstance(val.start, int):
            scalar.set_integer_casting()
        return scalar
def create_optuna_distribution_from_override(override: Override) -> Any:
    value = override.value()
    if not override.is_sweep_override():
        return value

    if override.is_choice_sweep():
        assert isinstance(value, ChoiceSweep)
        choices = [
            x for x in override.sweep_iterator(transformer=Transformer.encode)
        ]
        return CategoricalDistribution(choices)

    if override.is_range_sweep():
        choices = [
            x for x in override.sweep_iterator(transformer=Transformer.encode)
        ]
        return CategoricalDistribution(choices)

    if override.is_interval_sweep():
        assert isinstance(value, IntervalSweep)
        if "log" in value.tags:
            if "int" in value.tags:
                return IntLogUniformDistribution(value.start, value.end)
            return LogUniformDistribution(value.start, value.end)
        else:
            if "int" in value.tags:
                return IntUniformDistribution(value.start, value.end)
            return UniformDistribution(value.start, value.end)

    raise NotImplementedError(
        "{} is not supported by Optuna sweeper.".format(override))
Пример #8
0
    def visitOverride(self, ctx: OverrideParser.OverrideContext) -> Override:
        override_type = OverrideType.CHANGE
        children = ctx.getChildren()
        first_node = next(children)
        if isinstance(first_node, TerminalNodeImpl):
            symbol_text = first_node.symbol.text
            if symbol_text == "+":
                override_type = OverrideType.ADD
                key_node = next(children)
                if self.is_matching_terminal(key_node, OverrideLexer.PLUS):
                    override_type = OverrideType.FORCE_ADD
                    key_node = next(children)

            elif symbol_text == "~":
                override_type = OverrideType.DEL
                key_node = next(children)
            else:
                assert False
        else:
            key_node = first_node

        key = self.visitKey(key_node)
        value: Union[ChoiceSweep, RangeSweep, IntervalSweep, ParsedElementType]
        eq_node = next(children)
        if (override_type == OverrideType.DEL
                and isinstance(eq_node, TerminalNode)
                and eq_node.symbol.type == Token.EOF):
            value = None
            value_type = None
        else:
            assert self.is_matching_terminal(eq_node, OverrideLexer.EQUAL)
            if ctx.value() is None:
                value = ""
                value_type = ValueType.ELEMENT
            else:
                value = self.visitValue(ctx.value())
                if isinstance(value, ChoiceSweep):
                    if value.simple_form:
                        value_type = ValueType.SIMPLE_CHOICE_SWEEP
                    else:
                        value_type = ValueType.CHOICE_SWEEP
                elif isinstance(value, Glob):
                    value_type = ValueType.GLOB_CHOICE_SWEEP
                elif isinstance(value, IntervalSweep):
                    value_type = ValueType.INTERVAL_SWEEP
                elif isinstance(value, RangeSweep):
                    value_type = ValueType.RANGE_SWEEP
                else:
                    value_type = ValueType.ELEMENT

        return Override(
            type=override_type,
            key_or_group=key.key_or_group,
            _value=value,
            value_type=value_type,
            package=key.package,
        )
Пример #9
0
def create_optuna_distribution_from_override(override: Override) -> Any:
    if not override.is_sweep_override():
        return override.get_value_element_as_str()

    value = override.value()
    choices: List[CategoricalChoiceType] = []
    if override.is_choice_sweep():
        assert isinstance(value, ChoiceSweep)
        for x in override.sweep_iterator(transformer=Transformer.encode):
            assert isinstance(
                x, (str, int, float, bool)
            ), f"A choice sweep expects str, int, float, or bool type. Got {type(x)}."
            choices.append(x)
        return CategoricalDistribution(choices)

    if override.is_range_sweep():
        assert isinstance(value, RangeSweep)
        assert value.start is not None
        assert value.stop is not None
        if value.shuffle:
            for x in override.sweep_iterator(transformer=Transformer.encode):
                assert isinstance(
                    x, (str, int, float, bool)
                ), f"A choice sweep expects str, int, float, or bool type. Got {type(x)}."
                choices.append(x)
            return CategoricalDistribution(choices)
        return IntUniformDistribution(int(value.start),
                                      int(value.stop),
                                      step=int(value.step))

    if override.is_interval_sweep():
        assert isinstance(value, IntervalSweep)
        assert value.start is not None
        assert value.end is not None
        if "log" in value.tags:
            if isinstance(value.start, int) and isinstance(value.end, int):
                return IntLogUniformDistribution(int(value.start),
                                                 int(value.end))
            return LogUniformDistribution(value.start, value.end)
        else:
            if isinstance(value.start, int) and isinstance(value.end, int):
                return IntUniformDistribution(value.start, value.end)
            return UniformDistribution(value.start, value.end)

    raise NotImplementedError(
        f"{override} is not supported by Optuna sweeper.")
Пример #10
0
 def is_searchpath_override(v: Override) -> bool:
     return v.get_key_element() == "hydra.searchpath"