Ejemplo n.º 1
0
def create_nevergrad_parameter_from_override(override: Override) -> Any:
    val = override.value()
    if not override.is_sweep_override():
        return val
    if override.is_choice_sweep():
        assert isinstance(val, ChoiceSweep)
        vals = [
            x for x in override.sweep_iterator(transformer=Transformer.encode)
        ]
        if "ordered" in val.tags:
            return ng.p.TransitionChoice(vals)
        else:
            return ng.p.Choice(vals)
    elif override.is_range_sweep():
        vals = [
            x for x in override.sweep_iterator(transformer=Transformer.encode)
        ]
        return ng.p.Choice(vals)
    elif override.is_interval_sweep():
        assert isinstance(val, IntervalSweep)
        if "log" in val.tags:
            scalar = ng.p.Log(lower=val.start, upper=val.end)
        else:
            scalar = ng.p.Scalar(lower=val.start,
                                 upper=val.end)  # type: ignore
        if isinstance(val.start, int):
            scalar.set_integer_casting()
        return scalar
def create_optuna_distribution_from_override(override: Override) -> Any:
    value = override.value()
    if not override.is_sweep_override():
        return value

    if override.is_choice_sweep():
        assert isinstance(value, ChoiceSweep)
        choices = [
            x for x in override.sweep_iterator(transformer=Transformer.encode)
        ]
        return CategoricalDistribution(choices)

    if override.is_range_sweep():
        choices = [
            x for x in override.sweep_iterator(transformer=Transformer.encode)
        ]
        return CategoricalDistribution(choices)

    if override.is_interval_sweep():
        assert isinstance(value, IntervalSweep)
        if "log" in value.tags:
            if "int" in value.tags:
                return IntLogUniformDistribution(value.start, value.end)
            return LogUniformDistribution(value.start, value.end)
        else:
            if "int" in value.tags:
                return IntUniformDistribution(value.start, value.end)
            return UniformDistribution(value.start, value.end)

    raise NotImplementedError(
        "{} is not supported by Optuna sweeper.".format(override))
Ejemplo n.º 3
0
def create_optuna_distribution_from_override(override: Override) -> Any:
    value = override.value()
    if not override.is_sweep_override():
        return value

    choices: List[CategoricalChoiceType] = []
    if override.is_choice_sweep():
        assert isinstance(value, ChoiceSweep)
        for x in override.sweep_iterator(transformer=Transformer.encode):
            assert isinstance(
                x, (str, int, float, bool)
            ), f"A choice sweep expects str, int, float, or bool type. Got {type(x)}."
            choices.append(x)
        return CategoricalDistribution(choices)

    if override.is_range_sweep():
        assert isinstance(value, RangeSweep)
        assert value.start is not None
        assert value.stop is not None
        if value.shuffle:
            for x in override.sweep_iterator(transformer=Transformer.encode):
                assert isinstance(
                    x, (str, int, float, bool)
                ), f"A choice sweep expects str, int, float, or bool type. Got {type(x)}."
                choices.append(x)
            return CategoricalDistribution(choices)
        return IntUniformDistribution(int(value.start),
                                      int(value.stop),
                                      step=int(value.step))

    if override.is_interval_sweep():
        assert isinstance(value, IntervalSweep)
        assert value.start is not None
        assert value.end is not None
        if "log" in value.tags:
            if isinstance(value.start, int) and isinstance(value.end, int):
                return IntLogUniformDistribution(int(value.start),
                                                 int(value.end))
            return LogUniformDistribution(value.start, value.end)
        else:
            if isinstance(value.start, int) and isinstance(value.end, int):
                return IntUniformDistribution(value.start, value.end)
            return UniformDistribution(value.start, value.end)

    raise NotImplementedError(
        f"{override} is not supported by Optuna sweeper.")