コード例 #1
0
    def _merge_config(self, cfg: DictConfig, family: str, name: str,
                      required: bool) -> DictConfig:
        try:
            if family != "":
                new_cfg = f"{family}/{name}"
            else:
                new_cfg = name

            loaded_cfg, _ = self._load_config_impl(new_cfg)
            if loaded_cfg is None:
                if required:
                    if family == "":
                        msg = f"Could not load {new_cfg}"
                        raise MissingConfigException(msg, new_cfg)
                    else:
                        options = self.get_group_options(family)
                        if options:
                            lst = "\n\t".join(options)
                            msg = f"Could not load {new_cfg}, available options:\n{family}:\n\t{lst}"
                        else:
                            msg = f"Could not load {new_cfg}"
                        raise MissingConfigException(msg, new_cfg, options)
                else:
                    return cfg

            else:
                ret = OmegaConf.merge(cfg, loaded_cfg)
                assert isinstance(ret, DictConfig)
                return ret
        except OmegaConfBaseException as ex:
            raise HydraException(f"Error merging {family}={name}") from ex
コード例 #2
0
    def _merge_config(
        self, cfg: DictConfig, family: str, name: str, required: bool
    ) -> DictConfig:

        if family != "":
            new_cfg = "{}/{}".format(family, name)
        else:
            new_cfg = name

        loaded_cfg = self._load_config_impl(new_cfg)
        if loaded_cfg is None:
            if required:
                if family == "":
                    msg = "Could not load {}".format(new_cfg)
                    raise MissingConfigException(msg, new_cfg)
                else:
                    options = self.get_group_options(family)
                    if options:
                        msg = "Could not load {}, available options:\n{}:\n\t{}".format(
                            new_cfg, family, "\n\t".join(options)
                        )
                    else:
                        msg = "Could not load {}".format(new_cfg)
                    raise MissingConfigException(msg, new_cfg, options)
            else:
                return cfg

        else:
            ret = OmegaConf.merge(cfg, loaded_cfg)
            assert isinstance(ret, DictConfig)
            return ret
コード例 #3
0
def missing_config_error(repo: IConfigRepository, element: DefaultElement) -> None:
    options = None
    if element.config_group is not None:
        options = repo.get_group_options(element.config_group, ObjectType.CONFIG)
        opt_list = "\n".join(["\t" + x for x in options])
        msg = (
            f"Could not find '{element.config_name}' in the config group '{element.config_group}'"
            f"\nAvailable options:\n{opt_list}\n"
        )
    else:
        msg = dedent(
            f"""\
        Could not load {element.config_path()}.
        """
        )

    descs = []
    for src in repo.get_sources():
        descs.append(f"\t{repr(src)}")
    lines = "\n".join(descs)
    msg += "\nConfig search path:" + f"\n{lines}"

    raise MissingConfigException(
        missing_cfg_file=element.config_path(),
        message=msg,
        options=options,
    )
コード例 #4
0
ファイル: config_loader_impl.py プロジェクト: roopeshvs/hydra
    def _merge_config(
        self,
        cfg: DictConfig,
        config_group: str,
        name: str,
        required: bool,
        is_primary_config: bool,
        package_override: Optional[str],
    ) -> DictConfig:
        try:
            if config_group != "":
                new_cfg = f"{config_group}/{name}"
            else:
                new_cfg = name

            loaded_cfg, _ = self._load_config_impl(
                new_cfg,
                is_primary_config=is_primary_config,
                package_override=package_override,
            )
            if loaded_cfg is None:
                if required:
                    if config_group == "":
                        msg = f"Could not load {new_cfg}"
                        raise MissingConfigException(msg, new_cfg)
                    else:
                        options = self.get_group_options(config_group)
                        if options:
                            opt_list = "\n".join(["\t" + x for x in options])
                            msg = (
                                f"Could not load {new_cfg}.\nAvailable options:"
                                f"\n{opt_list}"
                            )
                        else:
                            msg = f"Could not load {new_cfg}"
                        raise MissingConfigException(msg, new_cfg, options)
                else:
                    return cfg

            else:
                ret = OmegaConf.merge(cfg, loaded_cfg)
                assert isinstance(ret, DictConfig)
                return ret
        except OmegaConfBaseException as ex:
            raise ConfigCompositionException(
                f"Error merging {config_group}={name}"
            ) from ex
コード例 #5
0
    def _missing_config_error(self, config_name: Optional[str], msg: str,
                              with_search_path: bool) -> None:
        def add_search_path() -> str:
            descs = []
            for src in self.repository.get_sources():
                if src.provider != "schema":
                    descs.append(f"\t{repr(src)}")
            lines = "\n".join(descs)

            if with_search_path:
                return msg + "\nSearch path:" + f"\n{lines}"
            else:
                return msg

        raise MissingConfigException(missing_cfg_file=config_name,
                                     message=add_search_path())
コード例 #6
0
    def __init__(self, task_name: str, config_loader: ConfigLoader) -> None:
        """
        :param task_name: task name
        :param config_loader: config loader
        """
        setup_globals()
        self.config_loader = config_loader

        for source in config_loader.get_sources():
            # if specified, make sure main config search path exists
            if source.provider == "main":
                if not source.exists(""):
                    raise MissingConfigException(
                        missing_cfg_file=source.path,
                        message=f"Primary config dir not found: {source}",
                    )

        JobRuntime().set("name", task_name)
コード例 #7
0
ファイル: hydra.py プロジェクト: pseeth/hydra
    def compose_config(
        self,
        config_name: Optional[str],
        overrides: List[str],
        strict: Optional[bool] = None,
        with_log_configuration: bool = False,
    ) -> DictConfig:
        """
        :param self:
        :param config_name:
        :param overrides:
        :param with_log_configuration: True to configure logging subsystem from the loaded config
        :param strict: None for default behavior (default to true for config file, false if no config file).
                       otherwise forces specific behavior.
        :return:
        """

        for source in self.config_loader.get_sources():
            # if specified, make sure main config search path exists
            if source.provider == "main":
                if not source.exists(""):
                    raise MissingConfigException(
                        missing_cfg_file=source.path,
                        message=f"Primary config dir not found: {source}",
                    )

        cfg = self.config_loader.load_configuration(config_name=config_name,
                                                    overrides=overrides,
                                                    strict=strict)
        with open_dict(cfg):
            from hydra import __version__

            cfg.hydra.runtime.version = __version__
            cfg.hydra.runtime.cwd = os.getcwd()
        if with_log_configuration:
            configure_log(cfg.hydra.hydra_logging, cfg.hydra.verbose)
            global log
            log = logging.getLogger(__name__)
            self._print_debug_info()
        return cfg
コード例 #8
0
    def load_configuration(
        self,
        config_file: Optional[str],
        overrides: List[str],
        strict: Optional[bool] = None,
    ) -> DictConfig:
        assert config_file is None or isinstance(config_file, str)
        assert strict is None or isinstance(strict, bool)
        assert isinstance(overrides, list)
        if strict is None:
            strict = self.default_strict

        assert overrides is None or isinstance(overrides, list)
        overrides = copy.deepcopy(overrides) or []

        if config_file is not None and not self.exists_in_search_path(config_file):
            raise MissingConfigException(
                missing_cfg_file=config_file,
                message="Cannot find primary config file: {}\nSearch path:\n{}".format(
                    config_file,
                    "\n".join(
                        [
                            "\t{} (from {})".format(src.path, src.provider)
                            for src in self.repository.sources
                        ]
                    ),
                ),
            )

        # Load hydra config
        hydra_cfg = self._create_cfg(cfg_filename="hydra.yaml")

        # Load job config
        job_cfg = self._create_cfg(cfg_filename=config_file, record_load=False)

        defaults = ConfigLoader._get_defaults(hydra_cfg)
        if config_file is not None:
            defaults.append(config_file)
        split_at = len(defaults)
        job_defaults = ConfigLoader._get_defaults(job_cfg)
        ConfigLoader._merge_default_lists(defaults, job_defaults)
        consumed = self._apply_defaults_overrides(overrides, defaults)

        consumed_free_job_defaults = self._apply_free_defaults(defaults, overrides)

        ConfigLoader._validate_defaults(defaults)

        # Load and defaults and merge them into cfg
        cfg = self._merge_defaults(hydra_cfg, defaults, split_at)
        OmegaConf.set_struct(cfg.hydra, True)
        OmegaConf.set_struct(cfg, strict)

        # Merge all command line overrides after enabling strict flag
        all_consumed = consumed + consumed_free_job_defaults
        remaining_overrides = [x for x in overrides if x not in all_consumed]
        cfg.merge_with_dotlist(remaining_overrides)

        remaining = consumed + consumed_free_job_defaults + remaining_overrides

        def is_hydra(x: str) -> bool:
            return x.startswith("hydra.") or x.startswith("hydra/")

        cfg.hydra.overrides.task = [x for x in remaining if not is_hydra(x)]
        cfg.hydra.overrides.hydra = [x for x in remaining if is_hydra(x)]

        with open_dict(cfg.hydra.job):
            if "name" not in cfg.hydra.job:
                cfg.hydra.job.name = JobRuntime().get("name")
            cfg.hydra.job.override_dirname = get_overrides_dirname(
                input_list=cfg.hydra.overrides.task,
                kv_sep=cfg.hydra.job.config.override_dirname.kv_sep,
                item_sep=cfg.hydra.job.config.override_dirname.item_sep,
                exclude_keys=cfg.hydra.job.config.override_dirname.exclude_keys,
            )
            cfg.hydra.job.config_file = config_file

        return cfg
コード例 #9
0
    def load_configuration(
        self,
        config_name: Optional[str],
        overrides: List[str],
        strict: Optional[bool] = None,
    ) -> DictConfig:
        if strict is None:
            strict = self.default_strict

        parsed_overrides = [
            self._parse_override(override) for override in overrides
        ]

        if config_name is not None and not self.repository.config_exists(
                config_name):
            # TODO: handle schema as a special case
            descs = [
                f"\t{src.path} (from {src.provider})"
                for src in self.repository.get_sources()
            ]
            lines = "\n".join(descs)
            raise MissingConfigException(
                missing_cfg_file=config_name,
                message=
                f"Cannot find primary config file: {config_name}\nSearch path:\n{lines}",
            )

        # Load hydra config
        hydra_cfg, _load_trace = self._load_primary_config(
            cfg_filename="hydra_config")

        # Load job config
        job_cfg, job_cfg_load_trace = self._load_primary_config(
            cfg_filename=config_name, record_load=False)

        job_defaults = self._parse_defaults(job_cfg)
        defaults = self._parse_defaults(hydra_cfg)

        job_cfg_type = OmegaConf.get_type(job_cfg)
        if job_cfg_type is not None and not issubclass(job_cfg_type, dict):
            hydra_cfg._promote(job_cfg_type)
            # this is breaking encapsulation a bit. can potentially be implemented in OmegaConf
            hydra_cfg._metadata.ref_type = job_cfg._metadata.ref_type

        # if defaults are re-introduced by the promotion, remove it.
        if "defaults" in hydra_cfg:
            with open_dict(hydra_cfg):
                del hydra_cfg["defaults"]

        if config_name is not None:
            defaults.append(
                DefaultElement(config_group=None, config_name="__SELF__"))
        split_at = len(defaults)

        config_group_overrides, config_overrides = self.split_overrides(
            parsed_overrides)
        self._combine_default_lists(defaults, job_defaults)
        ConfigLoaderImpl._apply_overrides_to_defaults(config_group_overrides,
                                                      defaults)

        # Load and defaults and merge them into cfg
        cfg = self._merge_defaults_into_config(hydra_cfg, job_cfg,
                                               job_cfg_load_trace, defaults,
                                               split_at)
        OmegaConf.set_struct(cfg.hydra, True)
        OmegaConf.set_struct(cfg, strict)

        # Merge all command line overrides after enabling strict flag
        ConfigLoaderImpl._apply_overrides_to_config(
            [x.input_line for x in config_overrides], cfg)

        app_overrides = []
        for pwl in parsed_overrides:
            override = pwl.override
            assert override.key is not None
            key = override.key
            if key.startswith("hydra.") or key.startswith("hydra/"):
                cfg.hydra.overrides.hydra.append(pwl.input_line)
            else:
                cfg.hydra.overrides.task.append(pwl.input_line)
                app_overrides.append(pwl)

        with open_dict(cfg.hydra.job):
            if "name" not in cfg.hydra.job:
                cfg.hydra.job.name = JobRuntime().get("name")
            cfg.hydra.job.override_dirname = get_overrides_dirname(
                input_list=app_overrides,
                kv_sep=cfg.hydra.job.config.override_dirname.kv_sep,
                item_sep=cfg.hydra.job.config.override_dirname.item_sep,
                exclude_keys=cfg.hydra.job.config.override_dirname.
                exclude_keys,
            )
            cfg.hydra.job.config_name = config_name

            for key in cfg.hydra.job.env_copy:
                cfg.hydra.job.env_set[key] = os.environ[key]

        return cfg
コード例 #10
0
    def load_configuration(
        self,
        config_name: Optional[str],
        overrides: List[str],
        strict: Optional[bool] = None,
    ) -> DictConfig:
        assert config_name is None or isinstance(config_name, str)
        assert strict is None or isinstance(strict, bool)
        assert isinstance(overrides, list)
        if strict is None:
            strict = self.default_strict

        assert overrides is None or isinstance(overrides, list)
        overrides = copy.deepcopy(overrides) or []

        if config_name is not None and not self.exists_in_search_path(
                config_name):
            # TODO: handle schema as a special case
            descs = [
                "\t{} (from {})".format(src.path, src.provider)
                for src in self.repository.get_sources()
            ]
            lines = "\n".join(descs)
            raise MissingConfigException(
                missing_cfg_file=config_name,
                message=
                f"Cannot find primary config file: {config_name}\nSearch path:\n{lines}",
            )

        # Load hydra config
        hydra_cfg, _load_trace = self._create_cfg(cfg_filename="hydra_config")

        # Load job config
        job_cfg, job_cfg_load_trace = self._create_cfg(
            cfg_filename=config_name, record_load=False)

        job_defaults = ConfigLoaderImpl._get_defaults(job_cfg)
        defaults = ConfigLoaderImpl._get_defaults(hydra_cfg)

        job_cfg_type = OmegaConf.get_type(job_cfg)
        if job_cfg_type is not None and not issubclass(job_cfg_type, dict):
            hydra_cfg._promote(job_cfg_type)
            # this is breaking encapsulation a bit. can potentially be implemented in OmegaConf
            hydra_cfg._metadata.ref_type = job_cfg._metadata.ref_type

        # if defaults are re-introduced by the promotion, remove it.
        if "defaults" in hydra_cfg:
            with open_dict(hydra_cfg):
                del hydra_cfg["defaults"]

        if config_name is not None:
            defaults.append("__SELF__")
        split_at = len(defaults)

        ConfigLoaderImpl._merge_default_lists(defaults, job_defaults)
        consumed = self._apply_defaults_overrides(overrides, defaults)

        consumed_free_job_defaults = self._apply_free_defaults(
            defaults, overrides)

        ConfigLoaderImpl._validate_defaults(defaults)

        # Load and defaults and merge them into cfg
        cfg = self._merge_defaults(hydra_cfg, job_cfg, job_cfg_load_trace,
                                   defaults, split_at)
        OmegaConf.set_struct(cfg.hydra, True)
        OmegaConf.set_struct(cfg, strict)

        # Merge all command line overrides after enabling strict flag
        all_consumed = consumed + consumed_free_job_defaults
        remaining_overrides = [x for x in overrides if x not in all_consumed]
        cfg.merge_with_dotlist(remaining_overrides)

        remaining = consumed + consumed_free_job_defaults + remaining_overrides

        def is_hydra(x: str) -> bool:
            return x.startswith("hydra.") or x.startswith("hydra/")

        cfg.hydra.overrides.task = [x for x in remaining if not is_hydra(x)]
        cfg.hydra.overrides.hydra = [x for x in remaining if is_hydra(x)]

        with open_dict(cfg.hydra.job):
            if "name" not in cfg.hydra.job:
                cfg.hydra.job.name = JobRuntime().get("name")
            cfg.hydra.job.override_dirname = get_overrides_dirname(
                input_list=cfg.hydra.overrides.task,
                kv_sep=cfg.hydra.job.config.override_dirname.kv_sep,
                item_sep=cfg.hydra.job.config.override_dirname.item_sep,
                exclude_keys=cfg.hydra.job.config.override_dirname.
                exclude_keys,
            )
            cfg.hydra.job.config_name = config_name

        return cfg
コード例 #11
0
ファイル: config_loader_impl.py プロジェクト: edraizen/hydra
    def load_configuration(
        self,
        config_name: Optional[str],
        overrides: List[str],
        strict: Optional[bool] = None,
    ) -> DictConfig:
        if strict is None:
            strict = self.default_strict

        overrides = copy.deepcopy(overrides) or []

        if config_name is not None and not self.exists_in_search_path(
                config_name):
            # TODO: handle schema as a special case
            descs = [
                f"\t{src.path} (from {src.provider})"
                for src in self.repository.get_sources()
            ]
            lines = "\n".join(descs)
            raise MissingConfigException(
                missing_cfg_file=config_name,
                message=
                f"Cannot find primary config file: {config_name}\nSearch path:\n{lines}",
            )

        # Load hydra config
        hydra_cfg, _load_trace = self._create_cfg(cfg_filename="hydra_config")

        # Load job config
        job_cfg, job_cfg_load_trace = self._create_cfg(
            cfg_filename=config_name, record_load=False)

        job_defaults = self._parse_defaults(job_cfg)
        defaults = self._parse_defaults(hydra_cfg)

        job_cfg_type = OmegaConf.get_type(job_cfg)
        if job_cfg_type is not None and not issubclass(job_cfg_type, dict):
            hydra_cfg._promote(job_cfg_type)
            # this is breaking encapsulation a bit. can potentially be implemented in OmegaConf
            hydra_cfg._metadata.ref_type = job_cfg._metadata.ref_type

        # if defaults are re-introduced by the promotion, remove it.
        if "defaults" in hydra_cfg:
            with open_dict(hydra_cfg):
                del hydra_cfg["defaults"]

        if config_name is not None:
            defaults.append(
                DefaultElement(config_group=None, config_name="__SELF__"))
        split_at = len(defaults)

        self._combine_default_lists(defaults, job_defaults)
        consumed = self._apply_overrides_to_defaults(overrides, defaults)

        # Load and defaults and merge them into cfg
        cfg = self._merge_defaults_into_config(hydra_cfg, job_cfg,
                                               job_cfg_load_trace, defaults,
                                               split_at)
        OmegaConf.set_struct(cfg.hydra, True)
        OmegaConf.set_struct(cfg, strict)

        # Merge all command line overrides after enabling strict flag
        remaining_overrides = [x for x in overrides if x not in consumed]
        try:
            merged = OmegaConf.merge(
                cfg, OmegaConf.from_dotlist(remaining_overrides))
            assert isinstance(merged, DictConfig)
            cfg = merged
        except OmegaConfBaseException as ex:
            raise HydraException("Error merging overrides") from ex

        remaining = consumed + remaining_overrides

        def is_hydra(x: str) -> bool:
            return x.startswith("hydra.") or x.startswith("hydra/")

        cfg.hydra.overrides.task = [x for x in remaining if not is_hydra(x)]
        cfg.hydra.overrides.hydra = [x for x in remaining if is_hydra(x)]

        with open_dict(cfg.hydra.job):
            if "name" not in cfg.hydra.job:
                cfg.hydra.job.name = JobRuntime().get("name")
            cfg.hydra.job.override_dirname = get_overrides_dirname(
                input_list=cfg.hydra.overrides.task,
                kv_sep=cfg.hydra.job.config.override_dirname.kv_sep,
                item_sep=cfg.hydra.job.config.override_dirname.item_sep,
                exclude_keys=cfg.hydra.job.config.override_dirname.
                exclude_keys,
            )
            cfg.hydra.job.config_name = config_name

            for key in cfg.hydra.job.env_copy:
                cfg.hydra.job.env_set[key] = os.environ[key]

        return cfg