Esempio n. 1
0
    def load_configuration(
        self,
        config_file: Optional[str],
        overrides: List[str],
        strict: Optional[bool] = None,
    ) -> DictConfig:
        assert config_file is None or isinstance(config_file, str)
        assert strict is None or isinstance(strict, bool)
        assert isinstance(overrides, list)
        if strict is None:
            strict = self.default_strict

        assert overrides is None or isinstance(overrides, list)
        overrides = copy.deepcopy(overrides) or []

        if config_file is not None and not self.exists_in_search_path(
                config_file):
            raise MissingConfigException(
                missing_cfg_file=config_file,
                message="Cannot find primary config file: {}\nSearch path:\n{}"
                .format(
                    config_file,
                    "\n".join([
                        "\t{} (from {})".format(src.path, src.provider)
                        for src in self.repository.get_sources()
                    ]),
                ),
            )

        # Load hydra config
        hydra_cfg = self._create_cfg(cfg_filename="hydra.yaml")

        # Load job config
        job_cfg = self._create_cfg(cfg_filename=config_file, record_load=False)

        defaults = ConfigLoaderImpl._get_defaults(hydra_cfg)
        if config_file is not None:
            defaults.append(config_file)
        split_at = len(defaults)
        job_defaults = ConfigLoaderImpl._get_defaults(job_cfg)
        ConfigLoaderImpl._merge_default_lists(defaults, job_defaults)
        consumed = self._apply_defaults_overrides(overrides, defaults)

        consumed_free_job_defaults = self._apply_free_defaults(
            defaults, overrides)

        ConfigLoaderImpl._validate_defaults(defaults)

        # Load and defaults and merge them into cfg
        cfg = self._merge_defaults(hydra_cfg, defaults, split_at)
        OmegaConf.set_struct(cfg.hydra, True)
        OmegaConf.set_struct(cfg, strict)

        # Merge all command line overrides after enabling strict flag
        all_consumed = consumed + consumed_free_job_defaults
        remaining_overrides = [x for x in overrides if x not in all_consumed]
        cfg.merge_with_dotlist(remaining_overrides)

        remaining = consumed + consumed_free_job_defaults + remaining_overrides

        def is_hydra(x: str) -> bool:
            return x.startswith("hydra.") or x.startswith("hydra/")

        cfg.hydra.overrides.task = [x for x in remaining if not is_hydra(x)]
        cfg.hydra.overrides.hydra = [x for x in remaining if is_hydra(x)]

        with open_dict(cfg.hydra.job):
            if "name" not in cfg.hydra.job:
                cfg.hydra.job.name = JobRuntime().get("name")
            cfg.hydra.job.override_dirname = get_overrides_dirname(
                input_list=cfg.hydra.overrides.task,
                kv_sep=cfg.hydra.job.config.override_dirname.kv_sep,
                item_sep=cfg.hydra.job.config.override_dirname.item_sep,
                exclude_keys=cfg.hydra.job.config.override_dirname.
                exclude_keys,
            )
            cfg.hydra.job.config_file = config_file

        return cfg
Esempio n. 2
0
    def load_configuration(
        self,
        config_name: Optional[str],
        overrides: List[str],
        strict: Optional[bool] = None,
    ) -> DictConfig:
        assert config_name is None or isinstance(config_name, str)
        assert strict is None or isinstance(strict, bool)
        assert isinstance(overrides, list)
        if strict is None:
            strict = self.default_strict

        assert overrides is None or isinstance(overrides, list)
        overrides = copy.deepcopy(overrides) or []

        if config_name is not None and not self.exists_in_search_path(
                config_name):
            # TODO: handle schema as a special case
            descs = [
                "\t{} (from {})".format(src.path, src.provider)
                for src in self.repository.get_sources()
            ]
            lines = "\n".join(descs)
            raise MissingConfigException(
                missing_cfg_file=config_name,
                message=
                f"Cannot find primary config file: {config_name}\nSearch path:\n{lines}",
            )

        # Load hydra config
        hydra_cfg, _load_trace = self._create_cfg(cfg_filename="hydra_config")

        # Load job config
        job_cfg, job_cfg_load_trace = self._create_cfg(
            cfg_filename=config_name, record_load=False)

        job_defaults = ConfigLoaderImpl._get_defaults(job_cfg)
        defaults = ConfigLoaderImpl._get_defaults(hydra_cfg)

        job_cfg_type = OmegaConf.get_type(job_cfg)
        if job_cfg_type is not None and not issubclass(job_cfg_type, dict):
            hydra_cfg._promote(job_cfg_type)
            # this is breaking encapsulation a bit. can potentially be implemented in OmegaConf
            hydra_cfg._metadata.ref_type = job_cfg._metadata.ref_type

        # if defaults are re-introduced by the promotion, remove it.
        if "defaults" in hydra_cfg:
            with open_dict(hydra_cfg):
                del hydra_cfg["defaults"]

        if config_name is not None:
            defaults.append("__SELF__")
        split_at = len(defaults)

        ConfigLoaderImpl._merge_default_lists(defaults, job_defaults)
        consumed = self._apply_defaults_overrides(overrides, defaults)

        consumed_free_job_defaults = self._apply_free_defaults(
            defaults, overrides)

        ConfigLoaderImpl._validate_defaults(defaults)

        # Load and defaults and merge them into cfg
        cfg = self._merge_defaults(hydra_cfg, job_cfg, job_cfg_load_trace,
                                   defaults, split_at)
        OmegaConf.set_struct(cfg.hydra, True)
        OmegaConf.set_struct(cfg, strict)

        # Merge all command line overrides after enabling strict flag
        all_consumed = consumed + consumed_free_job_defaults
        remaining_overrides = [x for x in overrides if x not in all_consumed]
        cfg.merge_with_dotlist(remaining_overrides)

        remaining = consumed + consumed_free_job_defaults + remaining_overrides

        def is_hydra(x: str) -> bool:
            return x.startswith("hydra.") or x.startswith("hydra/")

        cfg.hydra.overrides.task = [x for x in remaining if not is_hydra(x)]
        cfg.hydra.overrides.hydra = [x for x in remaining if is_hydra(x)]

        with open_dict(cfg.hydra.job):
            if "name" not in cfg.hydra.job:
                cfg.hydra.job.name = JobRuntime().get("name")
            cfg.hydra.job.override_dirname = get_overrides_dirname(
                input_list=cfg.hydra.overrides.task,
                kv_sep=cfg.hydra.job.config.override_dirname.kv_sep,
                item_sep=cfg.hydra.job.config.override_dirname.item_sep,
                exclude_keys=cfg.hydra.job.config.override_dirname.
                exclude_keys,
            )
            cfg.hydra.job.config_name = config_name

        return cfg
Esempio n. 3
0
    def load_configuration(
        self,
        config_name: Optional[str],
        overrides: List[str],
        strict: Optional[bool] = None,
    ) -> DictConfig:
        if strict is None:
            strict = self.default_strict

        overrides = copy.deepcopy(overrides) or []

        if config_name is not None and not self.exists_in_search_path(
                config_name):
            # TODO: handle schema as a special case
            descs = [
                f"\t{src.path} (from {src.provider})"
                for src in self.repository.get_sources()
            ]
            lines = "\n".join(descs)
            raise MissingConfigException(
                missing_cfg_file=config_name,
                message=
                f"Cannot find primary config file: {config_name}\nSearch path:\n{lines}",
            )

        # Load hydra config
        hydra_cfg, _load_trace = self._create_cfg(cfg_filename="hydra_config")

        # Load job config
        job_cfg, job_cfg_load_trace = self._create_cfg(
            cfg_filename=config_name, record_load=False)

        job_defaults = self._parse_defaults(job_cfg)
        defaults = self._parse_defaults(hydra_cfg)

        job_cfg_type = OmegaConf.get_type(job_cfg)
        if job_cfg_type is not None and not issubclass(job_cfg_type, dict):
            hydra_cfg._promote(job_cfg_type)
            # this is breaking encapsulation a bit. can potentially be implemented in OmegaConf
            hydra_cfg._metadata.ref_type = job_cfg._metadata.ref_type

        # if defaults are re-introduced by the promotion, remove it.
        if "defaults" in hydra_cfg:
            with open_dict(hydra_cfg):
                del hydra_cfg["defaults"]

        if config_name is not None:
            defaults.append(
                DefaultElement(config_group=None, config_name="__SELF__"))
        split_at = len(defaults)

        self._combine_default_lists(defaults, job_defaults)
        consumed = self._apply_overrides_to_defaults(overrides, defaults)

        # Load and defaults and merge them into cfg
        cfg = self._merge_defaults_into_config(hydra_cfg, job_cfg,
                                               job_cfg_load_trace, defaults,
                                               split_at)
        OmegaConf.set_struct(cfg.hydra, True)
        OmegaConf.set_struct(cfg, strict)

        # Merge all command line overrides after enabling strict flag
        remaining_overrides = [x for x in overrides if x not in consumed]
        try:
            merged = OmegaConf.merge(
                cfg, OmegaConf.from_dotlist(remaining_overrides))
            assert isinstance(merged, DictConfig)
            cfg = merged
        except OmegaConfBaseException as ex:
            raise HydraException("Error merging overrides") from ex

        remaining = consumed + remaining_overrides

        def is_hydra(x: str) -> bool:
            return x.startswith("hydra.") or x.startswith("hydra/")

        cfg.hydra.overrides.task = [x for x in remaining if not is_hydra(x)]
        cfg.hydra.overrides.hydra = [x for x in remaining if is_hydra(x)]

        with open_dict(cfg.hydra.job):
            if "name" not in cfg.hydra.job:
                cfg.hydra.job.name = JobRuntime().get("name")
            cfg.hydra.job.override_dirname = get_overrides_dirname(
                input_list=cfg.hydra.overrides.task,
                kv_sep=cfg.hydra.job.config.override_dirname.kv_sep,
                item_sep=cfg.hydra.job.config.override_dirname.item_sep,
                exclude_keys=cfg.hydra.job.config.override_dirname.
                exclude_keys,
            )
            cfg.hydra.job.config_name = config_name

            for key in cfg.hydra.job.env_copy:
                cfg.hydra.job.env_set[key] = os.environ[key]

        return cfg