def create_choice_param_from_choice_override(override: Override) -> Dict[str, Any]: key = override.get_key_element() param = { "name": key, "type": "choice", "values": list(override.sweep_iterator(transformer=Transformer.encode)), } return param
def is_matching(override: Override, default: DefaultElement) -> bool: assert override.key_or_group == default.config_group if override.is_delete(): return override.get_subject_package() == default.package else: return override.key_or_group == default.config_group and ( override.pkg1 == default.package or override.pkg1 == "" and default.package is None)
def create_fixed_param_from_element_override(override: Override) -> Dict[str, Any]: key = override.get_key_element() param = { "name": key, "type": "fixed", "value": override.value(), } return param
def create_choice_param_from_range_override(override: Override) -> Dict[str, Any]: key = override.get_key_element() param = { "name": key, "type": "choice", "values": [val for val in override.sweep_iterator()], } return param
def create_range_param_using_interval_override(override: Override) -> Dict[str, Any]: key = override.get_key_element() value = override.value() assert isinstance(value, IntervalSweep) param = { "name": key, "type": "range", "bounds": [value.start, value.end], } return param
def create_nevergrad_parameter_from_override(override: Override) -> Any: val = override.value() if not override.is_sweep_override(): return val if override.is_choice_sweep(): assert isinstance(val, ChoiceSweep) vals = [ x for x in override.sweep_iterator(transformer=Transformer.encode) ] if "ordered" in val.tags: return ng.p.TransitionChoice(vals) else: return ng.p.Choice(vals) elif override.is_range_sweep(): vals = [ x for x in override.sweep_iterator(transformer=Transformer.encode) ] return ng.p.Choice(vals) elif override.is_interval_sweep(): assert isinstance(val, IntervalSweep) if "log" in val.tags: scalar = ng.p.Log(lower=val.start, upper=val.end) else: scalar = ng.p.Scalar(lower=val.start, upper=val.end) # type: ignore if isinstance(val.start, int): scalar.set_integer_casting() return scalar
def create_optuna_distribution_from_override(override: Override) -> Any: value = override.value() if not override.is_sweep_override(): return value if override.is_choice_sweep(): assert isinstance(value, ChoiceSweep) choices = [ x for x in override.sweep_iterator(transformer=Transformer.encode) ] return CategoricalDistribution(choices) if override.is_range_sweep(): choices = [ x for x in override.sweep_iterator(transformer=Transformer.encode) ] return CategoricalDistribution(choices) if override.is_interval_sweep(): assert isinstance(value, IntervalSweep) if "log" in value.tags: if "int" in value.tags: return IntLogUniformDistribution(value.start, value.end) return LogUniformDistribution(value.start, value.end) else: if "int" in value.tags: return IntUniformDistribution(value.start, value.end) return UniformDistribution(value.start, value.end) raise NotImplementedError( "{} is not supported by Optuna sweeper.".format(override))
def visitOverride(self, ctx: OverrideParser.OverrideContext) -> Override: override_type = OverrideType.CHANGE children = ctx.getChildren() first_node = next(children) if isinstance(first_node, TerminalNodeImpl): symbol_text = first_node.symbol.text if symbol_text == "+": override_type = OverrideType.ADD key_node = next(children) if self.is_matching_terminal(key_node, OverrideLexer.PLUS): override_type = OverrideType.FORCE_ADD key_node = next(children) elif symbol_text == "~": override_type = OverrideType.DEL key_node = next(children) else: assert False else: key_node = first_node key = self.visitKey(key_node) value: Union[ChoiceSweep, RangeSweep, IntervalSweep, ParsedElementType] eq_node = next(children) if (override_type == OverrideType.DEL and isinstance(eq_node, TerminalNode) and eq_node.symbol.type == Token.EOF): value = None value_type = None else: assert self.is_matching_terminal(eq_node, OverrideLexer.EQUAL) if ctx.value() is None: value = "" value_type = ValueType.ELEMENT else: value = self.visitValue(ctx.value()) if isinstance(value, ChoiceSweep): if value.simple_form: value_type = ValueType.SIMPLE_CHOICE_SWEEP else: value_type = ValueType.CHOICE_SWEEP elif isinstance(value, Glob): value_type = ValueType.GLOB_CHOICE_SWEEP elif isinstance(value, IntervalSweep): value_type = ValueType.INTERVAL_SWEEP elif isinstance(value, RangeSweep): value_type = ValueType.RANGE_SWEEP else: value_type = ValueType.ELEMENT return Override( type=override_type, key_or_group=key.key_or_group, _value=value, value_type=value_type, package=key.package, )
def create_optuna_distribution_from_override(override: Override) -> Any: if not override.is_sweep_override(): return override.get_value_element_as_str() value = override.value() choices: List[CategoricalChoiceType] = [] if override.is_choice_sweep(): assert isinstance(value, ChoiceSweep) for x in override.sweep_iterator(transformer=Transformer.encode): assert isinstance( x, (str, int, float, bool) ), f"A choice sweep expects str, int, float, or bool type. Got {type(x)}." choices.append(x) return CategoricalDistribution(choices) if override.is_range_sweep(): assert isinstance(value, RangeSweep) assert value.start is not None assert value.stop is not None if value.shuffle: for x in override.sweep_iterator(transformer=Transformer.encode): assert isinstance( x, (str, int, float, bool) ), f"A choice sweep expects str, int, float, or bool type. Got {type(x)}." choices.append(x) return CategoricalDistribution(choices) return IntUniformDistribution(int(value.start), int(value.stop), step=int(value.step)) if override.is_interval_sweep(): assert isinstance(value, IntervalSweep) assert value.start is not None assert value.end is not None if "log" in value.tags: if isinstance(value.start, int) and isinstance(value.end, int): return IntLogUniformDistribution(int(value.start), int(value.end)) return LogUniformDistribution(value.start, value.end) else: if isinstance(value.start, int) and isinstance(value.end, int): return IntUniformDistribution(value.start, value.end) return UniformDistribution(value.start, value.end) raise NotImplementedError( f"{override} is not supported by Optuna sweeper.")
def is_searchpath_override(v: Override) -> bool: return v.get_key_element() == "hydra.searchpath"