Esempio n. 1
0
def configure_log(log_config: DictConfig,
                  verbose_config: Union[bool, str, Sequence[str]]) -> None:
    assert isinstance(verbose_config,
                      (bool, str)) or OmegaConf.is_list(verbose_config)
    if log_config is not None:
        conf: Dict[str, Any] = OmegaConf.to_container(  # type: ignore
            log_config, resolve=True)
        logging.config.dictConfig(conf)
    else:
        # default logging to stdout
        root = logging.getLogger()
        root.setLevel(logging.INFO)
        handler = logging.StreamHandler(sys.stdout)
        formatter = logging.Formatter(
            "[%(asctime)s][%(name)s][%(levelname)s] - %(message)s")
        handler.setFormatter(formatter)
        root.addHandler(handler)
    if isinstance(verbose_config, bool):
        if verbose_config:
            logging.getLogger().setLevel(logging.DEBUG)
    else:
        if isinstance(verbose_config, str):
            verbose_list = OmegaConf.create([verbose_config])
        elif OmegaConf.is_list(verbose_config):
            verbose_list = verbose_config  # type: ignore
        else:
            assert False

        for logger in verbose_list:
            logging.getLogger(logger).setLevel(logging.DEBUG)
Esempio n. 2
0
def check_if_validation(cfg: DictConfig) -> bool:
    flag = False
    flag |= OmegaConf.is_list(cfg.training.lr)
    flag |= OmegaConf.is_list(cfg.training.regularization)

    if check_if_gev_loss(get_loss_class(cfg)):
        flag |= OmegaConf.is_list(cfg.loss.xi)

    return flag
Esempio n. 3
0
def _get_kwargs(
    config: Union[DictConfig, ListConfig],
    **kwargs: Any,
) -> Any:
    from hydra.utils import instantiate

    assert OmegaConf.is_config(config)

    if OmegaConf.is_list(config):
        assert isinstance(config, ListConfig)
        return [
            _get_kwargs(x) if OmegaConf.is_config(x) else x for x in config
        ]

    assert OmegaConf.is_dict(
        config), "Input config is not an OmegaConf DictConfig"

    final_kwargs = {}

    recursive = _is_recursive(config, kwargs)
    overrides = OmegaConf.create(kwargs, flags={"allow_objects": True})
    config.merge_with(overrides)

    for k, v in config.items():
        final_kwargs[k] = v

    if recursive:
        for k, v in final_kwargs.items():
            if _is_target(v):
                final_kwargs[k] = instantiate(v)
            elif OmegaConf.is_dict(v) and not OmegaConf.is_none(v):
                d = OmegaConf.create({}, flags={"allow_objects": True})
                for key, value in v.items():
                    if _is_target(value):
                        d[key] = instantiate(value)
                    elif OmegaConf.is_config(value):
                        d[key] = _get_kwargs(value)
                    else:
                        d[key] = value
                final_kwargs[k] = d
            elif OmegaConf.is_list(v):
                lst = OmegaConf.create([], flags={"allow_objects": True})
                for x in v:
                    if _is_target(x):
                        lst.append(instantiate(x))
                    elif OmegaConf.is_config(x):
                        lst.append(_get_kwargs(x))
                    else:
                        lst.append(x)
                final_kwargs[k] = lst
            else:
                if OmegaConf.is_none(v):
                    v = None
                final_kwargs[k] = v

    return final_kwargs
Esempio n. 4
0
    def __init__(self, fn):
        """
        load a yaml config of a job and save generated configs as yaml for each task.
        return: a list of files to run as specified by `run_task`.
        """
        if fn.endswith(".py"):
            # a python command.
            self.backend = "python"
            self.run_yamls = [fn]
            return

        job_config = recursive_config(fn)
        if job_config.base_dir is None:  # single file job config.
            self.run_yamls = [fn]
            return

        self.project_dir = os.path.join("projects", job_config.project_dir)
        self.run_dir = os.path.join("runs", job_config.project_dir)

        if job_config.run_task is not None:
            run_yamls = []
            for stage in job_config.run_task:
                # each stage can have multiple tasks running in parallel.
                if OmegaConf.is_list(stage):
                    stage_yamls = []
                    for task_file in stage:
                        stage_yamls.append(
                            os.path.join(self.project_dir, task_file))
                    run_yamls.append(stage_yamls)
                else:
                    run_yamls.append(os.path.join(self.project_dir, stage))
            self.run_yamls = run_yamls
        configs_to_save = self._overwrite_task(job_config)
        self._save_configs(configs_to_save)
Esempio n. 5
0
    def _recurse_on_config(self, config):
        if OmegaConf.is_list(
                config) and len(config) > 0 and "url" in config[0]:
            # Found the urls, let's test them
            for item in config:
                # First try making the DownloadableFile class to make sure
                # everything is fine
                download = DownloadableFile(**item)
                # Now check the actual header 3 times before failing
                for i in range(3):
                    try:
                        check_header(download._url,
                                     from_google=download._from_google)
                        break
                    except AssertionError:
                        if i == 2:
                            raise
                        else:
                            # If failed, add a sleep of 5 seconds before retrying
                            time.sleep(2)

        elif OmegaConf.is_dict(config):
            # Both version and resources should be present
            if "version" in config:
                self.assertIn("resources", config)
            if "resources" in config:
                self.assertIn("version", config)

            # Let's continue recursing
            for item in config:
                self._recurse_on_config(config[item])
Esempio n. 6
0
    def __init__(self, config, *args, **kwargs):
        transform_params = config.transforms
        assert OmegaConf.is_dict(transform_params) or OmegaConf.is_list(
            transform_params)
        if OmegaConf.is_dict(transform_params):
            transform_params = [transform_params]
        pytorchvideo_spec = importlib.util.find_spec("pytorchvideo")
        assert (pytorchvideo_spec is not None
                ), "Must have pytorchvideo installed to use VideoTransforms"

        transforms_list = []

        for param in transform_params:
            if OmegaConf.is_dict(param):
                # This will throw config error if missing
                transform_type = param.type
                transform_param = param.get("params", OmegaConf.create({}))
            else:
                assert isinstance(param, str), (
                    "Each transform should either be str or dict containing "
                    "type and params")
                transform_type = param
                transform_param = OmegaConf.create([])

            transforms_list.append(
                self.get_transform_object(transform_type, transform_param))

        self.transform = img_transforms.Compose(transforms_list)
Esempio n. 7
0
    def _recurse_on_config(self, config: DictConfig,
                           callback: typing.Callable):
        if OmegaConf.is_list(
                config) and len(config) > 0 and "url" in config[0]:
            # Found the urls, let's test them
            for item in config:
                # flickr30 download source is down, ignore dataset until a
                # mirror can be found
                if getattr(item, "file_name", "") == "flickr30_images.tar.gz":
                    continue
                # First try making the DownloadableFile class to make sure
                # everything is fine
                download = DownloadableFile(**item)
                # Now, call the actual callback which will test specific scenarios
                callback(download)

        elif OmegaConf.is_dict(config):
            # Both version and resources should be present
            if "version" in config:
                self.assertIn("resources", config)
            if "resources" in config:
                self.assertIn("version", config)

            # Let's continue recursing
            for item in config:
                self._recurse_on_config(config[item], callback=callback)
Esempio n. 8
0
    def _get_matches(config: Container, word: str) -> List[str]:
        def str_rep(in_key: Any, in_value: Any) -> str:
            if OmegaConf.is_config(in_value):
                return f"{in_key}."
            else:
                return f"{in_key}="

        if config is None:
            return []
        elif OmegaConf.is_config(config):
            matches = []
            if word.endswith(".") or word.endswith("="):
                exact_key = word[0:-1]
                try:
                    conf_node = OmegaConf.select(
                        config, exact_key, throw_on_missing=True
                    )
                except MissingMandatoryValue:
                    conf_node = ""
                if conf_node is not None:
                    if OmegaConf.is_config(conf_node):
                        key_matches = CompletionPlugin._get_matches(conf_node, "")
                    else:
                        # primitive
                        if isinstance(conf_node, bool):
                            conf_node = str(conf_node).lower()
                        key_matches = [conf_node]
                else:
                    key_matches = []

                matches.extend([f"{word}{match}" for match in key_matches])
            else:
                last_dot = word.rfind(".")
                if last_dot != -1:
                    base_key = word[0:last_dot]
                    partial_key = word[last_dot + 1 :]
                    conf_node = OmegaConf.select(config, base_key)
                    key_matches = CompletionPlugin._get_matches(conf_node, partial_key)
                    matches.extend([f"{base_key}.{match}" for match in key_matches])
                else:
                    if isinstance(config, DictConfig):
                        for key, value in config.items_ex(resolve=False):
                            str_key = str(key)
                            if str_key.startswith(word):
                                matches.append(str_rep(key, value))
                    elif OmegaConf.is_list(config):
                        assert isinstance(config, ListConfig)
                        for idx in range(len(config)):
                            try:
                                value = config[idx]
                                if str(idx).startswith(word):
                                    matches.append(str_rep(idx, value))
                            except MissingMandatoryValue:
                                matches.append(str_rep(idx, ""))

        else:
            assert False, f"Object is not an instance of config : {type(config)}"

        return matches
Esempio n. 9
0
    def _get_matches(config: Container, word: str) -> List[str]:
        def str_rep(in_key: Union[str, int], in_value: Any) -> str:
            if OmegaConf.is_config(in_value):
                return "{}.".format(in_key)
            else:
                return "{}=".format(in_key)

        if config is None:
            return []
        elif OmegaConf.is_config(config):
            matches = []
            if word.endswith(".") or word.endswith("="):
                exact_key = word[0:-1]
                try:
                    conf_node = config.select(exact_key)
                except MissingMandatoryValue:
                    conf_node = ""
                if conf_node is not None:
                    if OmegaConf.is_config(conf_node):
                        key_matches = CompletionPlugin._get_matches(
                            conf_node, "")
                    else:
                        # primitive
                        if isinstance(conf_node, bool):
                            conf_node = str(conf_node).lower()
                        key_matches = [conf_node]
                else:
                    key_matches = []

                matches.extend(
                    ["{}{}".format(word, match) for match in key_matches])
            else:
                last_dot = word.rfind(".")
                if last_dot != -1:
                    base_key = word[0:last_dot]
                    partial_key = word[last_dot + 1:]
                    conf_node = config.select(base_key)
                    key_matches = CompletionPlugin._get_matches(
                        conf_node, partial_key)
                    matches.extend([
                        "{}.{}".format(base_key, match)
                        for match in key_matches
                    ])
                else:
                    if isinstance(config, DictConfig):
                        for key, value in config.items_ex(resolve=False):
                            if key.startswith(word):
                                matches.append(str_rep(key, value))
                    elif OmegaConf.is_list(config):
                        for idx, value in enumerate(config):
                            if str(idx).startswith(word):
                                matches.append(str_rep(idx, value))
        else:
            assert False, "Object is not an instance of config : {}".format(
                type(config))

        return matches
Esempio n. 10
0
    def __init__(self, config, *args, **kwargs):
        transform_params = config.transforms
        assert OmegaConf.is_dict(transform_params) or OmegaConf.is_list(
            transform_params)
        if OmegaConf.is_dict(transform_params):
            transform_params = [transform_params]

        transforms_list = []

        for param in transform_params:
            if OmegaConf.is_dict(param):
                # This will throw config error if missing
                transform_type = param.type
                transform_param = param.get("params", OmegaConf.create({}))
            else:
                assert isinstance(param, str), (
                    "Each transform should either be str or dict containing " +
                    "type and params")
                transform_type = param
                transform_param = OmegaConf.create([])

            transform = getattr(transforms, transform_type, None)
            if transform is None:
                from mmf.utils.env import setup_torchaudio

                setup_torchaudio()
                from torchaudio import transforms as torchaudio_transforms

                transform = getattr(torchaudio_transforms, transform_type,
                                    None)
            # If torchvision or torchaudiodoesn't contain this, check our registry
            # if we implemented a custom transform as processor
            if transform is None:
                transform = registry.get_processor_class(transform_type)
            assert transform is not None, (
                f"transform {transform_type} is not present in torchvision, " +
                "torchaudio or processor registry")

            # https://github.com/omry/omegaconf/issues/248
            transform_param = OmegaConf.to_container(transform_param)
            # If a dict, it will be passed as **kwargs, else a list is *args
            if isinstance(transform_param, collections.abc.Mapping):
                transform_object = transform(**transform_param)
            else:
                transform_object = transform(*transform_param)

            transforms_list.append(transform_object)

        self.transform = transforms.Compose(transforms_list)
Esempio n. 11
0
    def _recur_log_omegaconf_param(self, key: str, val: any) -> None:
        if OmegaConf.is_dict(val):
            for k, v in val.items():
                k: str
                v: any

                if key == "":
                    self._recur_log_omegaconf_param(k, v)
                else:
                    self._recur_log_omegaconf_param(f"{key}.{k}", v)
        elif OmegaConf.is_list(val):
            for i, v in enumerate(val):
                i: int
                v: any

                self._recur_log_omegaconf_param(f"{key}[{i}]", v)
        else:
            self.log_param(key, str(val))
Esempio n. 12
0
    def configure_optimizers(self):
        if OmegaConf.is_list(self.optimizer_args):
            kwargs_dict = {}
            for kwargs in self.optimizer_args:
                param_names = kwargs.pop('params')
                if param_names == 'default':
                    default_kwargs = kwargs
                else:
                    if isinstance(param_names, str):
                        param_names = [param_names]

                    for param in param_names:
                        kwargs_dict[param] = kwargs

            optimized_params = []
            for n, p in self.model.named_parameters():
                for i, (param, kwargs) in enumerate(kwargs_dict.items()):
                    if param in n:
                        optimized_params.append({'params': p, **kwargs})
                        break
                    elif i == len(kwargs_dict) - 1:
                        optimized_params.append({
                            'params': p,
                        })

            optimizer = getattr(optimizers, self.optimizer)(optimized_params,
                                                            **default_kwargs)

        elif OmegaConf.is_dict(self.optimizer_args):
            optimizer = getattr(optimizers,
                                self.optimizer)(self.parameters(),
                                                **self.optimizer_args)
        else:
            raise TypeError

        if self.scheduler is None:
            return optimizer
        else:
            scheduler = getattr(optimizers,
                                self.scheduler)(optimizer,
                                                **self.scheduler_args)
            return [optimizer], [scheduler]
Esempio n. 13
0
    def _recurse_on_config(self, config):
        if OmegaConf.is_list(
                config) and len(config) > 0 and "url" in config[0]:
            # Found the urls, let's test them
            for item in config:
                # First try making the DownloadableFile class to make sure
                # everything is fine
                download = DownloadableFile(**item)
                # Now check the actual header
                check_header(download._url, from_google=download._from_google)
        elif OmegaConf.is_dict(config):
            # Both version and resources should be present
            if "version" in config:
                self.assertIn("resources", config)
            if "resources" in config:
                self.assertIn("version", config)

            # Let's continue recursing
            for item in config:
                self._recurse_on_config(config[item])
Esempio n. 14
0
    def _recurse_on_config(self, config: DictConfig,
                           callback: typing.Callable):
        if OmegaConf.is_list(
                config) and len(config) > 0 and "url" in config[0]:
            # Found the urls, let's test them
            for item in config:
                # First try making the DownloadableFile class to make sure
                # everything is fine
                download = DownloadableFile(**item)
                # Now, call the actual callback which will test specific scenarios
                callback(download)

        elif OmegaConf.is_dict(config):
            # Both version and resources should be present
            if "version" in config:
                self.assertIn("resources", config)
            if "resources" in config:
                self.assertIn("version", config)

            # Let's continue recursing
            for item in config:
                self._recurse_on_config(config[item], callback=callback)
Esempio n. 15
0
def test_is_list(cfg: Any, expected: bool) -> None:
    assert OmegaConf.is_list(cfg) == expected
Esempio n. 16
0
    def __init__(self, config, mode, data_path, label_file):

        assert mode in ['train', 'val', 'test', 'predict']

        self.config = config

        self.target_type = config.target_type
        assert self.target_type in ['generative', 'predict']

        if self.target_type in ['predict']:  # list 前提かな?
            self.target_cols = config.input_cols
            self.pred_time = config.pred_time

        self.window_size = config.window_size
        self.slide_step = eval(f'config.{mode}.slide_step')

        self.featrue_cols = []
        if 'feature_cols' in config:
            if OmegaConf.is_list(config.feature_cols):
                self.feature_cols = OmegaConf.to_container(config.feature_cols)

        self.data_file_type = config.file_type

        if not data_path.exists():
            raise FileNotFoundError(f'{data_path} does not found.')
        ''' yaml ファイルで train, val ファイルを直接指定することにする '''
        if self.data_file_type == 'file':
            datafiles = eval(f'config.{mode}.file')
            if OmegaConf.is_list(datafiles):
                datafiles = OmegaConf.to_container(datafiles)
            if type(datafiles) is list:
                data_file_paths = []
                for file in datafiles:
                    data_path_tmp = data_path / file
                    if not data_path_tmp.exists():
                        data_path_tmp = data_path.parent / file
                    data_file_paths.append(data_path_tmp)
                self.data_file_paths = [
                    file.resolve() for file in data_file_paths
                ]
            else:
                data_path_tmp = data_path / datafiles
                if not data_path_tmp.exists():
                    data_path = data_path.parent / datafiles
                else:
                    data_path = data_path_tmp
                self.data_file_paths = [data_path.resolve()]
        elif self.data_file_type == 'dir':
            prefix = ''
            extension = 'csv'
            if 'file_prefix' in config:
                prefix = config.file_prefix
            if 'file_extension' in config:
                extension = config.file_extension
            data_file_paths = data_path.glob(prefix + '*.' + extension)
            self.data_file_paths = [file.resolve() for file in data_file_paths]
        """ label file の処理 """
        if label_file is not None:
            self.have_label = True
            self.labels = self.load_label(label_file)
        else:
            self.have_label = False
Esempio n. 17
0
def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
    """
    :param config: An config object describing what to call and what params to use.
                   In addition to the parameters, the config must contain:
                   _target_ : target class or callable name (str)
                   And may contain:
                   _args_: List-like of positional arguments to pass to the target
                   _recursive_: Construct nested objects as well (bool).
                                True by default.
                                may be overridden via a _recursive_ key in
                                the kwargs
                   _convert_: Conversion strategy
                        none    : Passed objects are DictConfig and ListConfig, default
                        partial : Passed objects are converted to dict and list, with
                                  the exception of Structured Configs (and their fields).
                        all     : Passed objects are dicts, lists and primitives without
                                  a trace of OmegaConf containers
                   _partial_: If True, return functools.partial wrapped method or object
                              False by default. Configure per target.
    :param args: Optional positional parameters pass-through
    :param kwargs: Optional named parameters to override
                   parameters in the config object. Parameters not present
                   in the config objects are being passed as is to the target.
                   IMPORTANT: dataclasses instances in kwargs are interpreted as config
                              and cannot be used as passthrough
    :return: if _target_ is a class name: the instantiated object
             if _target_ is a callable: the return value of the call
    """

    # Return None if config is None
    if config is None:
        return None

    # TargetConf edge case
    if isinstance(config, TargetConf) and config._target_ == "???":
        # Specific check to give a good warning about failure to annotate _target_ as a string.
        raise InstantiationException(
            dedent(f"""\
                Config has missing value for key `_target_`, cannot instantiate.
                Config type: {type(config).__name__}
                Check that the `_target_` key in your dataclass is properly annotated and overridden.
                A common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'"""
                   ))
        # TODO: print full key

    if isinstance(config, (dict, list)):
        config = _prepare_input_dict_or_list(config)

    kwargs = _prepare_input_dict_or_list(kwargs)

    # Structured Config always converted first to OmegaConf
    if is_structured_config(config) or isinstance(config, (dict, list)):
        config = OmegaConf.structured(config, flags={"allow_objects": True})

    if OmegaConf.is_dict(config):
        # Finalize config (convert targets to strings, merge with kwargs)
        config_copy = copy.deepcopy(config)
        config_copy._set_flag(flags=["allow_objects", "struct", "readonly"],
                              values=[True, False, False])
        config_copy._set_parent(config._get_parent())
        config = config_copy

        if kwargs:
            config = OmegaConf.merge(config, kwargs)

        OmegaConf.resolve(config)

        _recursive_ = config.pop(_Keys.RECURSIVE, True)
        _convert_ = config.pop(_Keys.CONVERT, ConvertMode.NONE)
        _partial_ = config.pop(_Keys.PARTIAL, False)

        return instantiate_node(config,
                                *args,
                                recursive=_recursive_,
                                convert=_convert_,
                                partial=_partial_)
    elif OmegaConf.is_list(config):
        # Finalize config (convert targets to strings, merge with kwargs)
        config_copy = copy.deepcopy(config)
        config_copy._set_flag(flags=["allow_objects", "struct", "readonly"],
                              values=[True, False, False])
        config_copy._set_parent(config._get_parent())
        config = config_copy

        OmegaConf.resolve(config)

        _recursive_ = kwargs.pop(_Keys.RECURSIVE, True)
        _convert_ = kwargs.pop(_Keys.CONVERT, ConvertMode.NONE)
        _partial_ = kwargs.pop(_Keys.PARTIAL, False)

        if _partial_:
            raise InstantiationException(
                "The _partial_ keyword is not compatible with top-level list instantiation"
            )

        return instantiate_node(config,
                                *args,
                                recursive=_recursive_,
                                convert=_convert_,
                                partial=_partial_)
    else:
        raise InstantiationException(
            dedent(f"""\
                Cannot instantiate config of type {type(config).__name__}.
                Top level config must be an OmegaConf DictConfig/ListConfig object,
                a plain dict/list, or a Structured Config class or instance."""
                   ))
def test_is_config(cfg: Any, is_conf: bool, is_list: bool, is_dict: bool) -> None:
    assert OmegaConf.is_config(cfg) == is_conf
    assert OmegaConf.is_list(cfg) == is_list
    assert OmegaConf.is_dict(cfg) == is_dict
Esempio n. 19
0
def test_is_config(cfg: Any, is_conf: bool, is_list: bool, is_dict: bool,
                   type_: Type[Any]) -> None:
    assert OmegaConf.is_config(cfg) == is_conf
    assert OmegaConf.is_list(cfg) == is_list
    assert OmegaConf.is_dict(cfg) == is_dict
    assert OmegaConf.get_type(cfg) == type_
Esempio n. 20
0
def _get_kwargs(
    config: Union[DictConfig, ListConfig],
    root: bool = True,
    **kwargs: Any,
) -> Any:
    from hydra.utils import instantiate

    assert OmegaConf.is_config(config)

    if OmegaConf.is_list(config):
        assert isinstance(config, ListConfig)
        return [
            _get_kwargs(x, root=False) if OmegaConf.is_config(x) else x
            for x in config
        ]

    assert OmegaConf.is_dict(
        config), "Input config is not an OmegaConf DictConfig"

    recursive = _is_recursive(config, kwargs)
    overrides = OmegaConf.create(kwargs, flags={"allow_objects": True})
    config.merge_with(overrides)

    final_kwargs = OmegaConf.create(flags={"allow_objects": True})
    final_kwargs._set_parent(config._get_parent())
    final_kwargs._set_flag("readonly", False)
    final_kwargs._set_flag("struct", False)
    if recursive:
        for k, v in config.items_ex(resolve=False):
            if OmegaConf.is_none(v):
                final_kwargs[k] = v
            elif _is_target(v):
                final_kwargs[k] = instantiate(v)
            elif OmegaConf.is_dict(v):
                d = OmegaConf.create({}, flags={"allow_objects": True})
                for key, value in v.items_ex(resolve=False):
                    if _is_target(value):
                        d[key] = instantiate(value)
                    elif OmegaConf.is_config(value):
                        d[key] = _get_kwargs(value, root=False)
                    else:
                        d[key] = value
                d._metadata.object_type = v._metadata.object_type
                final_kwargs[k] = d
            elif OmegaConf.is_list(v):
                lst = OmegaConf.create([], flags={"allow_objects": True})
                for x in v:
                    if _is_target(x):
                        lst.append(instantiate(x))
                    elif OmegaConf.is_config(x):
                        lst.append(_get_kwargs(x, root=False))
                        lst[-1]._metadata.object_type = x._metadata.object_type
                    else:
                        lst.append(x)
                final_kwargs[k] = lst
            else:
                final_kwargs[k] = v
    else:
        for k, v in config.items_ex(resolve=False):
            final_kwargs[k] = v

    final_kwargs._set_flag("readonly", None)
    final_kwargs._set_flag("struct", None)
    final_kwargs._set_flag("allow_objects", None)
    if not root:
        # This is tricky, since the root kwargs is exploded anyway we can treat is as an untyped dict
        # the motivation is that the object type is used as an indicator to treat the object differently during
        # conversion to a primitive container in some cases
        final_kwargs._metadata.object_type = config._metadata.object_type
    return final_kwargs
Esempio n. 21
0
def _wrap_with_list(x):
    if OmegaConf.is_list(x):
        return x
    else:
        return [x]
Esempio n. 22
0
def instantiate_node(
    node: Any,
    *args: Any,
    convert: Union[str, ConvertMode] = ConvertMode.NONE,
    recursive: bool = True,
) -> Any:
    # Return None if config is None
    if node is None or OmegaConf.is_none(node):
        return None

    if not OmegaConf.is_config(node):
        return node

    # Override parent modes from config if specified
    if OmegaConf.is_dict(node):
        # using getitem instead of get(key, default) because OmegaConf will raise an exception
        # if the key type is incompatible on get.
        convert = node[_Keys.CONVERT] if _Keys.CONVERT in node else convert
        recursive = node[
            _Keys.RECURSIVE] if _Keys.RECURSIVE in node else recursive

    if not isinstance(recursive, bool):
        raise TypeError(
            f"_recursive_ flag must be a bool, got {type(recursive)}")

    # If OmegaConf list, create new list of instances if recursive
    if OmegaConf.is_list(node):
        items = [
            instantiate_node(item, convert=convert, recursive=recursive)
            for item in node._iter_ex(resolve=True)
        ]

        if convert in (ConvertMode.ALL, ConvertMode.PARTIAL):
            # If ALL or PARTIAL, use plain list as container
            return items
        else:
            # Otherwise, use ListConfig as container
            lst = OmegaConf.create(items, flags={"allow_objects": True})
            lst._set_parent(node)
            return lst

    elif OmegaConf.is_dict(node):
        exclude_keys = set({"_target_", "_convert_", "_recursive_"})
        if _is_target(node):
            target = _resolve_target(node.get(_Keys.TARGET))
            kwargs = {}
            for key, value in node.items_ex(resolve=True):
                if key not in exclude_keys:
                    if recursive:
                        value = instantiate_node(value,
                                                 convert=convert,
                                                 recursive=recursive)
                    kwargs[key] = _convert_node(value, convert)
            return _call_target(target, *args, **kwargs)
        else:
            # If ALL or PARTIAL non structured, instantiate in dict and resolve interpolations eagerly.
            if convert == ConvertMode.ALL or (
                    convert == ConvertMode.PARTIAL
                    and node._metadata.object_type is None):
                dict_items = {}
                for key, value in node.items_ex(resolve=True):
                    # list items inherits recursive flag from the containing dict.
                    dict_items[key] = instantiate_node(value,
                                                       convert=convert,
                                                       recursive=recursive)
                return dict_items
            else:
                # Otherwise use DictConfig and resolve interpolations lazily.
                cfg = OmegaConf.create({}, flags={"allow_objects": True})
                for key, value in node.items_ex(resolve=False):
                    cfg[key] = instantiate_node(value,
                                                convert=convert,
                                                recursive=recursive)
                cfg._set_parent(node)
                cfg._metadata.object_type = node._metadata.object_type
                return cfg

    else:
        assert False, f"Unexpected config type : {type(node).__name__}"
Esempio n. 23
0
def splitNameConf(conf, search, default_name: str = None):
    if OmegaConf.is_dict(conf):
        return getattr(search, conf.pop("name", default_name)), conf
    elif OmegaConf.is_list(conf):
        return getattr(search, conf[0]), {} if len(conf) == 1 else conf[1]