def visitFunction(self, ctx: OverrideParser.FunctionContext) -> Any: args = [] kwargs = {} children = ctx.getChildren() func_name = next(children).getText() assert self.is_matching_terminal(next(children), OverrideLexer.POPEN) in_kwargs = False while True: cur = next(children) if self.is_matching_terminal(cur, OverrideLexer.PCLOSE): break if isinstance(cur, OverrideParser.ArgNameContext): in_kwargs = True name = cur.getChild(0).getText() cur = next(children) value = self.visitElement(cur) kwargs[name] = value else: if self.is_matching_terminal(cur, OverrideLexer.COMMA): continue if in_kwargs: raise HydraException("positional argument follows keyword argument") value = self.visitElement(cur) args.append(value) function = FunctionCall(name=func_name, args=args, kwargs=kwargs) try: return self.functions.eval(function) except Exception as e: raise HydraException( f"{type(e).__name__} while evaluating '{ctx.getText()}': {e}" ) from e
def _apply_overrides_to_config(overrides: List[str], cfg: DictConfig) -> None: loader = _utils.get_yaml_loader() def get_value(val_: Optional[str]) -> Any: return yaml.load(val_, Loader=loader) if val_ is not None else None for line in overrides: override = ConfigLoaderImpl._parse_config_override(line) try: value = get_value(override.value) if override.is_delete(): val = OmegaConf.select(cfg, override.key, throw_on_missing=False) if val is None: raise HydraException( f"Could not delete from config. '{override.key}' does not exist." ) elif value is not None and value != val: raise HydraException( f"Could not delete from config." f" The value of '{override.key}' is {val} and not {override.value}." ) key = override.key last_dot = key.rfind(".") with open_dict(cfg): if last_dot == -1: del cfg[key] else: node = OmegaConf.select(cfg, key[0:last_dot]) del node[key[last_dot + 1:]] elif override.is_add(): if (OmegaConf.select(cfg, override.key, throw_on_missing=False) is None): with open_dict(cfg): OmegaConf.update(cfg, override.key, value) else: raise HydraException( f"Could not append to config. An item is already at '{override.key}'." ) else: try: OmegaConf.update(cfg, override.key, value) except (ConfigAttributeError, ConfigKeyError) as ex: raise HydraException( f"Could not override '{override.key}'. No match in config." f"\nTo append to your config use +{line}") from ex except OmegaConfBaseException as ex: raise HydraException(f"Error merging override {line}") from ex
def syntaxError( self, recognizer: Any, offending_symbol: Any, line: Any, column: Any, msg: Any, e: Any, ) -> None: if msg is not None: raise HydraException(msg) from e else: raise HydraException(str(e)) from e
def __init__( self, config_path: Optional[str] = None, job_name: Optional[str] = None, strict: Optional[bool] = None, caller_stack_depth: int = 1, ) -> None: self._gh_backup = get_gh_backup() if config_path is not None and os.path.isabs(config_path): raise HydraException( "config_path in initialize() must be relative") calling_file, calling_module = detect_calling_file_or_module_from_stack_frame( caller_stack_depth + 1) if job_name is None: job_name = detect_task_name(calling_file=calling_file, calling_module=calling_module) Hydra.create_main_hydra_file_or_module( calling_file=calling_file, calling_module=calling_module, config_path=config_path, job_name=job_name, strict=strict, )
def split_arguments( overrides: List[Override], max_batch_size: Optional[int]) -> List[List[List[str]]]: lists = [] for override in overrides: if override.is_sweep_override(): if override.is_discrete_sweep(): key = override.get_key_element() sweep = [ f"{key}={val}" for val in override.sweep_string_iterator() ] lists.append(sweep) else: assert override.value_type is not None raise HydraException( f"{BasicSweeper.__name__} does not support sweep type : {override.value_type.name}" ) else: key = override.get_key_element() value = override.get_value_element_as_str() lists.append([f"{key}={value}"]) all_batches = [list(x) for x in itertools.product(*lists)] assert max_batch_size is None or max_batch_size > 0 if max_batch_size is None: return [all_batches] else: chunks_iter = BasicSweeper.split_overrides_to_chunks( all_batches, max_batch_size) return [x for x in chunks_iter]
def register(self, name: str, func: Callable[..., Any]) -> None: if name in self.definitions: raise HydraException( f"Function named '{name}' is already registered") self.definitions[name] = inspect.signature(func) self.functions[name] = func
def __init__( self, config_path: Optional[str] = _UNSPECIFIED_, job_name: Optional[str] = None, caller_stack_depth: int = 1, ) -> None: self._gh_backup = get_gh_backup() # DEPRECATED: remove in 1.2 # in 1.2, the default config_path should be changed to None if config_path is _UNSPECIFIED_: url = "https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_hydra_main_config_path" deprecation_warning( message=dedent(f"""\ config_path is not specified in hydra.initialize(). See {url} for more information."""), stacklevel=2, ) config_path = "." if config_path is not None and os.path.isabs(config_path): raise HydraException( "config_path in initialize() must be relative") calling_file, calling_module = detect_calling_file_or_module_from_stack_frame( caller_stack_depth + 1) if job_name is None: job_name = detect_task_name(calling_file=calling_file, calling_module=calling_module) Hydra.create_main_hydra_file_or_module( calling_file=calling_file, calling_module=calling_module, config_path=config_path, job_name=job_name, )
def _parse_config_override(override: str) -> ParsedConfigOverride: # forms: # update: key=value # append: +key=value # delete: ~key=value | ~key # regex code and tests: https://regex101.com/r/JAPVdx/9 regex = r"^(?P<prefix>[+~])?(?P<key>[\w\.@]+)(?:=(?P<value>.*))?$" matches = re.search(regex, override) valid = True prefix = None key = None value = None msg = ( f"Error parsing config override : '{override}'" f"\nAccepted forms:" f"\n\tOverride: key=value" f"\n\tAppend: +key=value" f"\n\tDelete: ~key=value, ~key" ) if matches: prefix = matches.group("prefix") key = matches.group("key") value = matches.group("value") if prefix in (None, "+"): valid = key is not None and value is not None elif prefix == "~": valid = key is not None if matches and valid: assert key is not None return ParsedConfigOverride(prefix=prefix, key=key, value=value) else: raise HydraException(msg)
def sweep_iterator( self, transformer: TransformerType = Transformer.identity ) -> Iterator[ElementType]: """ Converts CHOICE_SWEEP, SIMPLE_CHOICE_SWEEP, GLOB_CHOICE_SWEEP and RANGE_SWEEP to a List[Elements] that can be used in the value component of overrides (the part after the =). A transformer may be provided for converting each element to support the needs of different sweepers """ if self.value_type not in ( ValueType.CHOICE_SWEEP, ValueType.SIMPLE_CHOICE_SWEEP, ValueType.GLOB_CHOICE_SWEEP, ValueType.RANGE_SWEEP, ): raise HydraException( f"Can only enumerate CHOICE and RANGE sweeps, type is {self.value_type}" ) lst: Any if isinstance(self._value, list): lst = self._value elif isinstance(self._value, ChoiceSweep): if self._value.shuffle: lst = copy(self._value.list) shuffle(lst) else: lst = self._value.list elif isinstance(self._value, RangeSweep): if self._value.shuffle: lst = list(self._value.range()) shuffle(lst) lst = iter(lst) else: lst = self._value.range() elif isinstance(self._value, Glob): if self.config_loader is None: raise HydraException("ConfigLoader is not set") ret = self.config_loader.get_group_options( self.key_or_group, results_filter=ObjectType.CONFIG) return iter(self._value.filter(ret)) else: assert False return map(transformer, lst)
def call(config: Any, *args: Any, **kwargs: Any) -> Any: """ :param config: An object describing what to call and what params to use. Must have a _target_ field. :param args: optional positional parameters pass-through :param kwargs: optional named parameters pass-through :return: the return value from the specified class or method """ if OmegaConf.is_none(config): return None if isinstance(config, TargetConf) and config._target_ == "???": # Specific check to give a good warning about failure to annotate _target_ as a string. raise InstantiationException( f"Missing value for {type(config).__name__}._target_. Check that it's properly annotated and overridden." f"\nA common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'" ) if not (isinstance(config, dict) or OmegaConf.is_config(config) or is_structured_config(config)): raise HydraException( f"Unsupported config type : {type(config).__name__}") # make a copy to ensure we do not change the provided object config_copy = OmegaConf.structured(config) if OmegaConf.is_config(config): config_copy._set_parent(config._get_parent()) config = config_copy cls = "<unknown>" try: assert isinstance(config, DictConfig) OmegaConf.set_readonly(config, False) OmegaConf.set_struct(config, False) cls = _get_cls_name(config) type_or_callable = _locate(cls) if isinstance(type_or_callable, type): return _instantiate_class(type_or_callable, config, *args, **kwargs) else: assert callable(type_or_callable) return _call_callable(type_or_callable, config, *args, **kwargs) except InstantiationException as e: raise e except Exception as e: raise HydraException(f"Error calling '{cls}' : {e}") from e
def _raise_parse_override_error(override: str) -> str: msg = ( f"Error parsing config group override : '{override}'" f"\nAccepted forms:" f"\n\tOverride: key=value, key@package=value, key@src_pkg:dest_pkg=value, key@src_pkg:dest_pkg" f"\n\tAppend: +key=value, +key@package=value" f"\n\tDelete: ~key, ~key@pkg, ~key=value, ~key@pkg=value") raise HydraException(msg)
def syntaxError( self, recognizer: Any, offending_symbol: Any, line: Any, column: Any, msg: Any, e: Any, ) -> None: raise HydraException(msg)
def sweep_string_iterator(self) -> Iterator[str]: """ Converts the sweep_choices from a List[ParsedElements] to a List[str] that can be used in the value component of overrides (the part after the =) """ if self.value_type not in ( ValueType.CHOICE_SWEEP, ValueType.SIMPLE_CHOICE_SWEEP, ValueType.GLOB_CHOICE_SWEEP, ValueType.RANGE_SWEEP, ): raise HydraException( f"Can only enumerate CHOICE and RANGE sweeps, type is {self.value_type}" ) lst: Any if isinstance(self._value, list): lst = self._value elif isinstance(self._value, ChoiceSweep): if self._value.shuffle: lst = copy(self._value.list) shuffle(lst) else: lst = self._value.list elif isinstance(self._value, RangeSweep): if self._value.shuffle: lst = list(self._value.range()) shuffle(lst) lst = iter(lst) else: lst = self._value.range() elif isinstance(self._value, Glob): if self.config_loader is None: raise HydraException("ConfigLoader is not set") ret = self.config_loader.get_group_options( self.key_or_group, results_filter=ObjectType.CONFIG) return iter(self._value.filter(ret)) else: assert False return map(Override._get_value_element_as_str, lst)
def _raise_parse_override_error(override: Optional[str]) -> None: msg = ( f"Error parsing config group override : '{override}'" f"\nAccepted forms:" f"\n\tOverride: key=value, key@package=value, key@src_pkg:dest_pkg=value, key@src_pkg:dest_pkg" f"\n\tAppend: +key=value, +key@package=value" f"\n\tDelete: ~key, ~key@pkg, ~key=value, ~key@pkg=value" f"\n" f"\nSee https://hydra.cc/docs/next/advanced/command_line_syntax for details" ) raise HydraException(msg)
def __init__(self, config_dir: str, job_name: str = "app") -> None: self._gh_backup = get_gh_backup() # Relative here would be interpreted as relative to cwd, which - depending on when it run # may have unexpected meaning. best to force an absolute path to avoid confusion. # Can consider using hydra.utils.to_absolute_path() to convert it at a future point if there is demand. if not os.path.isabs(config_dir): raise HydraException( "initialize_config_dir() requires an absolute config_dir as input" ) csp = create_config_search_path(search_path_dir=config_dir) Hydra.create_main_hydra2(task_name=job_name, config_search_path=csp)
def visitPrimitive( self, ctx: OverrideParser.PrimitiveContext ) -> Optional[Union[QuotedString, int, bool, float, str]]: ret: Optional[Union[int, bool, float, str]] first_idx = 0 last_idx = ctx.getChildCount() # skip first if whitespace if self.is_ws(ctx.getChild(0)): if last_idx == 1: # Only whitespaces => this is not allowed. raise HydraException( "Trying to parse a primitive that is all whitespaces" ) first_idx = 1 if self.is_ws(ctx.getChild(-1)): last_idx = last_idx - 1 num = last_idx - first_idx if num > 1: ret = ctx.getText().strip() else: node = ctx.getChild(first_idx) if node.symbol.type == OverrideLexer.QUOTED_VALUE: text = node.getText() qc = text[0] text = text[1:-1] if qc == "'": quote = Quote.single text = text.replace("\\'", "'") elif qc == '"': quote = Quote.double text = text.replace('\\"', '"') else: assert False return QuotedString(text=text, quote=quote) elif node.symbol.type in (OverrideLexer.ID, OverrideLexer.INTERPOLATION): ret = node.symbol.text elif node.symbol.type == OverrideLexer.INT: ret = int(node.symbol.text) elif node.symbol.type == OverrideLexer.FLOAT: ret = float(node.symbol.text) elif node.symbol.type == OverrideLexer.NULL: ret = None elif node.symbol.type == OverrideLexer.BOOL: text = node.getText().lower() if text == "true": ret = True elif text == "false": ret = False else: assert False else: return node.getText() # type: ignore return ret
def get_value_element_as_str(self, space_after_sep: bool = False) -> str: """ Returns a string representation of the value in this override (similar to the part after the = in the input string) :param space_after_sep: True to append space after commas and colons :return: """ if isinstance(self._value, Sweep): # This should not be called for sweeps raise HydraException("Cannot convert sweep to str") return Override._get_value_element_as_str( self._value, space_after_sep=space_after_sep)
def parse_overrides(self, overrides: List[str]) -> List[Override]: ret: List[Override] = [] for override in overrides: try: parsed = self.parse_rule(override, "override") except HydraException as e: raise HydraException( f"Error parsing override '{override}' : {e}" f"\nSee https://hydra.cc/docs/next/advanced/command_line_syntax for details" ) assert isinstance(parsed, Override) ret.append(parsed) return ret
def _get_cls_name(config: Union[ObjectConf, DictConfig]) -> str: def _warn(field: str) -> None: if isinstance(config, DictConfig): warnings.warn( f"""Config key '{config._get_full_key(field)}' is deprecated since Hydra 1.0 and will be removed in Hydra 1.1. Use 'target' instead of '{field}'.""", category=UserWarning, ) else: warnings.warn( f""" ObjectConf field '{field}' is deprecated since Hydra 1.0 and will be removed in Hydra 1.1. Use 'target' instead of '{field}'.""", category=UserWarning, ) def _getcls(field: str) -> str: classname = getattr(config, field) assert isinstance(classname, str) return classname def _has_field(field: str) -> bool: if isinstance(config, DictConfig): if field in config: ret = config[field] != "???" assert isinstance(ret, bool) return ret else: return False else: if hasattr(config, field): ret = getattr(config, field) != "???" assert isinstance(ret, bool) return ret else: return False if _has_field(field="class"): _warn(field="class") elif _has_field(field="cls"): _warn(field="cls") if _has_field(field="target"): return _getcls(field="target") elif _has_field(field="cls"): return _getcls(field="cls") elif _has_field(field="class"): return _getcls(field="class") else: raise HydraException("Input config does not have a `target` field")
def _apply_overrides_to_config(overrides: List[Override], cfg: DictConfig) -> None: for override in overrides: if override.get_subject_package() is not None: raise HydraException( f"Override {override.input_line} looks like a config group override, " f"but config group '{override.key_or_group}' does not exist." ) key = override.key_or_group value = override.value() try: if override.is_delete(): config_val = OmegaConf.select(cfg, key, throw_on_missing=False) if config_val is None: raise HydraException( f"Could not delete from config. '{override.key_or_group}' does not exist." ) elif value is not None and value != config_val: raise HydraException( f"Could not delete from config." f" The value of '{override.key_or_group}' is {config_val} and not {value}." ) last_dot = key.rfind(".") with open_dict(cfg): if last_dot == -1: del cfg[key] else: node = OmegaConf.select(cfg, key[0:last_dot]) del node[key[last_dot + 1:]] elif override.is_add(): if OmegaConf.select(cfg, key, throw_on_missing=False) is None: with open_dict(cfg): OmegaConf.update(cfg, key, value) else: raise HydraException( f"Could not append to config. An item is already at '{override.key_or_group}'." ) else: try: OmegaConf.update(cfg, key, value) except (ConfigAttributeError, ConfigKeyError) as ex: raise HydraException( f"Could not override '{override.key_or_group}'. No match in config." f"\nTo append to your config use +{override.input_line}" ) from ex except OmegaConfBaseException as ex: raise HydraException( f"Error merging override {override.input_line}") from ex
def _merge_config( self, cfg: DictConfig, config_group: str, name: str, required: bool, is_primary_config: bool, package_override: Optional[str], ) -> DictConfig: try: if config_group != "": new_cfg = f"{config_group}/{name}" else: new_cfg = name loaded_cfg, _ = self._load_config_impl( new_cfg, is_primary_config=is_primary_config, package_override=package_override, ) if loaded_cfg is None: if required: if config_group == "": msg = f"Could not load {new_cfg}" raise MissingConfigException(msg, new_cfg) else: options = self.get_group_options(config_group) if options: opt_list = "\n".join(["\t" + x for x in options]) msg = ( f"Could not load {new_cfg}.\nAvailable options:" f"\n{opt_list}") else: msg = f"Could not load {new_cfg}" raise MissingConfigException(msg, new_cfg, options) else: return cfg else: ret = OmegaConf.merge(cfg, loaded_cfg) assert isinstance(ret, DictConfig) return ret except OmegaConfBaseException as ex: raise HydraException( f"Error merging {config_group}={name}") from ex
def eval(self, func: FunctionCall) -> Any: if func.name not in self.definitions: raise HydraException( f"Unknown function '{func.name}'" f"\nAvailable: {','.join(sorted(self.definitions.keys()))}\n") sig = self.definitions[func.name] # unquote strings in args args = [] for arg in func.args: if isinstance(arg, QuotedString): arg = arg.text args.append(arg) # Unquote strings in kwargs values kwargs = {} for key, val in func.kwargs.items(): if isinstance(val, QuotedString): val = val.text kwargs[key] = val bound = sig.bind(*args, **kwargs) for idx, arg in enumerate(bound.arguments.items()): name = arg[0] value = arg[1] expected_type = sig.parameters[name].annotation if sig.parameters[name].kind == inspect.Parameter.VAR_POSITIONAL: for iidx, v in enumerate(value): if not is_type_matching(v, expected_type): raise TypeError( f"mismatch type argument {name}[{iidx}]:" f" {type_str(type(v))} is incompatible with {type_str(expected_type)}" ) else: if not is_type_matching(value, expected_type): raise TypeError( f"mismatch type argument {name}:" f" {type_str(type(value))} is incompatible with {type_str(expected_type)}" ) return self.functions[func.name](*bound.args, **bound.kwargs)
def call(config: Union[ObjectConf, DictConfig], *args: Any, **kwargs: Any) -> Any: """ :param config: An ObjectConf or DictConfig describing what to call and what params to use :param args: optional positional parameters pass-through :param kwargs: optional named parameters pass-through :return: the return value from the specified class or method """ if OmegaConf.is_none(config): return None try: cls = _get_cls_name(config) type_or_callable = _locate(cls) if isinstance(type_or_callable, type): return _instantiate_class(type_or_callable, config, *args, **kwargs) else: assert callable(type_or_callable) return _call_callable(type_or_callable, config, *args, **kwargs) except Exception as e: raise HydraException(f"Error calling '{cls}' : {e}") from e
def call( config: Union[ObjectConf, TargetConf, DictConfig, Dict[Any, Any]], *args: Any, **kwargs: Any, ) -> Any: """ :param config: An object describing what to call and what params to use :param args: optional positional parameters pass-through :param kwargs: optional named parameters pass-through :return: the return value from the specified class or method """ if isinstance(config, TargetConf) and config._target_ == "???": raise InstantiationException( f"Missing _target_ value. Check that you specified it in '{type(config).__name__}'" f" and that the type annotation is correct: '_target_: str = ...'") if isinstance(config, (dict, ObjectConf, TargetConf)): config = OmegaConf.structured(config) if OmegaConf.is_none(config): return None cls = "<unknown>" try: assert isinstance(config, DictConfig) # make a copy to ensure we do not change the provided object config = copy.deepcopy(config) OmegaConf.set_readonly(config, False) OmegaConf.set_struct(config, False) cls = _get_cls_name(config) type_or_callable = _locate(cls) if isinstance(type_or_callable, type): return _instantiate_class(type_or_callable, config, *args, **kwargs) else: assert callable(type_or_callable) return _call_callable(type_or_callable, config, *args, **kwargs) except InstantiationException as e: raise e except Exception as e: raise HydraException(f"Error calling '{cls}' : {e}") from e
def __next__(self) -> float: assert isinstance(self.start, decimal.Decimal) assert isinstance(self.stop, decimal.Decimal) assert isinstance(self.step, decimal.Decimal) if self.step > 0: if self.start < self.stop: ret = float(self.start) self.start += self.step return ret else: raise StopIteration elif self.step < 0: if self.start > self.stop: ret = float(self.start) self.start += self.step return ret else: raise StopIteration else: raise HydraException( f"Invalid range values (start:{self.start}, stop:{self.stop}, step:{self.step})" )
def _update_package_in_header( self, header: Dict[str, str], normalized_config_path: str, is_primary_config: bool, package_override: Optional[str], ) -> None: config_without_ext = normalized_config_path[0:-len(".yaml")] package = ConfigSource._resolve_package( config_without_ext=config_without_ext, header=header, package_override=package_override, ) if is_primary_config: if "package" not in header: header["package"] = "_global_" else: if package != "": raise HydraException( f"Primary config '{config_without_ext}' must be " f"in the _global_ package; effective package : '{package}'" ) else: if "package" not in header: # Loading a config group option. # Hydra 1.0: default to _global_ and warn. # Hydra 1.1: default will change to _package_ and the warning will be removed. header["package"] = "_global_" msg = ( f"\nMissing @package directive {normalized_config_path} in {self.full_path()}.\n" f"See https://hydra.cc/docs/next/upgrades/0.11_to_1.0/adding_a_package_directive" ) warnings.warn(message=msg, category=UserWarning) header["package"] = package
def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: """ :param config: An config object describing what to call and what params to use. In addition to the parameters, the config must contain: _target_ : target class or callable name (str) And may contain: _recursive_: Construct nested objects as well (bool). True by default. may be overridden via a _recursive_ key in the kwargs _convert_: Conversion strategy none : Passed objects are DictConfig and ListConfig, default partial : Passed objects are converted to dict and list, with the exception of Structured Configs (and their fields). all : Passed objects are dicts, lists and primitives without a trace of OmegaConf containers :param args: Optional positional parameters pass-through :param kwargs: Optional named parameters to override parameters in the config object. Parameters not present in the config objects are being passed as is to the target. :return: if _target_ is a class name: the instantiated object if _target_ is a callable: the return value of the call """ if OmegaConf.is_none(config): return None if isinstance(config, TargetConf) and config._target_ == "???": # Specific check to give a good warning about failure to annotate _target_ as a string. raise InstantiationException( f"Missing value for {type(config).__name__}._target_. Check that it's properly annotated and overridden." f"\nA common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'" ) if not ( isinstance(config, dict) or OmegaConf.is_config(config) or is_structured_config(config) ): raise HydraException(f"Unsupported config type : {type(config).__name__}") if isinstance(config, dict): configc = config.copy() _convert_container_targets_to_strings(configc) config = configc kwargsc = kwargs.copy() _convert_container_targets_to_strings(kwargsc) kwargs = kwargsc # make a copy to ensure we do not change the provided object config_copy = OmegaConf.structured(config, flags={"allow_objects": True}) if OmegaConf.is_config(config): config_copy._set_parent(config._get_parent()) config = config_copy assert OmegaConf.is_config(config) OmegaConf.set_readonly(config, False) OmegaConf.set_struct(config, False) target = _get_target_type(config, kwargs) try: config._set_flag("allow_objects", True) final_kwargs = _get_kwargs(config, **kwargs) convert = _pop_convert_mode(final_kwargs) if convert == ConvertMode.PARTIAL: final_kwargs = OmegaConf.to_container( final_kwargs, resolve=True, exclude_structured_configs=True ) return target(*args, **final_kwargs) elif convert == ConvertMode.ALL: final_kwargs = OmegaConf.to_container(final_kwargs, resolve=True) return target(*args, **final_kwargs) elif convert == ConvertMode.NONE: return target(*args, **final_kwargs) else: assert False except Exception as e: # preserve the original exception backtrace raise type(e)( f"Error instantiating/calling '{_convert_target_to_string(target)}' : {e}" ).with_traceback(sys.exc_info()[2])
def _apply_overrides_to_defaults( overrides: List[Override], defaults: List[DefaultElement], ) -> None: key_to_defaults: Dict[str, List[IndexedDefaultElement]] = defaultdict(list) for idx, default in enumerate(defaults): if default.config_group is not None: key_to_defaults[default.config_group].append( IndexedDefaultElement(idx=idx, default=default)) for override in overrides: value = override.value() if value is None: if override.is_add(): ConfigLoaderImpl._raise_parse_override_error( override.input_line) if not override.is_delete(): override.type = OverrideType.DEL msg = ( "\nRemoving from the defaults list by assigning 'null' " "is deprecated and will be removed in Hydra 1.1." f"\nUse ~{override.key_or_group}") warnings.warn(category=UserWarning, message=msg) if (not (override.is_delete() or override.is_package_rename()) and value is None): ConfigLoaderImpl._raise_parse_override_error( override.input_line) if override.is_add() and override.is_package_rename(): raise HydraException( "Add syntax does not support package rename, remove + prefix" ) matches = ConfigLoaderImpl.find_matches(key_to_defaults, override) if isinstance(value, (list, dict)): raise HydraException( f"Config group override value type cannot be a {type(value).__name__}" ) if override.is_delete(): src = override.get_source_item() if len(matches) == 0: raise HydraException( f"Could not delete. No match for '{src}' in the defaults list." ) for pair in matches: if value is not None and value != defaults[ pair.idx].config_name: raise HydraException( f"Could not delete. No match for '{src}={value}' in the defaults list." ) del defaults[pair.idx] elif override.is_add(): if len(matches) > 0: src = override.get_source_item() raise HydraException( f"Could not add. An item matching '{src}' is already in the defaults list." ) assert value is not None defaults.append( DefaultElement( config_group=override.key_or_group, config_name=str(value), package=override.get_subject_package(), )) else: assert value is not None # override for match in matches: default = match.default default.config_name = str(value) if override.is_package_rename(): default.package = override.get_subject_package() if len(matches) == 0: src = override.get_source_item() if override.is_package_rename(): msg = f"Could not rename package. No match for '{src}' in the defaults list." else: msg = ( f"Could not override '{src}'. No match in the defaults list." f"\nTo append to your default list use +{override.input_line}" ) raise HydraException(msg)
def _createPrimitive( self, ctx: ParserRuleContext ) -> Optional[Union[QuotedString, int, bool, float, str]]: ret: Optional[Union[int, bool, float, str]] first_idx = 0 last_idx = ctx.getChildCount() # skip first if whitespace if self.is_ws(ctx.getChild(0)): if last_idx == 1: # Only whitespaces => this is not allowed. raise HydraException( "Trying to parse a primitive that is all whitespaces" ) first_idx = 1 if self.is_ws(ctx.getChild(-1)): last_idx = last_idx - 1 num = last_idx - first_idx if num > 1: # Concatenate, while un-escaping as needed. tokens = [] for i, n in enumerate(ctx.getChildren()): if n.symbol.type == OverrideLexer.WS and ( i < first_idx or i >= last_idx ): # Skip leading / trailing whitespaces. continue tokens.append( n.symbol.text[1::2] # un-escape by skipping every other char if n.symbol.type == OverrideLexer.ESC else n.symbol.text ) ret = "".join(tokens) else: node = ctx.getChild(first_idx) if node.symbol.type == OverrideLexer.QUOTED_VALUE: text = node.getText() qc = text[0] text = text[1:-1] if qc == "'": quote = Quote.single text = text.replace("\\'", "'") elif qc == '"': quote = Quote.double text = text.replace('\\"', '"') else: assert False return QuotedString(text=text, quote=quote) elif node.symbol.type in (OverrideLexer.ID, OverrideLexer.INTERPOLATION): ret = node.symbol.text elif node.symbol.type == OverrideLexer.INT: ret = int(node.symbol.text) elif node.symbol.type == OverrideLexer.FLOAT: ret = float(node.symbol.text) elif node.symbol.type == OverrideLexer.NULL: ret = None elif node.symbol.type == OverrideLexer.BOOL: text = node.getText().lower() if text == "true": ret = True elif text == "false": ret = False else: assert False elif node.symbol.type == OverrideLexer.ESC: ret = node.symbol.text[1::2] else: return node.getText() # type: ignore return ret
def _apply_overrides_to_defaults( overrides: List[ParsedOverrideWithLine], defaults: List[DefaultElement], ) -> None: key_to_defaults: Dict[str, List[IndexedDefaultElement]] = defaultdict(list) for idx, default in enumerate(defaults): if default.config_group is not None: key_to_defaults[default.config_group].append( IndexedDefaultElement(idx=idx, default=default) ) for owl in overrides: override = owl.override if override.value == "null": if override.prefix not in (None, "~"): ConfigLoaderImpl._raise_parse_override_error(owl.input_line) override.prefix = "~" override.value = None msg = ( "\nRemoving from the defaults list by assigning 'null' " "is deprecated and will be removed in Hydra 1.1." f"\nUse ~{override.key}" ) warnings.warn(category=UserWarning, message=msg) if ( not (override.is_delete() or override.is_package_rename()) and override.value is None ): ConfigLoaderImpl._raise_parse_override_error(owl.input_line) if override.is_add() and override.is_package_rename(): raise HydraException( "Add syntax does not support package rename, remove + prefix" ) if override.value is not None and "," in override.value: # If this is a multirun config (comma separated list), flag the default to prevent it from being # loaded until we are constructing the config for individual jobs. override.value = "_SKIP_" matches = ConfigLoaderImpl.find_matches(key_to_defaults, override) if override.is_delete(): src = override.get_source_item() if len(matches) == 0: raise HydraException( f"Could not delete. No match for '{src}' in the defaults list." ) for pair in matches: if ( override.value is not None and override.value != defaults[pair.idx].config_name ): raise HydraException( f"Could not delete. No match for '{src}={override.value}' in the defaults list." ) del defaults[pair.idx] elif override.is_add(): if len(matches) > 0: src = override.get_source_item() raise HydraException( f"Could not add. An item matching '{src}' is already in the defaults list." ) assert override.value is not None defaults.append( DefaultElement( config_group=override.key, config_name=override.value, package=override.get_subject_package(), ) ) else: # override for match in matches: default = match.default if override.value is not None: default.config_name = override.value if override.pkg1 is not None: default.package = override.get_subject_package() if len(matches) == 0: src = override.get_source_item() if override.is_package_rename(): msg = f"Could not rename package. No match for '{src}' in the defaults list." else: msg = ( f"Could not override '{src}'. No match in the defaults list." f"\nTo append to your default list use +{owl.input_line}" ) raise HydraException(msg)