Example #1
0
    def sweep(self, arguments: List[str]) -> Any:
        assert self.config is not None
        assert self.launcher is not None
        log.info(f"ExampleSweeper (foo={self.foo}, bar={self.bar}) sweeping")
        log.info(f"Sweep output dir : {self.config.hydra.sweep.dir}")

        # Save sweep run config in top level sweep working directory
        sweep_dir = Path(self.config.hydra.sweep.dir)
        sweep_dir.mkdir(parents=True, exist_ok=True)
        OmegaConf.save(self.config, sweep_dir / "multirun.yaml")

        parser = OverridesParser()
        parsed = parser.parse_overrides(arguments)

        lists = []
        for override in parsed:
            if override.is_sweep_override():
                # Sweepers must manipulate only overrides that return true to is_sweep_override()
                # This syntax is shared across all sweepers, so it may limiting.
                # Sweeper must respect this though: failing to do so will cause all sorts of hard to debug issues.
                # If you would like to propose an extension to the grammar (enabling new types of sweep overrides)
                # Please file an issue and describe the use case and the proposed syntax.
                # Be aware that syntax extensions are potentially breaking compatibility for existing users and the
                # use case will be scrutinized heavily before the syntax is changed.
                sweep_choices = override.choices_as_strings()
                assert isinstance(sweep_choices, list)
                key = override.get_key_element()
                sweep = [f"{key}={val}" for val in sweep_choices]
                lists.append(sweep)
            else:
                key = override.get_key_element()
                value = override.get_value_element()
                lists.append([f"{key}={value}"])
        batches = list(itertools.product(*lists))

        # some sweepers will launch multiple bathes.
        # for such sweepers, it is important that they pass the proper initial_job_idx when launching
        # each batch. see example below.
        # This is required to ensure that working that the job gets a unique job id
        # (which in turn can be used for other things, like the working directory)
        def chunks(lst: Sequence[Sequence[str]],
                   n: Optional[int]) -> Iterable[Sequence[Sequence[str]]]:
            """
            Split input to chunks of up to n items each
            """
            if n is None or n == -1:
                n = len(lst)
            for i in range(0, len(lst), n):
                yield lst[i:i + n]

        chunked_batches = list(chunks(batches, self.max_batch_size))

        returns = []
        initial_job_idx = 0
        for batch in chunked_batches:
            results = self.launcher.launch(batch,
                                           initial_job_idx=initial_job_idx)
            initial_job_idx += len(batch)
            returns.append(results)
        return returns
Example #2
0
    def split_arguments(
            arguments: List[str],
            max_batch_size: Optional[int]) -> List[List[List[str]]]:
        parser = OverridesParser()
        parsed = parser.parse_overrides(arguments)

        lists = []
        for override in parsed:
            if override.is_sweep_override():
                sweep_choices = override.choices_as_strings()
                assert isinstance(sweep_choices, list)
                key = override.get_key_element()
                sweep = [f"{key}={val}" for val in sweep_choices]
                lists.append(sweep)
            else:
                key = override.get_key_element()
                value = override.get_value_element()
                lists.append([f"{key}={value}"])
        all_batches = [list(x) for x in itertools.product(*lists)]
        assert max_batch_size is None or max_batch_size > 0
        if max_batch_size is None:
            return [all_batches]
        else:
            chunks_iter = BasicSweeper.split_overrides_to_chunks(
                all_batches, max_batch_size)
            return [x for x in chunks_iter]
def test_overrides_parser() -> None:
    overrides = ["x=10", "y=[1,2]", "z=a,b,c"]
    expected = [
        Override(
            type=OverrideType.CHANGE,
            key_or_group="x",
            value_type=ValueType.ELEMENT,
            _value=10,
            input_line=overrides[0],
        ),
        Override(
            type=OverrideType.CHANGE,
            key_or_group="y",
            value_type=ValueType.ELEMENT,
            _value=[1, 2],
            input_line=overrides[1],
        ),
        Override(
            type=OverrideType.CHANGE,
            key_or_group="z",
            value_type=ValueType.CHOICE_SWEEP,
            _value=["a", "b", "c"],
            input_line=overrides[2],
        ),
    ]
    parser = OverridesParser()
    ret = parser.parse_overrides(overrides)
    assert ret == expected
Example #4
0
def test_apply_overrides_to_config(input_cfg: Any, overrides: List[str],
                                   expected: Any) -> None:
    cfg = OmegaConf.create(input_cfg)
    OmegaConf.set_struct(cfg, True)
    parser = OverridesParser()
    parsed = parser.parse_overrides(overrides=overrides)

    if isinstance(expected, dict):
        ConfigLoaderImpl._apply_overrides_to_config(overrides=parsed, cfg=cfg)
        assert cfg == expected
    else:
        with expected:
            ConfigLoaderImpl._apply_overrides_to_config(overrides=parsed,
                                                        cfg=cfg)
def test_primitive(value: str, prefix: str, suffix: str,
                   expected: Any) -> None:
    ret = OverridesParser.parse_rule(prefix + value + suffix, "primitive")
    if isinstance(ret, QuotedString):
        assert value == ret.with_quotes()

    assert eq(ret, expected)
Example #6
0
def test_split(args: List[str], max_batch_size: Optional[int],
               expected: List[List[List[str]]]) -> None:
    parser = OverridesParser.create()
    ret = BasicSweeper.split_arguments(parser.parse_overrides(args),
                                       max_batch_size=max_batch_size)
    lret = [list(x) for x in ret]
    assert lret == expected
Example #7
0
    def sweep(self, arguments: List[str]) -> Any:
        assert self.config is not None
        assert self.launcher is not None

        parser = OverridesParser.create(config_loader=self.config_loader)
        overrides = parser.parse_overrides(arguments)

        self.overrides = self.split_arguments(overrides, self.max_batch_size)
        returns: List[Sequence[JobReturn]] = []

        # Save sweep run config in top level sweep working directory
        sweep_dir = Path(self.config.hydra.sweep.dir)
        sweep_dir.mkdir(parents=True, exist_ok=True)
        OmegaConf.save(self.config, sweep_dir / "multirun.yaml")

        initial_job_idx = 0
        while not self.is_done():
            batch = self.get_job_batch()
            tic = time.perf_counter()
            # Validate that jobs can be safely composed. This catches composition errors early.
            # This can be a bit slow for large jobs. can potentially allow disabling from the config.
            self.validate_batch_is_legal(batch)
            elapsed = time.perf_counter() - tic
            log.debug(
                f"Validated configs of {len(batch)} jobs in {elapsed:0.2f} seconds, {len(batch)/elapsed:.2f} / second)"
            )
            results = self.launcher.launch(batch, initial_job_idx=initial_job_idx)
            initial_job_idx += len(batch)
            returns.append(results)
        return returns
Example #8
0
def test_create_nevergrad_parameter_from_override(
    input: Any,
    expected: Any,
) -> None:
    parser = OverridesParser.create()
    parsed = parser.parse_overrides([input])[0]
    param = _impl.create_nevergrad_parameter_from_override(parsed)
    assert_ng_param_equals(param, expected)
Example #9
0
def test_delete_by_assigning_null_is_deprecated() -> None:
    msg = ("\nRemoving from the defaults list by assigning 'null' "
           "is deprecated and will be removed in Hydra 1.1."
           "\nUse ~db")

    defaults = ConfigLoaderImpl._parse_defaults(
        OmegaConf.create({"defaults": [{
            "db": "mysql"
        }]}))

    parser = OverridesParser()

    override = parser.parse_override("db=null")

    with pytest.warns(expected_warning=UserWarning, match=re.escape(msg)):
        ConfigLoaderImpl._apply_overrides_to_defaults(overrides=[override],
                                                      defaults=defaults)
        assert defaults == []
Example #10
0
def test_apply_overrides_to_defaults(input_defaults: List[str],
                                     overrides: List[str],
                                     expected: Any) -> None:
    defaults = ConfigLoaderImpl._parse_defaults(
        OmegaConf.create({"defaults": input_defaults}))

    parser = OverridesParser()
    if isinstance(expected, list):
        parsed_overrides = parser.parse_overrides(overrides=overrides)
        expected_defaults = ConfigLoaderImpl._parse_defaults(
            OmegaConf.create({"defaults": expected}))
        ConfigLoaderImpl._apply_overrides_to_defaults(
            overrides=parsed_overrides, defaults=defaults)
        assert defaults == expected_defaults
    else:
        with expected:
            parsed_overrides = parser.parse_overrides(overrides=overrides)
            ConfigLoaderImpl._apply_overrides_to_defaults(
                overrides=parsed_overrides, defaults=defaults)
    def sweep(self, arguments: List[str]) -> None:
        parser = OverridesParser.create()
        parsed = parser.parse_overrides(arguments)

        search_space = {}
        for override in parsed:
            search_space[override.get_key_element(
            )] = create_optuna_distribution_from_override(override)

        study = optuna.create_study(study_name=self.optuna_config.study_name,
                                    storage=self.optuna_config.storage,
                                    direction=self.optuna_config.direction)

        batch_size = self.optuna_config.n_jobs
        n_trials_to_go = self.optuna_config.n_trials

        while n_trials_to_go > 0:
            batch_size = min(n_trials_to_go, batch_size)

            trials = [study._ask() for _ in range(batch_size)]
            for trial in trials:
                for param_name, distribution in search_space.items():
                    trial._suggest(param_name, distribution)

            overrides = []
            for trial in trials:
                params = trial.params
                overrides.append(
                    tuple(f"{name}={val}" for name, val in params.items()))

            returns = self.launcher.launch(overrides,
                                           initial_job_idx=trials[0].number)
            for trial, ret in zip(trials, returns):
                study._tell(trial, optuna.trial.TrialState.COMPLETE,
                            ret.return_value)
            n_trials_to_go -= batch_size

        best_trial = study.best_trial
        results_to_serialize = {
            "name": "optuna",
            "best_params": best_trial.params,
            "best_value": best_trial.value,
        }
        OmegaConf.save(
            OmegaConf.create(results_to_serialize),
            f"{self.config.hydra.sweep.dir}/optimization_results.yaml",
        )
        log.info(f"Best parameters: {best_trial.params}")
        log.info(f"Best value: {best_trial.value}")
        log.info(f"Storage: {self.optuna_config.storage}")
        log.info(f"Study name: {study.study_name}")
Example #12
0
    def initialize_arguments(self, arguments: List[str]) -> None:
        parser = OverridesParser()
        parsed = parser.parse_overrides(arguments)

        lists = []
        for override in parsed:
            if override.is_sweep_override():
                sweep_choices = override.choices_as_strings()
                assert isinstance(sweep_choices, list)
                key = override.get_key_element()
                sweep = [f"{key}={val}" for val in sweep_choices]
                lists.append(sweep)
            else:
                key = override.get_key_element()
                value = override.get_value_element()
                lists.append([f"{key}={value}"])
        all_batches = list(itertools.product(*lists))
        assert self.max_batch_size is None or self.max_batch_size > 0
        if self.max_batch_size is None:
            self.overrides = [all_batches]
        else:
            self.overrides = list(
                self.split_overrides_to_chunks(all_batches,
                                               self.max_batch_size))
Example #13
0
 def compute_defaults_list(
     self,
     config_name: Optional[str],
     overrides: List[str],
     run_mode: RunMode,
 ) -> DefaultsList:
     parser = OverridesParser.create()
     repo = CachingConfigRepository(self.repository)
     defaults_list = create_defaults_list(
         repo=repo,
         config_name=config_name,
         overrides_list=parser.parse_overrides(overrides=overrides),
         prepend_hydra=True,
         skip_missing=run_mode == RunMode.MULTIRUN,
     )
     return defaults_list
Example #14
0
def test_override_parsing(
    prefix: str,
    value: str,
    override_type: OverrideType,
    expected_key: str,
    expected_value: Any,
    expected_value_type: ValueType,
) -> None:
    line = prefix + value
    ret = OverridesParser.parse_rule(line, "override")
    expected = Override(
        input_line=line,
        type=override_type,
        key_or_group=expected_key,
        _value=expected_value,
        value_type=expected_value_type,
    )
    assert ret == expected
Example #15
0
def _test_defaults_tree_impl(
    config_name: Optional[str],
    input_overrides: List[str],
    expected: Any,
    prepend_hydra: bool = False,
    skip_missing: bool = False,
) -> Optional[DefaultsList]:
    parser = OverridesParser.create()
    repo = create_repo()
    root = _create_root(config_name=config_name, with_hydra=prepend_hydra)
    overrides_list = parser.parse_overrides(overrides=input_overrides)
    overrides = Overrides(repo=repo, overrides_list=overrides_list)

    if expected is None or isinstance(expected, DefaultsTreeNode):
        result = _create_defaults_tree(
            repo=repo,
            root=root,
            overrides=overrides,
            is_root_config=True,
            interpolated_subtree=False,
            skip_missing=skip_missing,
        )
        overrides.ensure_overrides_used()
        overrides.ensure_deletions_used()
        assert result == expected
        return DefaultsList(defaults=[],
                            defaults_tree=result,
                            overrides=overrides,
                            config_overrides=[])
    else:
        with expected:
            _create_defaults_tree(
                repo=repo,
                root=root,
                overrides=overrides,
                is_root_config=True,
                interpolated_subtree=False,
                skip_missing=skip_missing,
            )
            overrides.ensure_overrides_used()
            overrides.ensure_deletions_used()
        return None
Example #16
0
    def parse_commandline_args(
        self, arguments: List[str]
    ) -> List[Dict[str, Union[ax_types.TParamValue, List[ax_types.TParamValue]]]]:
        """Method to parse the command line arguments and convert them into Ax parameters"""
        parser = OverridesParser.create()
        parsed = parser.parse_overrides(arguments)
        parameters: List[Dict[str, Any]] = []
        for override in parsed:
            if override.is_sweep_override():
                if override.is_choice_sweep():
                    param = create_choice_param_from_choice_override(override)
                elif override.is_range_sweep():
                    param = create_choice_param_from_range_override(override)
                elif override.is_interval_sweep():
                    param = create_range_param_using_interval_override(override)
            elif not override.is_hydra_override():
                param = create_fixed_param_from_element_override(override)
            parameters.append(param)

        return parameters
Example #17
0
    def parse_overrides(
        overrides: List[str],
        run_mode: RunMode,
        from_shell: bool,
    ) -> List[Override]:
        parser = OverridesParser.create()
        parsed_overrides = parser.parse_overrides(overrides=overrides)
        config_overrides = []
        for x in parsed_overrides:
            if x.is_sweep_override():
                if run_mode == RunMode.MULTIRUN:
                    if x.is_hydra_override():
                        raise ConfigCompositionException(
                            f"Sweeping over Hydra's configuration is not supported : '{x.input_line}'"
                        )
                    # do not process sweep overrides in multirun mode.
                    # They will be handled directly by the sweeper
                elif run_mode == RunMode.RUN:
                    if x.value_type == ValueType.SIMPLE_CHOICE_SWEEP:
                        vals = "value1,value2"
                        if from_shell:
                            example_override = f"key=\\'{vals}\\'"
                        else:
                            example_override = f"key='{vals}'"

                        msg = dedent(f"""\
                            Ambiguous value for argument '{x.input_line}'
                            1. To use it as a list, use key=[value1,value2]
                            2. To use it as string, quote the value: {example_override}
                            3. To sweep over it, add --multirun to your command line"""
                                     )
                        raise ConfigCompositionException(msg)
                    else:
                        raise ConfigCompositionException(
                            f"Sweep parameters '{x.input_line}' requires --multirun"
                        )
                else:
                    assert False
            else:
                config_overrides.append(x)
        return config_overrides
Example #18
0
    def apply_overrides(cfg, overrides: List[str]):
        """
        In-place override contents of cfg.

        Args:
            cfg: an omegaconf config object
            overrides: list of strings in the format of "a=b" to override configs.
                See https://hydra.cc/docs/next/advanced/override_grammar/basic/
                for syntax.

        Returns:
            the cfg object
        """

        def safe_update(cfg, key, value):
            parts = key.split(".")
            for idx in range(1, len(parts)):
                prefix = ".".join(parts[:idx])
                v = OmegaConf.select(cfg, prefix, default=None)
                if v is None:
                    break
                if not OmegaConf.is_config(v):
                    raise KeyError(
                        f"Trying to update key {key}, but {prefix} "
                        f"is not a config, but has type {type(v)}."
                    )
            OmegaConf.update(cfg, key, value, merge=True)

        from hydra.core.override_parser.overrides_parser import OverridesParser

        parser = OverridesParser.create()
        overrides = parser.parse_overrides(overrides)
        for o in overrides:
            key = o.key_or_group
            value = o.value()
            if o.is_delete():
                # TODO support this
                raise NotImplementedError("deletion is not yet a supported override")
            safe_update(cfg, key, value)
        return cfg
Example #19
0
    def apply_overrides(cfg, overrides: List[str]):
        """
        In-place override contents of cfg.

        Args:
            cfg: an omegaconf config object
            overrides: list of strings in the format of "a=b" to override configs.
                See https://hydra.cc/docs/next/advanced/override_grammar/basic/
                for syntax.

        Returns:
            the cfg object
        """
        from hydra.core.override_parser.overrides_parser import OverridesParser

        parser = OverridesParser.create()
        overrides = parser.parse_overrides(overrides)
        for o in overrides:
            key = o.key_or_group
            value = o.value()
            # TODO seems nice to support this
            assert not o.is_delete(), "deletion is not a supported override"
            OmegaConf.update(cfg, key, value, merge=True)
        return cfg
Example #20
0
    def sweep(self, arguments: List[str]) -> None:
        assert self.config is not None
        assert self.launcher is not None
        assert self.job_idx is not None

        parser = OverridesParser.create()
        parsed = parser.parse_overrides(arguments)

        search_space = dict(self.search_space)
        fixed_params = dict()
        for override in parsed:
            value = create_optuna_distribution_from_override(override)
            if isinstance(value, BaseDistribution):
                search_space[override.get_key_element()] = value
            else:
                fixed_params[override.get_key_element()] = value
        # Remove fixed parameters from Optuna search space.
        for param_name in fixed_params:
            if param_name in search_space:
                del search_space[param_name]

        samplers = {
            "tpe": "optuna.samplers.TPESampler",
            "random": "optuna.samplers.RandomSampler",
            "cmaes": "optuna.samplers.CmaEsSampler",
        }
        if self.optuna_config.sampler.name not in samplers:
            raise NotImplementedError(
                f"{self.optuna_config.sampler} is not supported by Optuna sweeper."
            )

        sampler_class = get_class(samplers[self.optuna_config.sampler.name])
        sampler = sampler_class(seed=self.optuna_config.seed)

        # TODO (toshihikoyanase): Remove type-ignore when optuna==2.4.0 is released.
        study = optuna.create_study(  # type: ignore
            study_name=self.optuna_config.study_name,
            storage=self.optuna_config.storage,
            sampler=sampler,
            direction=self.optuna_config.direction.name,
            load_if_exists=True,
        )
        log.info(f"Study name: {study.study_name}")
        log.info(f"Storage: {self.optuna_config.storage}")
        log.info(f"Sampler: {self.optuna_config.sampler.name}")
        log.info(f"Direction: {self.optuna_config.direction.name}")

        batch_size = self.optuna_config.n_jobs
        n_trials_to_go = self.optuna_config.n_trials

        while n_trials_to_go > 0:
            batch_size = min(n_trials_to_go, batch_size)

            trials = [study._ask() for _ in range(batch_size)]
            overrides = []
            for trial in trials:
                for param_name, distribution in search_space.items():
                    trial._suggest(param_name, distribution)

                params = dict(trial.params)
                params.update(fixed_params)
                overrides.append(
                    tuple(f"{name}={val}" for name, val in params.items()))

            returns = self.launcher.launch(overrides,
                                           initial_job_idx=self.job_idx)
            self.job_idx += len(returns)
            for trial, ret in zip(trials, returns):
                # TODO (toshihikoyanase): Remove type-ignore when optuna==2.4.0 is released.
                study._tell(trial, optuna.trial.TrialState.COMPLETE,
                            ret.return_value)  # type: ignore
            n_trials_to_go -= batch_size

        best_trial = study.best_trial
        results_to_serialize = {
            "name": "optuna",
            "best_params": best_trial.params,
            "best_value": best_trial.value,
        }
        OmegaConf.save(
            OmegaConf.create(results_to_serialize),
            f"{self.config.hydra.sweep.dir}/optimization_results.yaml",
        )
        log.info(f"Best parameters: {best_trial.params}")
        log.info(f"Best value: {best_trial.value}")
Example #21
0
    def sweep(self, arguments: List[str]) -> None:
        assert self.config is not None
        assert self.launcher is not None
        assert self.hydra_context is not None
        assert self.job_idx is not None

        parser = OverridesParser.create()
        parsed = parser.parse_overrides(arguments)

        search_space = dict(self.search_space)
        fixed_params = dict()
        for override in parsed:
            value = create_optuna_distribution_from_override(override)
            if isinstance(value, BaseDistribution):
                search_space[override.get_key_element()] = value
            else:
                fixed_params[override.get_key_element()] = value
        # Remove fixed parameters from Optuna search space.
        for param_name in fixed_params:
            if param_name in search_space:
                del search_space[param_name]

        directions: List[str]
        if isinstance(self.direction, MutableSequence):
            directions = [
                d.name if isinstance(d, Direction) else d
                for d in self.direction
            ]
        else:
            if isinstance(self.direction, str):
                directions = [self.direction]
            else:
                directions = [self.direction.name]

        study = optuna.create_study(
            study_name=self.study_name,
            storage=self.storage,
            sampler=self.sampler,
            directions=directions,
            load_if_exists=True,
        )
        log.info(f"Study name: {study.study_name}")
        log.info(f"Storage: {self.storage}")
        log.info(f"Sampler: {type(self.sampler).__name__}")
        log.info(f"Directions: {directions}")

        batch_size = self.n_jobs
        n_trials_to_go = self.n_trials

        while n_trials_to_go > 0:
            batch_size = min(n_trials_to_go, batch_size)

            trials = [study._ask() for _ in range(batch_size)]
            overrides = []
            for trial in trials:
                for param_name, distribution in search_space.items():
                    trial._suggest(param_name, distribution)

                params = dict(trial.params)
                params.update(fixed_params)
                overrides.append(
                    tuple(f"{name}={val}" for name, val in params.items()))

            returns = self.launcher.launch(overrides,
                                           initial_job_idx=self.job_idx)
            self.job_idx += len(returns)
            for trial, ret in zip(trials, returns):
                values: Optional[List[float]] = None
                state: optuna.trial.TrialState = optuna.trial.TrialState.COMPLETE
                try:
                    if len(directions) == 1:
                        try:
                            values = [float(ret.return_value)]
                        except (ValueError, TypeError):
                            raise ValueError(
                                f"Return value must be float-castable. Got '{ret.return_value}'."
                            ).with_traceback(sys.exc_info()[2])
                    else:
                        try:
                            values = [float(v) for v in ret.return_value]
                        except (ValueError, TypeError):
                            raise ValueError(
                                "Return value must be a list or tuple of float-castable values."
                                f" Got '{ret.return_value}'.").with_traceback(
                                    sys.exc_info()[2])
                        if len(values) != len(directions):
                            raise ValueError(
                                "The number of the values and the number of the objectives are"
                                f" mismatched. Expect {len(directions)}, but actually {len(values)}."
                            )
                    study._tell(trial, state, values)
                except Exception as e:
                    state = optuna.trial.TrialState.FAIL
                    study._tell(trial, state, values)
                    raise e

            n_trials_to_go -= batch_size

        results_to_serialize: Dict[str, Any]
        if len(directions) < 2:
            best_trial = study.best_trial
            results_to_serialize = {
                "name": "optuna",
                "best_params": best_trial.params,
                "best_value": best_trial.value,
            }
            log.info(f"Best parameters: {best_trial.params}")
            log.info(f"Best value: {best_trial.value}")
        else:
            best_trials = study.best_trials
            pareto_front = [{
                "params": t.params,
                "values": t.values
            } for t in best_trials]
            results_to_serialize = {
                "name": "optuna",
                "solutions": pareto_front,
            }
            log.info(f"Number of Pareto solutions: {len(best_trials)}")
            for t in best_trials:
                log.info(f"    Values: {t.values}, Params: {t.params}")
        OmegaConf.save(
            OmegaConf.create(results_to_serialize),
            f"{self.config.hydra.sweep.dir}/optimization_results.yaml",
        )
def test_create_optuna_distribution_from_override(input: Any,
                                                  expected: Any) -> None:
    parser = OverridesParser.create()
    parsed = parser.parse_overrides([input])[0]
    actual = _impl.create_optuna_distribution_from_override(parsed)
    check_distribution(expected, actual)
    def _load_configuration_impl(
        self,
        config_name: Optional[str],
        overrides: List[str],
        run_mode: RunMode,
        from_shell: bool = True,
    ) -> DictConfig:
        from hydra import __version__

        self.ensure_main_config_source_available()
        caching_repo = CachingConfigRepository(self.repository)

        parser = OverridesParser.create()
        parsed_overrides = parser.parse_overrides(overrides=overrides)

        self._process_config_searchpath(config_name, parsed_overrides,
                                        caching_repo)

        self.validate_sweep_overrides_legal(overrides=parsed_overrides,
                                            run_mode=run_mode,
                                            from_shell=from_shell)

        defaults_list = create_defaults_list(
            repo=caching_repo,
            config_name=config_name,
            overrides_list=parsed_overrides,
            prepend_hydra=True,
            skip_missing=run_mode == RunMode.MULTIRUN,
        )

        config_overrides = defaults_list.config_overrides

        cfg = self._compose_config_from_defaults_list(
            defaults=defaults_list.defaults, repo=caching_repo)

        # Set config root to struct mode.
        # Note that this will close any dictionaries (including dicts annotated as Dict[K, V].
        # One must use + to add new fields to them.
        OmegaConf.set_struct(cfg, True)

        # The Hydra node should not be read-only even if the root config is read-only.
        OmegaConf.set_readonly(cfg.hydra, False)

        # Apply command line overrides after enabling strict flag
        ConfigLoaderImpl._apply_overrides_to_config(config_overrides, cfg)
        app_overrides = []
        for override in parsed_overrides:
            if override.is_hydra_override():
                cfg.hydra.overrides.hydra.append(override.input_line)
            else:
                cfg.hydra.overrides.task.append(override.input_line)
                app_overrides.append(override)

        with open_dict(cfg.hydra):
            cfg.hydra.runtime.choices.update(
                defaults_list.overrides.known_choices)
            for key in cfg.hydra.job.env_copy:
                cfg.hydra.job.env_set[key] = os.environ[key]

        cfg.hydra.runtime.version = __version__
        cfg.hydra.runtime.cwd = os.getcwd()

        cfg.hydra.runtime.config_sources = [
            ConfigSourceInfo(path=x.path,
                             schema=x.scheme(),
                             provider=x.provider)
            for x in caching_repo.get_sources()
        ]

        if "name" not in cfg.hydra.job:
            cfg.hydra.job.name = JobRuntime().get("name")

        cfg.hydra.job.override_dirname = get_overrides_dirname(
            overrides=app_overrides,
            kv_sep=cfg.hydra.job.config.override_dirname.kv_sep,
            item_sep=cfg.hydra.job.config.override_dirname.item_sep,
            exclude_keys=cfg.hydra.job.config.override_dirname.exclude_keys,
        )
        cfg.hydra.job.config_name = config_name

        return cfg
Example #24
0
    def _load_configuration_impl(
        self,
        config_name: Optional[str],
        overrides: List[str],
        run_mode: RunMode,
        from_shell: bool = True,
    ) -> DictConfig:
        self.ensure_main_config_source_available()
        caching_repo = CachingConfigRepository(self.repository)

        parser = OverridesParser.create()
        parsed_overrides = parser.parse_overrides(overrides=overrides)

        self.validate_sweep_overrides_legal(overrides=parsed_overrides,
                                            run_mode=run_mode,
                                            from_shell=from_shell)

        defaults_list = create_defaults_list(
            repo=caching_repo,
            config_name=config_name,
            overrides_list=parser.parse_overrides(overrides=overrides),
            prepend_hydra=True,
            skip_missing=run_mode == RunMode.MULTIRUN,
        )

        config_overrides = defaults_list.config_overrides

        cfg = self._compose_config_from_defaults_list(
            defaults=defaults_list.defaults, repo=caching_repo)

        # set struct mode on user config if the root node is not typed.
        if OmegaConf.get_type(cfg) is dict:
            OmegaConf.set_struct(cfg, True)

        # Turn off struct mode on the Hydra node. It gets its type safety from it being a Structured Config.
        # This enables adding fields to nested dicts like hydra.job.env_set without having to using + to append.
        OmegaConf.set_struct(cfg.hydra, False)

        # Make the Hydra node writeable (The user may have made the primary node read-only).
        OmegaConf.set_readonly(cfg.hydra, False)

        # Apply command line overrides after enabling strict flag
        ConfigLoaderImpl._apply_overrides_to_config(config_overrides, cfg)
        app_overrides = []
        for override in parsed_overrides:
            if override.is_hydra_override():
                cfg.hydra.overrides.hydra.append(override.input_line)
            else:
                cfg.hydra.overrides.task.append(override.input_line)
                app_overrides.append(override)

        cfg.hydra.choices.update(defaults_list.overrides.known_choices)

        with open_dict(cfg.hydra):
            from hydra import __version__

            cfg.hydra.runtime.version = __version__
            cfg.hydra.runtime.cwd = os.getcwd()

            if "name" not in cfg.hydra.job:
                cfg.hydra.job.name = JobRuntime().get("name")
            cfg.hydra.job.override_dirname = get_overrides_dirname(
                overrides=app_overrides,
                kv_sep=cfg.hydra.job.config.override_dirname.kv_sep,
                item_sep=cfg.hydra.job.config.override_dirname.item_sep,
                exclude_keys=cfg.hydra.job.config.override_dirname.
                exclude_keys,
            )
            cfg.hydra.job.config_name = config_name

            for key in cfg.hydra.job.env_copy:
                cfg.hydra.job.env_set[key] = os.environ[key]

        return cfg
Example #25
0
def test_value(value: str, expected: Any) -> None:
    ret = OverridesParser.parse_rule(value, "value")
    assert ret == expected
Example #26
0
def test_override_value_method(override: str, expected: str) -> None:
    ret = OverridesParser.parse_rule(override, "override")
    assert ret.value() == expected
Example #27
0
def test_override_get_value_element_method(override: str, expected: str,
                                           space_after_sep: bool) -> None:
    ret = OverridesParser.parse_rule(override, "override")
    assert ret.get_value_element(space_after_sep=space_after_sep) == expected
Example #28
0
def test_get_key_element(override: str, expected: str) -> None:
    ret = OverridesParser.parse_rule(override, "override")
    assert ret.get_key_element() == expected
Example #29
0
def test_override_del(value: str, expected: Any) -> None:
    expected.input_line = value
    ret = OverridesParser.parse_rule(value, "override")
    assert ret == expected
Example #30
0
def test_element(value: str, expected: Any) -> None:
    ret = OverridesParser.parse_rule(value, "element")
    assert eq(ret, expected)