コード例 #1
0
    def __init__(self, config, *args, **kwargs):
        transform_params = config.transforms
        assert OmegaConf.is_dict(transform_params) or OmegaConf.is_list(
            transform_params)
        if OmegaConf.is_dict(transform_params):
            transform_params = [transform_params]
        pytorchvideo_spec = importlib.util.find_spec("pytorchvideo")
        assert (pytorchvideo_spec is not None
                ), "Must have pytorchvideo installed to use VideoTransforms"

        transforms_list = []

        for param in transform_params:
            if OmegaConf.is_dict(param):
                # This will throw config error if missing
                transform_type = param.type
                transform_param = param.get("params", OmegaConf.create({}))
            else:
                assert isinstance(param, str), (
                    "Each transform should either be str or dict containing "
                    "type and params")
                transform_type = param
                transform_param = OmegaConf.create([])

            transforms_list.append(
                self.get_transform_object(transform_type, transform_param))

        self.transform = img_transforms.Compose(transforms_list)
コード例 #2
0
def _get_kwargs(
    config: Union[DictConfig, ListConfig],
    **kwargs: Any,
) -> Any:
    from hydra.utils import instantiate

    assert OmegaConf.is_config(config)

    if OmegaConf.is_list(config):
        assert isinstance(config, ListConfig)
        return [
            _get_kwargs(x) if OmegaConf.is_config(x) else x for x in config
        ]

    assert OmegaConf.is_dict(
        config), "Input config is not an OmegaConf DictConfig"

    final_kwargs = {}

    recursive = _is_recursive(config, kwargs)
    overrides = OmegaConf.create(kwargs, flags={"allow_objects": True})
    config.merge_with(overrides)

    for k, v in config.items():
        final_kwargs[k] = v

    if recursive:
        for k, v in final_kwargs.items():
            if _is_target(v):
                final_kwargs[k] = instantiate(v)
            elif OmegaConf.is_dict(v) and not OmegaConf.is_none(v):
                d = OmegaConf.create({}, flags={"allow_objects": True})
                for key, value in v.items():
                    if _is_target(value):
                        d[key] = instantiate(value)
                    elif OmegaConf.is_config(value):
                        d[key] = _get_kwargs(value)
                    else:
                        d[key] = value
                final_kwargs[k] = d
            elif OmegaConf.is_list(v):
                lst = OmegaConf.create([], flags={"allow_objects": True})
                for x in v:
                    if _is_target(x):
                        lst.append(instantiate(x))
                    elif OmegaConf.is_config(x):
                        lst.append(_get_kwargs(x))
                    else:
                        lst.append(x)
                final_kwargs[k] = lst
            else:
                if OmegaConf.is_none(v):
                    v = None
                final_kwargs[k] = v

    return final_kwargs
コード例 #3
0
    def __init__(self, config, *args, **kwargs):
        transform_params = config.transforms
        assert OmegaConf.is_dict(transform_params) or OmegaConf.is_list(
            transform_params)
        if OmegaConf.is_dict(transform_params):
            transform_params = [transform_params]

        transforms_list = []

        for param in transform_params:
            if OmegaConf.is_dict(param):
                # This will throw config error if missing
                transform_type = param.type
                transform_param = param.get("params", OmegaConf.create({}))
            else:
                assert isinstance(param, str), (
                    "Each transform should either be str or dict containing " +
                    "type and params")
                transform_type = param
                transform_param = OmegaConf.create([])

            transform = getattr(transforms, transform_type, None)
            if transform is None:
                from mmf.utils.env import setup_torchaudio

                setup_torchaudio()
                from torchaudio import transforms as torchaudio_transforms

                transform = getattr(torchaudio_transforms, transform_type,
                                    None)
            # If torchvision or torchaudiodoesn't contain this, check our registry
            # if we implemented a custom transform as processor
            if transform is None:
                transform = registry.get_processor_class(transform_type)
            assert transform is not None, (
                f"transform {transform_type} is not present in torchvision, " +
                "torchaudio or processor registry")

            # https://github.com/omry/omegaconf/issues/248
            transform_param = OmegaConf.to_container(transform_param)
            # If a dict, it will be passed as **kwargs, else a list is *args
            if isinstance(transform_param, collections.abc.Mapping):
                transform_object = transform(**transform_param)
            else:
                transform_object = transform(*transform_param)

            transforms_list.append(transform_object)

        self.transform = transforms.Compose(transforms_list)
コード例 #4
0
ファイル: config_repository.py プロジェクト: odelalleau/hydra
    def _extract_defaults_list(self, config_path: str,
                               cfg: Container) -> ListConfig:
        empty = OmegaConf.create([])
        if not OmegaConf.is_dict(cfg):
            return empty
        assert isinstance(cfg, DictConfig)
        with read_write(cfg):
            with open_dict(cfg):
                if not cfg._is_typed():
                    defaults = cfg.pop("defaults", empty)
                else:
                    # If node is a backed by Structured Config, flag it and temporarily keep the defaults list in.
                    # It will be removed later.
                    # This is addressing an edge case where the defaults list re-appears once the dataclass is used
                    # as a prototype during OmegaConf merge.
                    cfg["__HYDRA_REMOVE_TOP_LEVEL_DEFAULTS__"] = True
                    defaults = cfg.get("defaults", empty)
        if not isinstance(defaults, ListConfig):
            if isinstance(defaults, DictConfig):
                type_str = "mapping"
            else:
                type_str = type(defaults).__name__
            raise ValueError(
                f"Invalid defaults list in '{config_path}', defaults must be a list (got {type_str})"
            )

        return defaults
コード例 #5
0
ファイル: test_zoo_urls.py プロジェクト: facebookresearch/mmf
    def _recurse_on_config(self, config: DictConfig,
                           callback: typing.Callable):
        if OmegaConf.is_list(
                config) and len(config) > 0 and "url" in config[0]:
            # Found the urls, let's test them
            for item in config:
                # flickr30 download source is down, ignore dataset until a
                # mirror can be found
                if getattr(item, "file_name", "") == "flickr30_images.tar.gz":
                    continue
                # First try making the DownloadableFile class to make sure
                # everything is fine
                download = DownloadableFile(**item)
                # Now, call the actual callback which will test specific scenarios
                callback(download)

        elif OmegaConf.is_dict(config):
            # Both version and resources should be present
            if "version" in config:
                self.assertIn("resources", config)
            if "resources" in config:
                self.assertIn("version", config)

            # Let's continue recursing
            for item in config:
                self._recurse_on_config(config[item], callback=callback)
コード例 #6
0
ファイル: config_source.py プロジェクト: yangky11/hydra
    def _extract_defaults_list(
        config_path: Optional[str], cfg: Container
    ) -> List[DefaultElement]:
        if not OmegaConf.is_dict(cfg):
            return []

        assert isinstance(cfg, DictConfig)
        with read_write(cfg):
            with open_dict(cfg):
                defaults = cfg.pop("defaults", OmegaConf.create([]))

        if len(defaults) > 0:
            return ConfigSource._create_defaults_list(
                config_path=config_path, defaults=defaults
            )
        else:
            return []
コード例 #7
0
    def _recur_log_omegaconf_param(self, key: str, val: any) -> None:
        if OmegaConf.is_dict(val):
            for k, v in val.items():
                k: str
                v: any

                if key == "":
                    self._recur_log_omegaconf_param(k, v)
                else:
                    self._recur_log_omegaconf_param(f"{key}.{k}", v)
        elif OmegaConf.is_list(val):
            for i, v in enumerate(val):
                i: int
                v: any

                self._recur_log_omegaconf_param(f"{key}[{i}]", v)
        else:
            self.log_param(key, str(val))
コード例 #8
0
    def _extract_defaults_list(self, config_path: str,
                               cfg: Container) -> ListConfig:
        empty = OmegaConf.create([])
        if not OmegaConf.is_dict(cfg):
            return empty
        assert isinstance(cfg, DictConfig)
        with read_write(cfg):
            with open_dict(cfg):
                defaults = cfg.pop("defaults", empty)
        if not isinstance(defaults, ListConfig):
            if isinstance(defaults, DictConfig):
                type_str = "mapping"
            else:
                type_str = type(defaults).__name__
            raise ValueError(
                f"Invalid defaults list in '{config_path}', defaults must be a list (got {type_str})"
            )

        return defaults
コード例 #9
0
def test_dict_with_structured_config() -> None:
    @dataclass
    class DictValuesConf:
        _target_: str = "tests.test_utils.DictValues"
        d: Dict[str, User] = MISSING

    schema = OmegaConf.structured(DictValuesConf)
    cfg = OmegaConf.merge(schema, {"d": {"007": {"name": "Bond", "age": 7}}})
    obj = utils.instantiate(config=cfg, _convert_="none")
    assert OmegaConf.is_dict(obj.d)
    assert OmegaConf.get_type(obj.d["007"]) == User

    obj = utils.instantiate(config=cfg, _convert_="partial")
    assert isinstance(obj.d, dict)
    assert OmegaConf.get_type(obj.d["007"]) == User

    obj = utils.instantiate(config=cfg, _convert_="all")
    assert isinstance(obj.d, dict)
    assert isinstance(obj.d["007"], dict)
コード例 #10
0
    def configure_optimizers(self):
        if OmegaConf.is_list(self.optimizer_args):
            kwargs_dict = {}
            for kwargs in self.optimizer_args:
                param_names = kwargs.pop('params')
                if param_names == 'default':
                    default_kwargs = kwargs
                else:
                    if isinstance(param_names, str):
                        param_names = [param_names]

                    for param in param_names:
                        kwargs_dict[param] = kwargs

            optimized_params = []
            for n, p in self.model.named_parameters():
                for i, (param, kwargs) in enumerate(kwargs_dict.items()):
                    if param in n:
                        optimized_params.append({'params': p, **kwargs})
                        break
                    elif i == len(kwargs_dict) - 1:
                        optimized_params.append({
                            'params': p,
                        })

            optimizer = getattr(optimizers, self.optimizer)(optimized_params,
                                                            **default_kwargs)

        elif OmegaConf.is_dict(self.optimizer_args):
            optimizer = getattr(optimizers,
                                self.optimizer)(self.parameters(),
                                                **self.optimizer_args)
        else:
            raise TypeError

        if self.scheduler is None:
            return optimizer
        else:
            scheduler = getattr(optimizers,
                                self.scheduler)(optimizer,
                                                **self.scheduler_args)
            return [optimizer], [scheduler]
コード例 #11
0
    def _recurse_on_config(self, config):
        if OmegaConf.is_list(
                config) and len(config) > 0 and "url" in config[0]:
            # Found the urls, let's test them
            for item in config:
                # First try making the DownloadableFile class to make sure
                # everything is fine
                download = DownloadableFile(**item)
                # Now check the actual header
                check_header(download._url, from_google=download._from_google)
        elif OmegaConf.is_dict(config):
            # Both version and resources should be present
            if "version" in config:
                self.assertIn("resources", config)
            if "resources" in config:
                self.assertIn("version", config)

            # Let's continue recursing
            for item in config:
                self._recurse_on_config(config[item])
コード例 #12
0
def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
    """
    :param config: An config object describing what to call and what params to use.
                   In addition to the parameters, the config must contain:
                   _target_ : target class or callable name (str)
                   And may contain:
                   _args_: List-like of positional arguments to pass to the target
                   _recursive_: Construct nested objects as well (bool).
                                True by default.
                                may be overridden via a _recursive_ key in
                                the kwargs
                   _convert_: Conversion strategy
                        none    : Passed objects are DictConfig and ListConfig, default
                        partial : Passed objects are converted to dict and list, with
                                  the exception of Structured Configs (and their fields).
                        all     : Passed objects are dicts, lists and primitives without
                                  a trace of OmegaConf containers
                   _partial_: If True, return functools.partial wrapped method or object
                              False by default. Configure per target.
    :param args: Optional positional parameters pass-through
    :param kwargs: Optional named parameters to override
                   parameters in the config object. Parameters not present
                   in the config objects are being passed as is to the target.
                   IMPORTANT: dataclasses instances in kwargs are interpreted as config
                              and cannot be used as passthrough
    :return: if _target_ is a class name: the instantiated object
             if _target_ is a callable: the return value of the call
    """

    # Return None if config is None
    if config is None:
        return None

    # TargetConf edge case
    if isinstance(config, TargetConf) and config._target_ == "???":
        # Specific check to give a good warning about failure to annotate _target_ as a string.
        raise InstantiationException(
            dedent(f"""\
                Config has missing value for key `_target_`, cannot instantiate.
                Config type: {type(config).__name__}
                Check that the `_target_` key in your dataclass is properly annotated and overridden.
                A common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'"""
                   ))
        # TODO: print full key

    if isinstance(config, (dict, list)):
        config = _prepare_input_dict_or_list(config)

    kwargs = _prepare_input_dict_or_list(kwargs)

    # Structured Config always converted first to OmegaConf
    if is_structured_config(config) or isinstance(config, (dict, list)):
        config = OmegaConf.structured(config, flags={"allow_objects": True})

    if OmegaConf.is_dict(config):
        # Finalize config (convert targets to strings, merge with kwargs)
        config_copy = copy.deepcopy(config)
        config_copy._set_flag(flags=["allow_objects", "struct", "readonly"],
                              values=[True, False, False])
        config_copy._set_parent(config._get_parent())
        config = config_copy

        if kwargs:
            config = OmegaConf.merge(config, kwargs)

        OmegaConf.resolve(config)

        _recursive_ = config.pop(_Keys.RECURSIVE, True)
        _convert_ = config.pop(_Keys.CONVERT, ConvertMode.NONE)
        _partial_ = config.pop(_Keys.PARTIAL, False)

        return instantiate_node(config,
                                *args,
                                recursive=_recursive_,
                                convert=_convert_,
                                partial=_partial_)
    elif OmegaConf.is_list(config):
        # Finalize config (convert targets to strings, merge with kwargs)
        config_copy = copy.deepcopy(config)
        config_copy._set_flag(flags=["allow_objects", "struct", "readonly"],
                              values=[True, False, False])
        config_copy._set_parent(config._get_parent())
        config = config_copy

        OmegaConf.resolve(config)

        _recursive_ = kwargs.pop(_Keys.RECURSIVE, True)
        _convert_ = kwargs.pop(_Keys.CONVERT, ConvertMode.NONE)
        _partial_ = kwargs.pop(_Keys.PARTIAL, False)

        if _partial_:
            raise InstantiationException(
                "The _partial_ keyword is not compatible with top-level list instantiation"
            )

        return instantiate_node(config,
                                *args,
                                recursive=_recursive_,
                                convert=_convert_,
                                partial=_partial_)
    else:
        raise InstantiationException(
            dedent(f"""\
                Cannot instantiate config of type {type(config).__name__}.
                Top level config must be an OmegaConf DictConfig/ListConfig object,
                a plain dict/list, or a Structured Config class or instance."""
                   ))
コード例 #13
0
    def __init__(self, cfg: FlashlightDecoderConfig,
                 tgt_dict: Dictionary) -> None:
        super().__init__(tgt_dict)

        self.nbest = cfg.nbest
        self.unitlm = cfg.unitlm

        self.lexicon = load_words(cfg.lexicon) if cfg.lexicon else None
        self.idx_to_wrd = {}

        checkpoint = torch.load(cfg.lmpath, map_location="cpu")

        if "cfg" in checkpoint and checkpoint["cfg"] is not None:
            lm_args = checkpoint["cfg"]
        else:
            lm_args = convert_namespace_to_omegaconf(checkpoint["args"])

        if not OmegaConf.is_dict(lm_args):
            lm_args = OmegaConf.create(lm_args)

        with open_dict(lm_args.task):
            lm_args.task.data = osp.dirname(cfg.lmpath)

        task = tasks.setup_task(lm_args.task)
        model = task.build_model(lm_args.model)
        model.load_state_dict(checkpoint["model"], strict=False)

        self.trie = Trie(self.vocab_size, self.silence)

        self.word_dict = task.dictionary
        self.unk_word = self.word_dict.unk()
        self.lm = FairseqLM(self.word_dict, model)

        if self.lexicon:
            start_state = self.lm.start(False)
            for i, (word, spellings) in enumerate(self.lexicon.items()):
                if self.unitlm:
                    word_idx = i
                    self.idx_to_wrd[i] = word
                    score = 0
                else:
                    word_idx = self.word_dict.index(word)
                    _, score = self.lm.score(start_state,
                                             word_idx,
                                             no_cache=True)

                for spelling in spellings:
                    spelling_idxs = [
                        tgt_dict.index(token) for token in spelling
                    ]
                    assert (
                        tgt_dict.unk()
                        not in spelling_idxs), f"{spelling} {spelling_idxs}"
                    self.trie.insert(spelling_idxs, word_idx, score)
            self.trie.smear(SmearingMode.MAX)

            self.decoder_opts = LexiconDecoderOptions(
                beam_size=cfg.beam,
                beam_size_token=cfg.beamsizetoken or len(tgt_dict),
                beam_threshold=cfg.beamthreshold,
                lm_weight=cfg.lmweight,
                word_score=cfg.wordscore,
                unk_score=cfg.unkweight,
                sil_score=cfg.silweight,
                log_add=False,
                criterion_type=CriterionType.CTC,
            )

            self.decoder = LexiconDecoder(
                self.decoder_opts,
                self.trie,
                self.lm,
                self.silence,
                self.blank,
                self.unk_word,
                [],
                self.unitlm,
            )
        else:
            assert self.unitlm, "Lexicon-free decoding requires unit LM"

            d = {w: [[w]] for w in tgt_dict.symbols}
            self.word_dict = create_word_dict(d)
            self.lm = KenLM(cfg.lmpath, self.word_dict)
            self.decoder_opts = LexiconFreeDecoderOptions(
                beam_size=cfg.beam,
                beam_size_token=cfg.beamsizetoken or len(tgt_dict),
                beam_threshold=cfg.beamthreshold,
                lm_weight=cfg.lmweight,
                sil_score=cfg.silweight,
                log_add=False,
                criterion_type=CriterionType.CTC,
            )
            self.decoder = LexiconFreeDecoder(self.decoder_opts, self.lm,
                                              self.silence, self.blank, [])
コード例 #14
0
ファイル: fsm.py プロジェクト: JamzumSum/yNet
def splitNameConf(conf, search, default_name: str = None):
    if OmegaConf.is_dict(conf):
        return getattr(search, conf.pop("name", default_name)), conf
    elif OmegaConf.is_list(conf):
        return getattr(search, conf[0]), {} if len(conf) == 1 else conf[1]
コード例 #15
0
def test_is_config(cfg: Any, is_conf: bool, is_list: bool, is_dict: bool,
                   type_: Type[Any]) -> None:
    assert OmegaConf.is_config(cfg) == is_conf
    assert OmegaConf.is_list(cfg) == is_list
    assert OmegaConf.is_dict(cfg) == is_dict
    assert OmegaConf.get_type(cfg) == type_