Example #1
0
def test_nested_flag_override() -> None:
    c = OmegaConf.create({"a": {"b": 1}})
    with flag_override(c, "test", True):
        assert c._get_flag("test") is True
        with flag_override(c.a, "test", False):
            assert c.a._get_flag("test") is False
    assert c.a._get_flag("test") is None
Example #2
0
 def get_sanitized_hydra_cfg(src_cfg: DictConfig) -> DictConfig:
     cfg = copy.deepcopy(src_cfg)
     with flag_override(cfg, ["struct", "readonly"], [False, False]):
         for key in list(cfg.keys()):
             if key != "hydra":
                 del cfg[key]
     with flag_override(cfg.hydra, ["struct", "readonly"], False):
         del cfg.hydra["hydra_help"]
         del cfg.hydra["help"]
     return cfg
Example #3
0
    def _set_value_impl(
        self, value: Any, flags: Optional[Dict[str, bool]] = None
    ) -> None:
        from omegaconf import OmegaConf, flag_override

        if id(self) == id(value):
            return

        if flags is None:
            flags = {}

        assert not isinstance(value, ValueNode)
        self._validate_set(key=None, value=value)

        if OmegaConf.is_none(value):
            self.__dict__["_content"] = None
            self._metadata.object_type = None
        elif _is_interpolation(value):
            self.__dict__["_content"] = value
            self._metadata.object_type = None
        elif value == "???":
            self.__dict__["_content"] = "???"
            self._metadata.object_type = None
        else:
            self.__dict__["_content"] = {}
            if is_structured_config(value):
                self._metadata.object_type = None
                data = get_structured_config_data(
                    value,
                    allow_objects=self._get_flag("allow_objects"),
                )
                for k, v in data.items():
                    self.__setitem__(k, v)
                self._metadata.object_type = get_type_of(value)
            elif isinstance(value, DictConfig):
                self.__dict__["_metadata"] = copy.deepcopy(value._metadata)
                self._metadata.flags = copy.deepcopy(flags)
                # disable struct and readonly for the construction phase
                # retaining other flags like allow_objects. The real flags are restored at the end of this function
                with flag_override(self, "struct", False):
                    with flag_override(self, "readonly", False):
                        for k, v in value.__dict__["_content"].items():
                            self.__setitem__(k, v)

            elif isinstance(value, dict):
                for k, v in value.items():
                    self.__setitem__(k, v)
            else:  # pragma: no cover
                msg = f"Unsupported value type : {value}"
                raise ValidationError(msg)
Example #4
0
    def _generate_runners(self, run_mode: RunMode) -> List[TrainingRunner]:
        """
        Generates training or rollout runner(s).
        :param run_mode: Run mode. See See :py:class:`~maze.maze.api.RunMode`.
        :return: Instantiated Runner instance.
        """

        cl = ConfigurationLoader(_run_mode=run_mode,
                                 _kwargs=self._auditors[run_mode].kwargs,
                                 _overrides=self._auditors[run_mode].overrides,
                                 _ephemeral_init_kwargs=self.
                                 _auditors[run_mode].ephemeral_init_kwargs)
        cl.load()

        self._workdirs = cl.workdirs
        self._configs[run_mode] = cl.configs
        runners: List[TrainingRunner] = []

        # Change to correct working directory (necessary due to being outside of Hydra scope).
        for workdir, config in zip(self._workdirs, self._configs[run_mode]):
            with working_directory(workdir):
                # Allow non-primitives in Hydra config.
                with omegaconf.flag_override(config, "allow_objects",
                                             True) as cfg:
                    # Set up and return runner.
                    runner = Factory(
                        base_type=TrainingRunner if run_mode ==
                        RunMode.TRAINING else RolloutRunner).instantiate(
                            cfg.runner)
                    runner.setup(cfg)
                    runners.append(runner)

        return runners
Example #5
0
    def _compose_config_from_defaults_list(
        self,
        defaults: List[ResultDefault],
        repo: IConfigRepository,
    ) -> DictConfig:
        cfg = OmegaConf.create()
        with flag_override(cfg, "no_deepcopy_set_nodes", True):
            for default in defaults:
                loaded = self._load_single_config(default=default, repo=repo)
                try:
                    cfg.merge_with(loaded.config)
                except OmegaConfBaseException as e:
                    raise ConfigCompositionException(
                        f"In '{default.config_path}': {type(e).__name__} raised while composing config:\n{e}"
                    ).with_traceback(sys.exc_info()[2])

        # # remove remaining defaults lists from all nodes.
        def strip_defaults(cfg: Any) -> None:
            if isinstance(cfg, DictConfig):
                if cfg._is_missing() or cfg._is_none():
                    return
                with flag_override(cfg, ["readonly", "struct"], False):
                    if cfg._get_flag("HYDRA_REMOVE_TOP_LEVEL_DEFAULTS"):
                        cfg._set_flag("HYDRA_REMOVE_TOP_LEVEL_DEFAULTS", None)
                        cfg.pop("defaults", None)

                for _key, value in cfg.items_ex(resolve=False):
                    strip_defaults(value)

        strip_defaults(cfg)

        return cfg
Example #6
0
def test_multiple_flags_override() -> None:
    c = OmegaConf.create({"foo": "bar"})
    with flag_override(c, ["readonly"], True):
        with pytest.raises(ReadonlyConfigError):
            c.foo = 10

    with flag_override(c, ["struct"], True):
        with pytest.raises(ConfigAttributeError):
            c.x = 10

    with flag_override(c, ["struct", "readonly"], True):
        with pytest.raises(ConfigAttributeError):
            c.x = 10

        with pytest.raises(ReadonlyConfigError):
            c.foo = 20
Example #7
0
    def _set_value_impl(self,
                        value: Any,
                        flags: Optional[Dict[str, bool]] = None) -> None:
        from omegaconf import MISSING, flag_override

        if flags is None:
            flags = {}

        assert not isinstance(value, ValueNode)
        self._validate_set(key=None, value=value)

        if _is_none(value, resolve=True):
            self.__dict__["_content"] = None
            self._metadata.object_type = None
        elif _is_interpolation(value, strict_interpolation_validation=True):
            self.__dict__["_content"] = value
            self._metadata.object_type = None
        elif _is_missing_value(value):
            self.__dict__["_content"] = MISSING
            self._metadata.object_type = None
        else:
            self.__dict__["_content"] = {}
            if is_structured_config(value):
                self._metadata.object_type = None
                ao = self._get_flag("allow_objects")
                data = get_structured_config_data(value, allow_objects=ao)
                with flag_override(self, ["struct", "readonly"], False):
                    for k, v in data.items():
                        self.__setitem__(k, v)
                self._metadata.object_type = get_type_of(value)

            elif isinstance(value, DictConfig):
                self._metadata.flags = copy.deepcopy(flags)
                with flag_override(self, ["struct", "readonly"], False):
                    for k, v in value.__dict__["_content"].items():
                        self.__setitem__(k, v)
                self._metadata.object_type = value._metadata.object_type

            elif isinstance(value, dict):
                with flag_override(self, ["struct", "readonly"], False):
                    for k, v in value.items():
                        self.__setitem__(k, v)
                self._metadata.object_type = dict

            else:  # pragma: no cover
                msg = f"Unsupported value type: {value}"
                raise ValidationError(msg)
Example #8
0
 def get_sanitized_cfg(self, cfg: DictConfig, cfg_type: str) -> DictConfig:
     assert cfg_type in ["job", "hydra", "all"]
     if cfg_type == "job":
         with flag_override(cfg, ["struct", "readonly"], [False, False]):
             del cfg["hydra"]
     elif cfg_type == "hydra":
         cfg = self.get_sanitized_hydra_cfg(cfg)
     return cfg
Example #9
0
def test_dict_assign_illegal_value_nested() -> None:
    c = OmegaConf.create({"a": {}})
    iv = IllegalType()
    with pytest.raises(UnsupportedValueType, match=re.escape("key: a.b")):
        c.a.b = iv

    with flag_override(c, "allow_objects", True):
        c.a.b = iv
    assert c.a.b == iv
Example #10
0
    def load(self) -> None:
        """
        Loads Hydra configuration and post-processes it according to specified RunContext.

        This also detects inconsistencies in the specification and raises errors to prevent errors at run time. There
        are the following sources for such inconsistencies:

        * Codependent components. Some components, e.g. environments and algorithms, are completely independent
          from each other - each environment can be run with each algorithm. Others have a codependency with each
          other, e.g. Runner to Trainer: A runner is specific to a particular trainer. If a component and another
          component to it are specified, these have to be consistent with each other. E.g.: Running
          A2CMultistepTrainer and ESTrainingRunner will fail.
        * Derived config groups. Some config groups depend on other attributes or qualifiers, e.g.
          algorithm_configuration (qualifiers: algorithm, configuration) or algorithm_runner (algorithm, runner).
          Hydra loads these automatically w.r.t. the set qualifiers. This is not the case when RunContext injects
          instantiated objects though, since the value for the underlying Hyra configuration group is not changed
          automatically.
          E.g.: If an A2CAlgorithmConfig is passed, the expected behaviour would be to load the corresponding A2C
          trainer config in algorithm_config. Since the default value for the config group "algorithm" is "ES"
          however, Hydra will load the configuration for the module es-dev or es-local. This results in combining an
          ES trainer with a A2CAlgorithmConfig, which will fail.
        * Nested elements. If a super- and at least one sub-component are specified as instantiated components or
          DictConfigs, RunContext attempts to inject them in the loaded configuration w.r.t. to the respective
          hierarchy levels. This entails that more specialized attributes will overwrite more general ones.
          E.g.: If "model" and "policy" are specified, the model is injected first, then the policy. That way the
          "model" object will overwrite the loaded model and the "policy" object will overwrite the policy in the
          "model" object.
          This doesn't occur in the CLI, since no non-primitive value can be passed.

        Furthermore, it resolves proxy arguments w.r.t. the current run mode: Non-top level attributes
        exposed in :py:class:`maze.api.run_context.RunContext` (e.g. "critic").
        """

        # 1. Load Hydra configuration for this algorithm and environment.
        self._load_hydra_config()

        # Change to correct working directory (necessary due to being outside of Hydra scope).
        for workdir in self._workdirs:
            with working_directory(workdir):
                for config in self._configs:
                    # Allow non-primitives in Hydra config.
                    with omegaconf.flag_override(config, "allow_objects",
                                                 True) as cfg:
                        OmegaConf.set_struct(cfg, False)

                        # 2. Inject instantiated objects.
                        self._inject_nonprimitive_instances_into_hydra_config(
                            cfg)

                        # 3. Resolve proxy arguments in-place.
                        self._resolve_proxy_arguments(cfg)

                        # 4. Postprocess loaded configuration.
                        self._postprocess_config(cfg)

                    # 5. Set up and return runner.
                    OmegaConf.set_struct(cfg, True)
Example #11
0
def test_flag_override(src, flag_name, flag_value, func, expectation):
    c = OmegaConf.create(src)
    c._set_flag(flag_name, True)
    with expectation:
        func(c)

    with does_not_raise():
        with flag_override(c, flag_name, flag_value):
            func(c)
Example #12
0
def test_flag_override(src: Dict[str, Any], flag_name: str, func: Any,
                       expectation: Any) -> None:
    c = OmegaConf.create(src)
    c._set_flag(flag_name, True)
    with expectation:
        func(c)

    with does_not_raise():
        with flag_override(c, flag_name, False):
            func(c)
Example #13
0
    def _set_value_impl(self,
                        value: Any,
                        flags: Optional[Dict[str, bool]] = None) -> None:
        from omegaconf import OmegaConf, flag_override

        if id(self) == id(value):
            return

        if flags is None:
            flags = {}

        if OmegaConf.is_none(value):
            if not self._is_optional():
                raise ValidationError(
                    "Non optional ListConfig cannot be constructed from None")
            self.__dict__["_content"] = None
        elif get_value_kind(value) == ValueKind.MANDATORY_MISSING:
            self.__dict__["_content"] = "???"
        elif get_value_kind(value) in (
                ValueKind.INTERPOLATION,
                ValueKind.STR_INTERPOLATION,
        ):
            self.__dict__["_content"] = value
        else:
            if not (is_primitive_list(value) or isinstance(value, ListConfig)):
                type_ = type(value)
                msg = f"Invalid value assigned : {type_.__name__} is not a ListConfig, list or tuple."
                raise ValidationError(msg)

            self.__dict__["_content"] = []
            if isinstance(value, ListConfig):
                self.__dict__["_metadata"] = copy.deepcopy(value._metadata)
                self._metadata.flags = copy.deepcopy(flags)
                # disable struct and readonly for the construction phase
                # retaining other flags like allow_objects. The real flags are restored at the end of this function
                with flag_override(self, "struct", False):
                    with flag_override(self, "readonly", False):
                        for item in value._iter_ex(resolve=False):
                            self.append(item)
            elif is_primitive_list(value):
                for item in value:
                    self.append(item)
Example #14
0
def test_insert_throws_not_changing_list() -> None:
    c = OmegaConf.create([])
    iv = IllegalType()
    with pytest.raises(ValueError):
        c.insert(0, iv)
    assert len(c) == 0
    assert c == []

    with flag_override(c, "allow_objects", True):
        c.insert(0, iv)
    assert c == [iv]
Example #15
0
        def strip_defaults(cfg: Any) -> None:
            if isinstance(cfg, DictConfig):
                if cfg._is_missing() or cfg._is_none():
                    return
                with flag_override(cfg, ["readonly", "struct"], False):
                    if cfg._get_flag("HYDRA_REMOVE_TOP_LEVEL_DEFAULTS"):
                        cfg._set_flag("HYDRA_REMOVE_TOP_LEVEL_DEFAULTS", None)
                        cfg.pop("defaults", None)

                for _key, value in cfg.items_ex(resolve=False):
                    strip_defaults(value)
Example #16
0
def test_append_throws_not_changing_list() -> None:
    c = OmegaConf.create([])
    iv = IllegalType()
    with pytest.raises(ValueError):
        c.append(iv)
    assert len(c) == 0
    assert c == []
    validate_list_keys(c)

    with flag_override(c, "allow_objects", True):
        c.append(iv)
    assert c == [iv]
Example #17
0
        def test_allow_objects(self, module: Any) -> None:
            cfg = OmegaConf.structured(module.Plugin)
            iv = IllegalType()
            with raises(UnsupportedValueType):
                cfg.params = iv
            cfg = OmegaConf.structured(module.Plugin, flags={"allow_objects": True})
            cfg.params = iv
            assert cfg.params == iv

            cfg = OmegaConf.structured(module.Plugin)
            with flag_override(cfg, "allow_objects", True):
                cfg.params = iv
                assert cfg.params == iv

            cfg = OmegaConf.structured({"plugin": module.Plugin})
            pwo = module.Plugin(name="foo", params=iv)
            with raises(UnsupportedValueType):
                cfg.plugin = pwo

            with flag_override(cfg, "allow_objects", True):
                cfg.plugin = pwo
                assert cfg.plugin == pwo
Example #18
0
    def app_help(self, config_name: Optional[str], args_parser: ArgumentParser,
                 args: Any) -> None:
        cfg = self.compose_config(
            config_name=config_name,
            overrides=args.overrides,
            run_mode=RunMode.RUN,
            with_log_configuration=True,
        )
        help_cfg = cfg.hydra.help
        clean_cfg = copy.deepcopy(cfg)

        with flag_override(clean_cfg, ["struct", "readonly"], [False, False]):
            del clean_cfg["hydra"]
        help_text = self.get_help(help_cfg, clean_cfg, args_parser)
        print(help_text)
    def _compose_config_from_defaults_list(
        self,
        defaults: List[ResultDefault],
        repo: IConfigRepository,
    ) -> DictConfig:
        cfg = OmegaConf.create()
        with flag_override(cfg, "no_deepcopy_set_nodes", True):
            for default in defaults:
                loaded = self._load_single_config(default=default, repo=repo)
                try:
                    cfg.merge_with(loaded.config)
                except ValidationError as e:
                    raise ConfigCompositionException(
                        f"In '{default.config_path}': Validation error while composing config:\n{e}"
                    ).with_traceback(sys.exc_info()[2])

        return cfg
Example #20
0
    def _print_config_info(self, config_name: Optional[str],
                           overrides: List[str]) -> None:
        assert log is not None
        self._print_search_path(config_name=config_name, overrides=overrides)
        self._print_defaults_tree(config_name=config_name, overrides=overrides)
        self._print_defaults_list(config_name=config_name, overrides=overrides)

        cfg = run_and_report(lambda: self._get_cfg(
            config_name=config_name,
            overrides=overrides,
            cfg_type="all",
            with_log_configuration=False,
        ))
        self._log_header(header="Config", filler="*")
        with flag_override(cfg, ["struct", "readonly"], [False, False]):
            del cfg["hydra"]
        log.info(OmegaConf.to_yaml(cfg))
Example #21
0
 def _get_cfg(
     self,
     config_name: Optional[str],
     overrides: List[str],
     cfg_type: str,
     with_log_configuration: bool,
 ) -> DictConfig:
     assert cfg_type in ["job", "hydra", "all"]
     cfg = self.compose_config(
         config_name=config_name,
         overrides=overrides,
         run_mode=RunMode.RUN,
         with_log_configuration=with_log_configuration,
     )
     if cfg_type == "job":
         with flag_override(cfg, ["struct", "readonly"], [False, False]):
             del cfg["hydra"]
     elif cfg_type == "hydra":
         cfg = self.get_sanitized_hydra_cfg(cfg)
     return cfg
Example #22
0
    def _print_config_info(
        self,
        config_name: Optional[str],
        overrides: List[str],
        run_mode: RunMode = RunMode.RUN,
    ) -> None:
        assert log is not None
        self._print_search_path(config_name=config_name,
                                overrides=overrides,
                                run_mode=run_mode)
        self._print_defaults_tree(config_name=config_name, overrides=overrides)
        self._print_defaults_list(config_name=config_name, overrides=overrides)

        cfg = run_and_report(lambda: self.compose_config(
            config_name=config_name,
            overrides=overrides,
            run_mode=run_mode,
            with_log_configuration=False,
        ))
        HydraConfig.instance().set_config(cfg)
        self._log_header(header="Config", filler="*")
        with flag_override(cfg, ["struct", "readonly"], [False, False]):
            del cfg["hydra"]
        log.info(OmegaConf.to_yaml(cfg))