def test_initialize_compat_version_base(hydra_restore_singletons: Any) -> None: assert not GlobalHydra().is_initialized() with raises( UserWarning, match=f"Will assume defaults for version {version.__compat_version__}", ): initialize() assert version.base_at_least(str(version.__compat_version__))
def _update_overrides( defaults_list: List[InputDefault], overrides: Overrides, parent: InputDefault, interpolated_subtree: bool, ) -> None: seen_override = False last_override_seen = None for d in defaults_list: if d.is_self(): continue d.update_parent(parent.get_group_path(), parent.get_final_package()) legacy_hydra_override = False if isinstance(d, GroupDefault): assert d.group is not None if not version.base_at_least("1.2"): legacy_hydra_override = not d.is_override( ) and d.group.startswith("hydra/") if seen_override and not (d.is_override() or d.is_external_append() or legacy_hydra_override): assert isinstance(last_override_seen, GroupDefault) pcp = parent.get_config_path() okey = last_override_seen.get_override_key() oval = last_override_seen.get_name() raise ConfigCompositionException( dedent(f"""\ In {pcp}: Override '{okey} : {oval}' is defined before '{d.get_override_key()}: {d.get_name()}'. Overrides must be at the end of the defaults list""")) if isinstance(d, GroupDefault): if legacy_hydra_override: d.override = True url = "https://hydra.cc/docs/next/upgrades/1.0_to_1.1/defaults_list_override" msg = dedent(f"""\ In {parent.get_config_path()}: Invalid overriding of {d.group}: Default list overrides requires 'override' keyword. See {url} for more information. """) deprecation_warning(msg) if d.override: if not legacy_hydra_override: seen_override = True last_override_seen = d if interpolated_subtree: # Since interpolations are deferred for until all the config groups are already set, # Their subtree may not contain config group overrides raise ConfigCompositionException( dedent(f"""\ {parent.get_config_path()}: Default List Overrides are not allowed in the subtree of an in interpolated config group (override {d.get_override_key()}={d.get_name()}). """)) overrides.add_override(parent.get_config_path(), d)
def _normalize_file_name(filename: str) -> str: supported_extensions = [".yaml"] if not version.base_at_least("1.2"): supported_extensions.append(".yml") if filename.endswith(".yml"): deprecation_warning( "Support for .yml files is deprecated. Use .yaml extension for Hydra config files" ) if not any(filename.endswith(ext) for ext in supported_extensions): filename += ".yaml" return filename
def compose( config_name: Optional[str] = None, overrides: List[str] = [], return_hydra_config: bool = False, strict: Optional[bool] = None, ) -> DictConfig: """ :param config_name: the name of the config (usually the file name without the .yaml extension) :param overrides: list of overrides for config file :param return_hydra_config: True to return the hydra config node in the result :param strict: DEPRECATED. If false, returned config has struct mode disabled. :return: the composed config """ assert ( GlobalHydra().is_initialized() ), "GlobalHydra is not initialized, use @hydra.main() or call one of the hydra initialization methods first" gh = GlobalHydra.instance() assert gh.hydra is not None cfg = gh.hydra.compose_config( config_name=config_name, overrides=overrides, run_mode=RunMode.RUN, from_shell=False, with_log_configuration=False, ) assert isinstance(cfg, DictConfig) if not return_hydra_config: if "hydra" in cfg: with open_dict(cfg): del cfg["hydra"] if strict is not None: if version.base_at_least("1.2"): raise TypeError("got an unexpected 'strict' argument") else: deprecation_warning( dedent(""" The strict flag in the compose API is deprecated. See https://hydra.cc/docs/upgrades/0.11_to_1.0/strict_mode_flag_deprecated for more info. """)) OmegaConf.set_struct(cfg, strict) return cfg
def __init__( self, config_path: Optional[str] = _UNSPECIFIED_, version_base: Optional[str] = _UNSPECIFIED_, job_name: Optional[str] = None, caller_stack_depth: int = 1, ) -> None: self._gh_backup = get_gh_backup() version.setbase(version_base) if config_path is _UNSPECIFIED_: if version.base_at_least("1.2"): config_path = None elif version_base is _UNSPECIFIED_: url = "https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_hydra_main_config_path" deprecation_warning( message=dedent(f"""\ config_path is not specified in hydra.initialize(). See {url} for more information."""), stacklevel=2, ) config_path = "." else: config_path = "." if config_path is not None and os.path.isabs(config_path): raise HydraException( "config_path in initialize() must be relative") calling_file, calling_module = detect_calling_file_or_module_from_stack_frame( caller_stack_depth + 1) if job_name is None: job_name = detect_task_name(calling_file=calling_file, calling_module=calling_module) Hydra.create_main_hydra_file_or_module( calling_file=calling_file, calling_module=calling_module, config_path=config_path, job_name=job_name, )
def test_initialize_cur_version_base(hydra_restore_singletons: Any) -> None: assert not GlobalHydra().is_initialized() initialize(version_base=None) assert version.base_at_least(__version__)
def test_initialize_dev_version_base(hydra_restore_singletons: Any) -> None: assert not GlobalHydra().is_initialized() # packaging will compare "1.2.0.dev2" < "1.2", so need to ensure handled correctly initialize(version_base="1.2.0.dev2") assert version.base_at_least("1.2")
def run_job( task_function: TaskFunction, config: DictConfig, job_dir_key: str, job_subdir_key: Optional[str], hydra_context: HydraContext, configure_logging: bool = True, ) -> "JobReturn": _check_hydra_context(hydra_context) callbacks = hydra_context.callbacks old_cwd = os.getcwd() orig_hydra_cfg = HydraConfig.instance().cfg # init Hydra config for config evaluation HydraConfig.instance().set_config(config) output_dir = str(OmegaConf.select(config, job_dir_key)) if job_subdir_key is not None: # evaluate job_subdir_key lazily. # this is running on the client side in sweep and contains things such as job:id which # are only available there. subdir = str(OmegaConf.select(config, job_subdir_key)) output_dir = os.path.join(output_dir, subdir) with read_write(config.hydra.runtime): with open_dict(config.hydra.runtime): config.hydra.runtime.output_dir = os.path.abspath(output_dir) # update Hydra config HydraConfig.instance().set_config(config) _chdir = None try: ret = JobReturn() task_cfg = copy.deepcopy(config) with read_write(task_cfg): with open_dict(task_cfg): del task_cfg["hydra"] ret.cfg = task_cfg hydra_cfg = copy.deepcopy(HydraConfig.instance().cfg) assert isinstance(hydra_cfg, DictConfig) ret.hydra_cfg = hydra_cfg overrides = OmegaConf.to_container(config.hydra.overrides.task) assert isinstance(overrides, list) ret.overrides = overrides # handle output directories here Path(str(output_dir)).mkdir(parents=True, exist_ok=True) _chdir = hydra_cfg.hydra.job.chdir if _chdir is None: if version.base_at_least("1.2"): _chdir = False if _chdir is None: url = "https://hydra.cc/docs/next/upgrades/1.1_to_1.2/changes_to_job_working_dir/" deprecation_warning( message=dedent(f"""\ Future Hydra versions will no longer change working directory at job runtime by default. See {url} for more information."""), stacklevel=2, ) _chdir = True if _chdir: os.chdir(output_dir) ret.working_dir = output_dir else: ret.working_dir = os.getcwd() if configure_logging: configure_log(config.hydra.job_logging, config.hydra.verbose) if config.hydra.output_subdir is not None: hydra_output = Path(config.hydra.runtime.output_dir) / Path( config.hydra.output_subdir) _save_config(task_cfg, "config.yaml", hydra_output) _save_config(hydra_cfg, "hydra.yaml", hydra_output) _save_config(config.hydra.overrides.task, "overrides.yaml", hydra_output) with env_override(hydra_cfg.hydra.job.env_set): callbacks.on_job_start(config=config) try: ret.return_value = task_function(task_cfg) ret.status = JobStatus.COMPLETED except Exception as e: ret.return_value = e ret.status = JobStatus.FAILED ret.task_name = JobRuntime.instance().get("name") _flush_loggers() callbacks.on_job_end(config=config, job_return=ret) return ret finally: HydraConfig.instance().cfg = orig_hydra_cfg if _chdir: os.chdir(old_cwd)
def _create_defaults_list( self, config_path: str, defaults: ListConfig, ) -> List[InputDefault]: def issue_deprecated_name_warning() -> None: # DEPRECATED: remove in 1.2 url = "https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header" deprecation_warning(message=dedent(f"""\ In {config_path}: Defaults List contains deprecated keyword _name_, see {url} """), ) res: List[InputDefault] = [] for item in defaults._iter_ex(resolve=False): default: InputDefault if isinstance(item, DictConfig): if not version.base_at_least("1.2"): old_optional = None if len(item) > 1: if "optional" in item: old_optional = item.pop("optional") keys = list(item.keys()) if len(keys) > 1: raise ValueError( f"In {config_path}: Too many keys in default item {item}" ) if len(keys) == 0: raise ValueError( f"In {config_path}: Missing group name in {item}") key = keys[0] assert isinstance(key, str) config_group, package, _package2 = self._split_group(key) keywords = ConfigRepository.Keywords() self._extract_keywords_from_config_group( config_group, keywords) if not version.base_at_least("1.2"): if not keywords.optional and old_optional is not None: keywords.optional = old_optional node = item._get_node(key) assert node is not None and isinstance(node, Node) config_value = node._value() if not version.base_at_least("1.2"): if old_optional is not None: msg = dedent(f""" In {config_path}: 'optional: true' is deprecated. Use 'optional {key}: {config_value}' instead. Support for the old style is removed for Hydra version_base >= 1.2""" ) deprecation_warning(msg) if config_value is not None and not isinstance( config_value, (str, list)): raise ValueError( f"Unsupported item value in defaults : {type(config_value).__name__}." " Supported: string or list") if isinstance(config_value, list): options = [] for v in config_value: vv = v._value() if not isinstance(vv, str): raise ValueError( f"Unsupported item value in defaults : {type(vv).__name__}," " nested list items must be strings") options.append(vv) config_value = options if package is not None and "_name_" in package: issue_deprecated_name_warning() default = GroupDefault( group=keywords.group, value=config_value, package=package, optional=keywords.optional, override=keywords.override, ) elif isinstance(item, str): path, package, _package2 = self._split_group(item) if package is not None and "_name_" in package: issue_deprecated_name_warning() default = ConfigDefault(path=path, package=package) else: raise ValueError( f"Unsupported type in defaults : {type(item).__name__}") res.append(default) return res