Example #1
0
def test_initialize() -> None:
    try:
        assert not GlobalHydra().is_initialized()
        initialize(config_dir=None, strict=True)
        assert GlobalHydra().is_initialized()
    finally:
        GlobalHydra().clear()
Example #2
0
def test_initialize_with_config_path(hydra_restore_singletons: Any) -> None:
    assert not GlobalHydra().is_initialized()
    initialize(config_path="../hydra/test_utils/configs")
    assert GlobalHydra().is_initialized()

    gh = GlobalHydra.instance()
    assert gh.hydra is not None
    config_search_path = gh.hydra.config_loader.get_search_path()
    assert isinstance(config_search_path, ConfigSearchPathImpl)
    idx = config_search_path.find_first_match(
        SearchPathQuery(provider="main", path=None))
    assert idx != -1
Example #3
0
def test_initialize_with_config_dir() -> None:
    try:
        assert not GlobalHydra().is_initialized()
        initialize(config_dir="../hydra/test_utils/configs", strict=True)
        assert GlobalHydra().is_initialized()

        gh = GlobalHydra.instance()
        assert gh.hydra is not None
        config_search_path = gh.hydra.config_loader.get_search_path()
        assert isinstance(config_search_path, ConfigSearchPathImpl)
        idx = config_search_path.find_first_match(
            SearchPathQuery(provider="main", search_path=None))
        assert idx != -1
    finally:
        GlobalHydra().clear()
Example #4
0
def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig:

    # Here we are using field values provided in args to override counterparts inside config object
    overrides, deletes = override_module_args(args)

    cfg_name = "config"
    cfg_path = f"../../{cfg_name}"

    if not GlobalHydra().is_initialized():
        initialize(config_path=cfg_path)

    composed_cfg = compose(cfg_name, overrides=overrides, strict=False)
    for k in deletes:
        composed_cfg[k] = None

    cfg = OmegaConf.create(
        OmegaConf.to_container(composed_cfg, resolve=True, enum_to_str=True))

    # hack to be able to set Namespace in dict config. this should be removed when we update to newer
    # omegaconf version that supports object flags, or when we migrate all existing models
    from omegaconf import _utils

    old_primitive = _utils.is_primitive_type
    _utils.is_primitive_type = lambda _: True

    if cfg.task is None and getattr(args, "task", None):
        cfg.task = Namespace(**vars(args))
        from fairseq.tasks import TASK_REGISTRY

        _set_legacy_defaults(cfg.task, TASK_REGISTRY[args.task])
        cfg.task._name = args.task
    if cfg.model is None and getattr(args, "arch", None):
        cfg.model = Namespace(**vars(args))
        from fairseq.models import ARCH_MODEL_REGISTRY

        _set_legacy_defaults(cfg.model, ARCH_MODEL_REGISTRY[args.arch])
        cfg.model._name = args.arch
    if cfg.optimizer is None and getattr(args, "optimizer", None):
        cfg.optimizer = Namespace(**vars(args))
        from fairseq.optim import OPTIMIZER_REGISTRY

        _set_legacy_defaults(cfg.optimizer, OPTIMIZER_REGISTRY[args.optimizer])
        cfg.optimizer._name = args.optimizer
    if cfg.lr_scheduler is None and getattr(args, "lr_scheduler", None):
        cfg.lr_scheduler = Namespace(**vars(args))
        from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY

        _set_legacy_defaults(cfg.lr_scheduler,
                             LR_SCHEDULER_REGISTRY[args.lr_scheduler])
        cfg.lr_scheduler._name = args.lr_scheduler
    if cfg.criterion is None and getattr(args, "criterion", None):
        cfg.criterion = Namespace(**vars(args))
        from fairseq.criterions import CRITERION_REGISTRY

        _set_legacy_defaults(cfg.criterion, CRITERION_REGISTRY[args.criterion])
        cfg.criterion._name = args.criterion

    _utils.is_primitive_type = old_primitive
    OmegaConf.set_struct(cfg, True)
    return cfg
Example #5
0
def compose(
    config_name: Optional[str] = None,
    overrides: List[str] = [],
    strict: Optional[bool] = None,
    return_hydra_config: bool = False,
) -> DictConfig:
    """
    :param config_name: the name of the config
           (usually the file name without the .yaml extension)
    :param overrides: list of overrides for config file
    :param strict: optionally override the default strict mode
    :param return_hydra_config: True to return the hydra config node in the result
    :return: the composed config
    """
    assert GlobalHydra().is_initialized(), (
        "GlobalHydra is not initialized, use @hydra.main()"
        " or call one of the hydra.experimental initialize methods first")

    gh = GlobalHydra.instance()
    assert gh.hydra is not None
    cfg = gh.hydra.compose_config(
        config_name=config_name,
        overrides=overrides,
        run_mode=RunMode.RUN,
        strict=strict,
        from_shell=False,
    )
    assert isinstance(cfg, DictConfig)

    if not return_hydra_config:
        if "hydra" in cfg:
            with open_dict(cfg):
                del cfg["hydra"]
    return cfg
Example #6
0
def test_initialize_old_version_base(hydra_restore_singletons: Any) -> None:
    assert not GlobalHydra().is_initialized()
    with raises(
        HydraException,
        match=f'version_base must be >= "{version.__compat_version__}"',
    ):
        initialize(version_base="1.0")
Example #7
0
    def __enter__(self) -> "SweepTaskFunction":
        overrides = copy.deepcopy(self.overrides)
        assert overrides is not None
        if self.temp_dir:
            Path(self.temp_dir).mkdir(parents=True, exist_ok=True)
        else:
            self.temp_dir = tempfile.mkdtemp()
        overrides.append(f"hydra.sweep.dir={self.temp_dir}")

        try:
            config_dir, config_name = split_config_path(
                self.config_path, self.config_name)
            job_name = detect_task_name(self.calling_file, self.calling_module)

            hydra_ = Hydra.create_main_hydra_file_or_module(
                calling_file=self.calling_file,
                calling_module=self.calling_module,
                config_path=config_dir,
                job_name=job_name,
                strict=self.strict,
            )

            self.returns = hydra_.multirun(
                config_name=config_name,
                task_function=self,
                overrides=overrides,
                with_log_configuration=self.configure_logging,
            )
        finally:
            GlobalHydra().clear()

        return self
Example #8
0
    def __enter__(self) -> "TaskTestFunction":
        try:
            config_dir, config_name = split_config_path(
                self.config_path, self.config_name)

            job_name = detect_task_name(self.calling_file, self.calling_module)

            self.hydra = Hydra.create_main_hydra_file_or_module(
                calling_file=self.calling_file,
                calling_module=self.calling_module,
                config_path=config_dir,
                job_name=job_name,
                strict=self.strict,
            )
            self.temp_dir = tempfile.mkdtemp()
            overrides = copy.deepcopy(self.overrides)
            assert overrides is not None
            overrides.append(f"hydra.run.dir={self.temp_dir}")
            self.job_ret = self.hydra.run(
                config_name=config_name,
                task_function=self,
                overrides=overrides,
                with_log_configuration=self.configure_logging,
            )
            return self
        finally:
            GlobalHydra().clear()
Example #9
0
def test_initialize_bad_version_base(hydra_restore_singletons: Any) -> None:
    assert not GlobalHydra().is_initialized()
    with raises(
        TypeError,
        match="expected string or bytes-like object",
    ):
        initialize(version_base=1.1)  # type: ignore
Example #10
0
def compose(
    config_name: Optional[str] = None,
    overrides: List[str] = [],
    strict: Optional[bool] = None,
) -> DictConfig:
    """
    :param config_name: optional config name to load
    :param overrides: list of overrides for config file
    :param strict: optionally override the default strict mode
    :return: the composed config
    """
    assert (
        GlobalHydra().is_initialized()
    ), "GlobalHydra is not initialized, use @hydra.main() or call hydra.experimental.initialize() first"

    gh = GlobalHydra.instance()
    assert gh.hydra is not None
    cfg = gh.hydra.compose_config(
        config_name=config_name, overrides=overrides, strict=strict
    )
    assert isinstance(cfg, DictConfig)

    if "hydra" in cfg:
        with open_dict(cfg):
            del cfg["hydra"]
    return cfg
Example #11
0
def test_initialize_compat_version_base(hydra_restore_singletons: Any) -> None:
    assert not GlobalHydra().is_initialized()
    with raises(
        UserWarning,
        match=f"Will assume defaults for version {version.__compat_version__}",
    ):
        initialize()
    assert version.base_at_least(str(version.__compat_version__))
Example #12
0
def compose(
    config_name: Optional[str] = None,
    overrides: List[str] = [],
    return_hydra_config: bool = False,
    strict: Optional[bool] = None,
) -> DictConfig:
    """
    :param config_name: the name of the config
           (usually the file name without the .yaml extension)
    :param overrides: list of overrides for config file
    :param return_hydra_config: True to return the hydra config node in the result
    :param strict: DEPRECATED. If false, returned config has struct mode disabled.
    :return: the composed config
    """
    assert (
        GlobalHydra().is_initialized()
    ), "GlobalHydra is not initialized, use @hydra.main() or call one of the hydra initialization methods first"

    gh = GlobalHydra.instance()
    assert gh.hydra is not None
    cfg = gh.hydra.compose_config(
        config_name=config_name,
        overrides=overrides,
        run_mode=RunMode.RUN,
        from_shell=False,
        with_log_configuration=False,
    )
    assert isinstance(cfg, DictConfig)

    if not return_hydra_config:
        if "hydra" in cfg:
            with open_dict(cfg):
                del cfg["hydra"]

    if strict is not None:
        # DEPRECATED: remove in 1.2
        deprecation_warning(
            dedent(
                """
                The strict flag in the compose API is deprecated and will be removed in the next version of Hydra.
                See https://hydra.cc/docs/upgrades/0.11_to_1.0/strict_mode_flag_deprecated for more info.
                """
            )
        )
        OmegaConf.set_struct(cfg, strict)

    return cfg
Example #13
0
    def get_test_configuration(config_override: list = []):
        test_path = os.path.dirname(os.path.abspath(__file__))
        source_path = os.path.join(test_path, 'data')
        output_path = os.path.join(test_path, 'output')

        if not os.path.exists(output_path):
            os.makedirs(output_path)

        if not OmegaConf.has_resolver('output_path'):
            OmegaConf.register_new_resolver('output_path',
                                            lambda sub_path: output_path)
            OmegaConf.register_new_resolver('source_path',
                                            lambda sub_path: source_path)

        if not GlobalHydra().is_initialized():
            hydra.initialize(config_path='data/config', caller_stack_depth=2)
        return hydra.compose("test_config.yaml", overrides=config_override)
Example #14
0
    def __enter__(self) -> "TaskTestFunction":
        try:
            config_dir, config_file = split_config_path(self.config_path)

            self.hydra = Hydra.create_main_hydra_file_or_module(
                calling_file=self.calling_file,
                calling_module=self.calling_module,
                config_dir=config_dir,
                strict=self.strict,
            )
            self.temp_dir = tempfile.mkdtemp()
            overrides = copy.deepcopy(self.overrides)
            assert overrides is not None
            overrides.append("hydra.run.dir={}".format(self.temp_dir))
            self.job_ret = self.hydra.run(config_file=config_file,
                                          task_function=self,
                                          overrides=overrides)
            return self
        finally:
            GlobalHydra().clear()
Example #15
0
    def __enter__(self) -> "SweepTaskFunction":
        self.temp_dir = tempfile.mkdtemp()
        overrides = copy.deepcopy(self.overrides)
        assert overrides is not None
        overrides.append(f"hydra.sweep.dir={self.temp_dir}")
        try:
            config_dir, config_name = split_config_path(
                self.config_path, self.config_name)
            hydra_ = Hydra.create_main_hydra_file_or_module(
                calling_file=self.calling_file,
                calling_module=self.calling_module,
                config_dir=config_dir,
                strict=self.strict,
            )

            self.returns = hydra_.multirun(config_name=config_name,
                                           task_function=self,
                                           overrides=overrides)
        finally:
            GlobalHydra().clear()

        return self
Example #16
0
def test_initialize(hydra_restore_singletons: Any) -> None:
    assert not GlobalHydra().is_initialized()
    initialize(version_base=None)
    assert GlobalHydra().is_initialized()
Example #17
0
def test_initialize(hydra_restore_singletons: Any) -> None:
    assert not GlobalHydra().is_initialized()
    initialize(config_path=None)
    assert GlobalHydra().is_initialized()
Example #18
0
def test_initialize_dev_version_base(hydra_restore_singletons: Any) -> None:
    assert not GlobalHydra().is_initialized()
    # packaging will compare "1.2.0.dev2" < "1.2", so need to ensure handled correctly
    initialize(version_base="1.2.0.dev2")
    assert version.base_at_least("1.2")
Example #19
0
def test_initialize_cur_version_base(hydra_restore_singletons: Any) -> None:
    assert not GlobalHydra().is_initialized()
    initialize(version_base=None)
    assert version.base_at_least(__version__)
Example #20
0
 def __exit__(self, exc_type, exc_val, exc_tb) -> None:  # type: ignore
     GlobalHydra().clear()
Example #21
0
def initialize(config_path: str,
               config_name: str,
               caller_stack_depth: int = 1) -> None:
    r"""Performs initialization needed for configuring multiruns / parameter sweeps via
    a YAML config file. Currently this is only needed if multirun configuration via a YAML
    file is desired. Otherwise, :func:`combustion.main` can be used without calling ``initialize``.
    See :func:`combustion.main` for usage examples.

    .. warning::
        This method makes use of Hydra's compose API, which is experimental as of version 1.0.

    .. warning::
        This method works by inspecting the "sweeper" section of the specified config file
        and altering ``sys.argv`` to include the chosen sweeper parameters.

    Args:
        config_path (str):
            Path to main configuration file. See :func:`hydra.main` for more details.

        config_name (str):
            Name of the main configuration file. See :func:`hydra.main` for more details.

        caller_stack_depth (int):
            Stack depth when calling :func:`initialize`. Defaults to 1 (direct caller).


    Sample sweeper Hydra config
        .. code-block:: yaml

            sweeper:
              model.params.batch_size: 8,16,32
              optimizer.params.lr: 0.001,0.002,0.0003
    """
    assert caller_stack_depth >= 1
    caller_stack_depth += 1

    if version.parse(hydra.__version__) < version.parse("1.0.0rc2"):
        raise ImportError(
            f"Sweeping requires hydra>=1.0.0rc2, but you have {hydra.__version__}"
        )

    gh = GlobalHydra.instance()
    if GlobalHydra().is_initialized():
        gh.clear()

    flags = [x for x in sys.argv[1:] if x[0] == "-"]
    overrides = [x for x in sys.argv[1:] if x[0] != "-"]

    # split argv into dict
    overrides_dict = {}
    for x in overrides:
        key, value = re.split(r"=|\s", x, maxsplit=1)
        overrides_dict[key] = value

    # use compose api to inspect multirun values
    with hydra.experimental.initialize(config_path,
                                       caller_stack_depth=caller_stack_depth):
        assert gh.hydra is not None
        cfg = gh.hydra.compose_config(
            config_name=config_name,
            overrides=overrides,
            run_mode=RunMode.MULTIRUN,
        )
        assert isinstance(cfg, DictConfig)

        if "sweeper" in cfg.keys() and cfg.sweeper:
            log.debug("Using sweeper values: %s", cfg.sweeper)
            overrides_dict.update(cfg.sweeper)

            if "--multirun" not in flags and "-m" not in flags:
                log.warning(
                    "Multirun flag not given but sweeper config was non-empty. "
                    "Adding -m flag")
                flags.append("--multirun")
        else:
            log.debug("No sweeper config specified")

    # append key value pairs in sweeper config to sys.argv
    overrides = [f"{key}={value}" for key, value in overrides_dict.items()]
    sys.argv = ([
        sys.argv[0],
    ] + flags + overrides)