예제 #1
0
def test_initialize_compat_version_base(hydra_restore_singletons: Any) -> None:
    assert not GlobalHydra().is_initialized()
    with raises(
        UserWarning,
        match=f"Will assume defaults for version {version.__compat_version__}",
    ):
        initialize()
    assert version.base_at_least(str(version.__compat_version__))
예제 #2
0
def _update_overrides(
    defaults_list: List[InputDefault],
    overrides: Overrides,
    parent: InputDefault,
    interpolated_subtree: bool,
) -> None:
    seen_override = False
    last_override_seen = None
    for d in defaults_list:
        if d.is_self():
            continue
        d.update_parent(parent.get_group_path(), parent.get_final_package())

        legacy_hydra_override = False
        if isinstance(d, GroupDefault):
            assert d.group is not None
            if not version.base_at_least("1.2"):
                legacy_hydra_override = not d.is_override(
                ) and d.group.startswith("hydra/")

        if seen_override and not (d.is_override() or d.is_external_append()
                                  or legacy_hydra_override):
            assert isinstance(last_override_seen, GroupDefault)
            pcp = parent.get_config_path()
            okey = last_override_seen.get_override_key()
            oval = last_override_seen.get_name()
            raise ConfigCompositionException(
                dedent(f"""\
                    In {pcp}: Override '{okey} : {oval}' is defined before '{d.get_override_key()}: {d.get_name()}'.
                    Overrides must be at the end of the defaults list"""))

        if isinstance(d, GroupDefault):
            if legacy_hydra_override:
                d.override = True
                url = "https://hydra.cc/docs/next/upgrades/1.0_to_1.1/defaults_list_override"
                msg = dedent(f"""\
                    In {parent.get_config_path()}: Invalid overriding of {d.group}:
                    Default list overrides requires 'override' keyword.
                    See {url} for more information.
                    """)
                deprecation_warning(msg)

            if d.override:
                if not legacy_hydra_override:
                    seen_override = True
                last_override_seen = d
                if interpolated_subtree:
                    # Since interpolations are deferred for until all the config groups are already set,
                    # Their subtree may not contain config group overrides
                    raise ConfigCompositionException(
                        dedent(f"""\
                            {parent.get_config_path()}: Default List Overrides are not allowed in the subtree
                            of an in interpolated config group (override {d.get_override_key()}={d.get_name()}).
                            """))
                overrides.add_override(parent.get_config_path(), d)
예제 #3
0
 def _normalize_file_name(filename: str) -> str:
     supported_extensions = [".yaml"]
     if not version.base_at_least("1.2"):
         supported_extensions.append(".yml")
         if filename.endswith(".yml"):
             deprecation_warning(
                 "Support for .yml files is deprecated. Use .yaml extension for Hydra config files"
             )
     if not any(filename.endswith(ext) for ext in supported_extensions):
         filename += ".yaml"
     return filename
예제 #4
0
def compose(
    config_name: Optional[str] = None,
    overrides: List[str] = [],
    return_hydra_config: bool = False,
    strict: Optional[bool] = None,
) -> DictConfig:
    """
    :param config_name: the name of the config
           (usually the file name without the .yaml extension)
    :param overrides: list of overrides for config file
    :param return_hydra_config: True to return the hydra config node in the result
    :param strict: DEPRECATED. If false, returned config has struct mode disabled.
    :return: the composed config
    """
    assert (
        GlobalHydra().is_initialized()
    ), "GlobalHydra is not initialized, use @hydra.main() or call one of the hydra initialization methods first"

    gh = GlobalHydra.instance()
    assert gh.hydra is not None
    cfg = gh.hydra.compose_config(
        config_name=config_name,
        overrides=overrides,
        run_mode=RunMode.RUN,
        from_shell=False,
        with_log_configuration=False,
    )
    assert isinstance(cfg, DictConfig)

    if not return_hydra_config:
        if "hydra" in cfg:
            with open_dict(cfg):
                del cfg["hydra"]

    if strict is not None:
        if version.base_at_least("1.2"):
            raise TypeError("got an unexpected 'strict' argument")
        else:
            deprecation_warning(
                dedent("""
                    The strict flag in the compose API is deprecated.
                    See https://hydra.cc/docs/upgrades/0.11_to_1.0/strict_mode_flag_deprecated for more info.
                    """))
            OmegaConf.set_struct(cfg, strict)

    return cfg
예제 #5
0
    def __init__(
        self,
        config_path: Optional[str] = _UNSPECIFIED_,
        version_base: Optional[str] = _UNSPECIFIED_,
        job_name: Optional[str] = None,
        caller_stack_depth: int = 1,
    ) -> None:
        self._gh_backup = get_gh_backup()

        version.setbase(version_base)

        if config_path is _UNSPECIFIED_:
            if version.base_at_least("1.2"):
                config_path = None
            elif version_base is _UNSPECIFIED_:
                url = "https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_hydra_main_config_path"
                deprecation_warning(
                    message=dedent(f"""\
                    config_path is not specified in hydra.initialize().
                    See {url} for more information."""),
                    stacklevel=2,
                )
                config_path = "."
            else:
                config_path = "."

        if config_path is not None and os.path.isabs(config_path):
            raise HydraException(
                "config_path in initialize() must be relative")
        calling_file, calling_module = detect_calling_file_or_module_from_stack_frame(
            caller_stack_depth + 1)
        if job_name is None:
            job_name = detect_task_name(calling_file=calling_file,
                                        calling_module=calling_module)

        Hydra.create_main_hydra_file_or_module(
            calling_file=calling_file,
            calling_module=calling_module,
            config_path=config_path,
            job_name=job_name,
        )
예제 #6
0
def test_initialize_cur_version_base(hydra_restore_singletons: Any) -> None:
    assert not GlobalHydra().is_initialized()
    initialize(version_base=None)
    assert version.base_at_least(__version__)
예제 #7
0
def test_initialize_dev_version_base(hydra_restore_singletons: Any) -> None:
    assert not GlobalHydra().is_initialized()
    # packaging will compare "1.2.0.dev2" < "1.2", so need to ensure handled correctly
    initialize(version_base="1.2.0.dev2")
    assert version.base_at_least("1.2")
예제 #8
0
def run_job(
    task_function: TaskFunction,
    config: DictConfig,
    job_dir_key: str,
    job_subdir_key: Optional[str],
    hydra_context: HydraContext,
    configure_logging: bool = True,
) -> "JobReturn":
    _check_hydra_context(hydra_context)
    callbacks = hydra_context.callbacks

    old_cwd = os.getcwd()
    orig_hydra_cfg = HydraConfig.instance().cfg

    # init Hydra config for config evaluation
    HydraConfig.instance().set_config(config)

    output_dir = str(OmegaConf.select(config, job_dir_key))
    if job_subdir_key is not None:
        # evaluate job_subdir_key lazily.
        # this is running on the client side in sweep and contains things such as job:id which
        # are only available there.
        subdir = str(OmegaConf.select(config, job_subdir_key))
        output_dir = os.path.join(output_dir, subdir)

    with read_write(config.hydra.runtime):
        with open_dict(config.hydra.runtime):
            config.hydra.runtime.output_dir = os.path.abspath(output_dir)

    # update Hydra config
    HydraConfig.instance().set_config(config)
    _chdir = None
    try:
        ret = JobReturn()
        task_cfg = copy.deepcopy(config)
        with read_write(task_cfg):
            with open_dict(task_cfg):
                del task_cfg["hydra"]

        ret.cfg = task_cfg
        hydra_cfg = copy.deepcopy(HydraConfig.instance().cfg)
        assert isinstance(hydra_cfg, DictConfig)
        ret.hydra_cfg = hydra_cfg
        overrides = OmegaConf.to_container(config.hydra.overrides.task)
        assert isinstance(overrides, list)
        ret.overrides = overrides
        # handle output directories here
        Path(str(output_dir)).mkdir(parents=True, exist_ok=True)

        _chdir = hydra_cfg.hydra.job.chdir

        if _chdir is None:
            if version.base_at_least("1.2"):
                _chdir = False

        if _chdir is None:
            url = "https://hydra.cc/docs/next/upgrades/1.1_to_1.2/changes_to_job_working_dir/"
            deprecation_warning(
                message=dedent(f"""\
                    Future Hydra versions will no longer change working directory at job runtime by default.
                    See {url} for more information."""),
                stacklevel=2,
            )
            _chdir = True

        if _chdir:
            os.chdir(output_dir)
            ret.working_dir = output_dir
        else:
            ret.working_dir = os.getcwd()

        if configure_logging:
            configure_log(config.hydra.job_logging, config.hydra.verbose)

        if config.hydra.output_subdir is not None:
            hydra_output = Path(config.hydra.runtime.output_dir) / Path(
                config.hydra.output_subdir)
            _save_config(task_cfg, "config.yaml", hydra_output)
            _save_config(hydra_cfg, "hydra.yaml", hydra_output)
            _save_config(config.hydra.overrides.task, "overrides.yaml",
                         hydra_output)

        with env_override(hydra_cfg.hydra.job.env_set):
            callbacks.on_job_start(config=config)
            try:
                ret.return_value = task_function(task_cfg)
                ret.status = JobStatus.COMPLETED
            except Exception as e:
                ret.return_value = e
                ret.status = JobStatus.FAILED

        ret.task_name = JobRuntime.instance().get("name")

        _flush_loggers()

        callbacks.on_job_end(config=config, job_return=ret)

        return ret
    finally:
        HydraConfig.instance().cfg = orig_hydra_cfg
        if _chdir:
            os.chdir(old_cwd)
예제 #9
0
    def _create_defaults_list(
        self,
        config_path: str,
        defaults: ListConfig,
    ) -> List[InputDefault]:
        def issue_deprecated_name_warning() -> None:
            # DEPRECATED: remove in 1.2
            url = "https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header"
            deprecation_warning(message=dedent(f"""\
                    In {config_path}: Defaults List contains deprecated keyword _name_, see {url}
                    """), )

        res: List[InputDefault] = []
        for item in defaults._iter_ex(resolve=False):
            default: InputDefault
            if isinstance(item, DictConfig):
                if not version.base_at_least("1.2"):
                    old_optional = None
                    if len(item) > 1:
                        if "optional" in item:
                            old_optional = item.pop("optional")
                keys = list(item.keys())

                if len(keys) > 1:
                    raise ValueError(
                        f"In {config_path}: Too many keys in default item {item}"
                    )
                if len(keys) == 0:
                    raise ValueError(
                        f"In {config_path}: Missing group name in {item}")

                key = keys[0]
                assert isinstance(key, str)
                config_group, package, _package2 = self._split_group(key)
                keywords = ConfigRepository.Keywords()
                self._extract_keywords_from_config_group(
                    config_group, keywords)

                if not version.base_at_least("1.2"):
                    if not keywords.optional and old_optional is not None:
                        keywords.optional = old_optional

                node = item._get_node(key)
                assert node is not None and isinstance(node, Node)
                config_value = node._value()

                if not version.base_at_least("1.2"):
                    if old_optional is not None:
                        msg = dedent(f"""
                            In {config_path}: 'optional: true' is deprecated.
                            Use 'optional {key}: {config_value}' instead.
                            Support for the old style is removed for Hydra version_base >= 1.2"""
                                     )

                        deprecation_warning(msg)

                if config_value is not None and not isinstance(
                        config_value, (str, list)):
                    raise ValueError(
                        f"Unsupported item value in defaults : {type(config_value).__name__}."
                        " Supported: string or list")

                if isinstance(config_value, list):
                    options = []
                    for v in config_value:
                        vv = v._value()
                        if not isinstance(vv, str):
                            raise ValueError(
                                f"Unsupported item value in defaults : {type(vv).__name__},"
                                " nested list items must be strings")
                        options.append(vv)
                    config_value = options

                if package is not None and "_name_" in package:
                    issue_deprecated_name_warning()

                default = GroupDefault(
                    group=keywords.group,
                    value=config_value,
                    package=package,
                    optional=keywords.optional,
                    override=keywords.override,
                )

            elif isinstance(item, str):
                path, package, _package2 = self._split_group(item)
                if package is not None and "_name_" in package:
                    issue_deprecated_name_warning()

                default = ConfigDefault(path=path, package=package)
            else:
                raise ValueError(
                    f"Unsupported type in defaults : {type(item).__name__}")
            res.append(default)
        return res