def _to_str(obj, prefix=None, inside_call=False): if prefix is None: prefix = [] if isinstance(obj, abc.Mapping) and "_target_" in obj: # Dict representing a function call target = _convert_target_to_string(obj.pop("_target_")) args = [] for k, v in sorted(obj.items()): args.append(f"{k}={_to_str(v, inside_call=True)}") args = ", ".join(args) call = f"{target}({args})" return "".join(prefix) + call elif isinstance(obj, abc.Mapping) and not inside_call: # Dict that is not inside a call is a list of top-level config objects that we # render as one object per line with dot separated prefixes key_list = [] for k, v in sorted(obj.items()): if isinstance(v, abc.Mapping) and "_target_" not in v: key_list.append(_to_str(v, prefix=prefix + [k + "."])) else: key = "".join(prefix) + k key_list.append(f"{key}={_to_str(v)}") return "\n".join(key_list) elif isinstance(obj, abc.Mapping): # Dict that is inside a call is rendered as a regular dict return ("{" + ",".join( f"{repr(k)}: {_to_str(v, inside_call=inside_call)}" for k, v in sorted(obj.items())) + "}") elif isinstance(obj, list): return "[" + ",".join( _to_str(x, inside_call=inside_call) for x in obj) + "]" else: return repr(obj)
def test_compress_target(self): from detectron2.data.transforms import RandomCrop name = _convert_target_to_string(RandomCrop) # name shouldn't contain 'augmentation_impl' self.assertEqual(name, "detectron2.data.transforms.RandomCrop") self.assertIs(RandomCrop, locate(name))
def __call__(self, **kwargs): if is_dataclass(self._target): # omegaconf object cannot hold dataclass type # https://github.com/omry/omegaconf/issues/784 target = _convert_target_to_string(self._target) else: target = self._target kwargs["_target_"] = target return DictConfig(content=kwargs, flags={"allow_objects": True})
def dump_dataclass(obj: Any): """ Dump a dataclass recursively into a dict that can be later instantiated. Args: obj: a dataclass object Returns: dict """ assert dataclasses.is_dataclass(obj) and not isinstance( obj, type), "dump_dataclass() requires an instance of a dataclass." ret = {"_target_": _convert_target_to_string(type(obj))} for f in dataclasses.fields(obj): v = getattr(obj, f.name) if dataclasses.is_dataclass(v): v = dump_dataclass(v) if isinstance(v, (list, tuple)): v = [ dump_dataclass(x) if dataclasses.is_dataclass(x) else x for x in v ] ret[f.name] = v return ret
def flatten(cls, obj): return (obj.tensor,), cls(_convert_target_to_string(type(obj)))
def _replace_type_by_name(x): if "_target_" in x and callable(x._target_): try: x._target_ = _convert_target_to_string(x._target_) except AttributeError: pass
def _test_obj(self, obj): name = _convert_target_to_string(obj) newobj = locate(name) self.assertIs(obj, newobj)