Esempio n. 1
0
def _get_kwargs(config: Union[ObjectConf, DictConfig], **kwargs: Any) -> Any:
    # copy config to avoid mutating it when merging with kwargs
    config_copy = copy.deepcopy(config)

    # Manually set parent as deepcopy does not currently handles it (https://github.com/omry/omegaconf/issues/130)
    # noinspection PyProtectedMember
    config_copy._set_parent(config._get_parent())  # type: ignore
    config = config_copy

    params = config.params if "params" in config else OmegaConf.create()
    assert isinstance(
        params, DictConfig
    ), f"Input config params are expected to be a mapping, found {type(config.params).__name__}"
    primitives = {}
    rest = {}
    for k, v in kwargs.items():
        if _utils.is_primitive_type(v) or isinstance(v, (dict, list)):
            primitives[k] = v
        else:
            rest[k] = v
    final_kwargs = {}
    with read_write(params):
        params.merge_with(OmegaConf.create(primitives))

    for k, v in params.items():
        final_kwargs[k] = v

    for k, v in rest.items():
        final_kwargs[k] = v
    return final_kwargs
Esempio n. 2
0
def convert_imports(imports: Set[Type], string_imports: List[str]) -> List[str]:
    tmp = set()
    for imp in string_imports:
        tmp.add(imp)
    for t in imports:
        s = None
        origin = getattr(t, "__origin__", None)
        if t is Any:
            classname = "Any"
        elif t is Optional:
            classname = "Optional"
        else:
            if origin is list:
                classname = "List"
            elif origin is tuple:
                classname = "Tuple"
            elif origin is dict:
                classname = "Dict"
            else:
                classname = t.__name__

        if not is_primitive_type(t) or issubclass(t, Enum):
            s = f"from {t.__module__} import {classname}"

        if s is not None:
            tmp.add(s)
    return sorted(list(tmp))
Esempio n. 3
0
def _get_kwargs(config: Union[ObjectConf, DictConfig], **kwargs: Any) -> Any:

    if isinstance(config, ObjectConf):
        config = OmegaConf.structured(config)
    else:
        config = copy.deepcopy(config)

    params = config.params if hasattr(config, "params") else {}

    assert isinstance(
        params, MutableMapping
    ), f"Input config params are expected to be a mapping, found {type(config.params).__name__}"

    if isinstance(config, DictConfig):
        assert isinstance(params, DictConfig)
        params._set_parent(config)

    primitives = {}
    rest = {}
    for k, v in kwargs.items():
        if _utils.is_primitive_type(v) or isinstance(v, (dict, list)):
            primitives[k] = v
        else:
            rest[k] = v
    final_kwargs = {}

    with read_write(params):
        params.merge_with(primitives)

    for k, v in params.items():
        final_kwargs[k] = v

    for k, v in rest.items():
        final_kwargs[k] = v
    return final_kwargs
Esempio n. 4
0
def convert_imports(imports: Set[Type], string_imports: Set[str]) -> List[str]:
    tmp = set()
    for import_ in imports:
        origin = getattr(import_, "__origin__", None)
        if import_ is Any:
            classname = "Any"
        elif import_ is Optional:
            classname = "Optional"
        else:
            if origin is list:
                classname = "List"
            elif origin is tuple:
                classname = "Tuple"
            elif origin is dict:
                classname = "Dict"
            else:
                classname = import_.__name__
        if not is_primitive_type(import_) or issubclass(import_, Enum):
            tmp.add(f"from {import_.__module__} import {classname}")

    return sorted(list(tmp.union(string_imports)))
Esempio n. 5
0
def test_is_primitive_type(type_: Any, is_primitive: bool) -> None:
    assert _utils.is_primitive_type(type_) == is_primitive