def _get_kwargs(config: Union[ObjectConf, DictConfig], **kwargs: Any) -> Any: # copy config to avoid mutating it when merging with kwargs config_copy = copy.deepcopy(config) # Manually set parent as deepcopy does not currently handles it (https://github.com/omry/omegaconf/issues/130) # noinspection PyProtectedMember config_copy._set_parent(config._get_parent()) # type: ignore config = config_copy params = config.params if "params" in config else OmegaConf.create() assert isinstance( params, DictConfig ), f"Input config params are expected to be a mapping, found {type(config.params).__name__}" primitives = {} rest = {} for k, v in kwargs.items(): if _utils.is_primitive_type(v) or isinstance(v, (dict, list)): primitives[k] = v else: rest[k] = v final_kwargs = {} with read_write(params): params.merge_with(OmegaConf.create(primitives)) for k, v in params.items(): final_kwargs[k] = v for k, v in rest.items(): final_kwargs[k] = v return final_kwargs
def convert_imports(imports: Set[Type], string_imports: List[str]) -> List[str]: tmp = set() for imp in string_imports: tmp.add(imp) for t in imports: s = None origin = getattr(t, "__origin__", None) if t is Any: classname = "Any" elif t is Optional: classname = "Optional" else: if origin is list: classname = "List" elif origin is tuple: classname = "Tuple" elif origin is dict: classname = "Dict" else: classname = t.__name__ if not is_primitive_type(t) or issubclass(t, Enum): s = f"from {t.__module__} import {classname}" if s is not None: tmp.add(s) return sorted(list(tmp))
def _get_kwargs(config: Union[ObjectConf, DictConfig], **kwargs: Any) -> Any: if isinstance(config, ObjectConf): config = OmegaConf.structured(config) else: config = copy.deepcopy(config) params = config.params if hasattr(config, "params") else {} assert isinstance( params, MutableMapping ), f"Input config params are expected to be a mapping, found {type(config.params).__name__}" if isinstance(config, DictConfig): assert isinstance(params, DictConfig) params._set_parent(config) primitives = {} rest = {} for k, v in kwargs.items(): if _utils.is_primitive_type(v) or isinstance(v, (dict, list)): primitives[k] = v else: rest[k] = v final_kwargs = {} with read_write(params): params.merge_with(primitives) for k, v in params.items(): final_kwargs[k] = v for k, v in rest.items(): final_kwargs[k] = v return final_kwargs
def convert_imports(imports: Set[Type], string_imports: Set[str]) -> List[str]: tmp = set() for import_ in imports: origin = getattr(import_, "__origin__", None) if import_ is Any: classname = "Any" elif import_ is Optional: classname = "Optional" else: if origin is list: classname = "List" elif origin is tuple: classname = "Tuple" elif origin is dict: classname = "Dict" else: classname = import_.__name__ if not is_primitive_type(import_) or issubclass(import_, Enum): tmp.add(f"from {import_.__module__} import {classname}") return sorted(list(tmp.union(string_imports)))
def test_is_primitive_type(type_: Any, is_primitive: bool) -> None: assert _utils.is_primitive_type(type_) == is_primitive