Exemplo n.º 1
0
 def _to_str(obj, prefix=None, inside_call=False):
     if prefix is None:
         prefix = []
     if isinstance(obj, abc.Mapping) and "_target_" in obj:
         # Dict representing a function call
         target = _convert_target_to_string(obj.pop("_target_"))
         args = []
         for k, v in sorted(obj.items()):
             args.append(f"{k}={_to_str(v, inside_call=True)}")
         args = ", ".join(args)
         call = f"{target}({args})"
         return "".join(prefix) + call
     elif isinstance(obj, abc.Mapping) and not inside_call:
         # Dict that is not inside a call is a list of top-level config objects that we
         # render as one object per line with dot separated prefixes
         key_list = []
         for k, v in sorted(obj.items()):
             if isinstance(v, abc.Mapping) and "_target_" not in v:
                 key_list.append(_to_str(v, prefix=prefix + [k + "."]))
             else:
                 key = "".join(prefix) + k
                 key_list.append(f"{key}={_to_str(v)}")
         return "\n".join(key_list)
     elif isinstance(obj, abc.Mapping):
         # Dict that is inside a call is rendered as a regular dict
         return ("{" + ",".join(
             f"{repr(k)}: {_to_str(v, inside_call=inside_call)}"
             for k, v in sorted(obj.items())) + "}")
     elif isinstance(obj, list):
         return "[" + ",".join(
             _to_str(x, inside_call=inside_call) for x in obj) + "]"
     else:
         return repr(obj)
Exemplo n.º 2
0
    def test_compress_target(self):
        from detectron2.data.transforms import RandomCrop

        name = _convert_target_to_string(RandomCrop)
        # name shouldn't contain 'augmentation_impl'
        self.assertEqual(name, "detectron2.data.transforms.RandomCrop")
        self.assertIs(RandomCrop, locate(name))
Exemplo n.º 3
0
    def __call__(self, **kwargs):
        if is_dataclass(self._target):
            # omegaconf object cannot hold dataclass type
            # https://github.com/omry/omegaconf/issues/784
            target = _convert_target_to_string(self._target)
        else:
            target = self._target
        kwargs["_target_"] = target

        return DictConfig(content=kwargs, flags={"allow_objects": True})
Exemplo n.º 4
0
def dump_dataclass(obj: Any):
    """
    Dump a dataclass recursively into a dict that can be later instantiated.

    Args:
        obj: a dataclass object

    Returns:
        dict
    """
    assert dataclasses.is_dataclass(obj) and not isinstance(
        obj, type), "dump_dataclass() requires an instance of a dataclass."
    ret = {"_target_": _convert_target_to_string(type(obj))}
    for f in dataclasses.fields(obj):
        v = getattr(obj, f.name)
        if dataclasses.is_dataclass(v):
            v = dump_dataclass(v)
        if isinstance(v, (list, tuple)):
            v = [
                dump_dataclass(x) if dataclasses.is_dataclass(x) else x
                for x in v
            ]
        ret[f.name] = v
    return ret
Exemplo n.º 5
0
 def flatten(cls, obj):
     return (obj.tensor,), cls(_convert_target_to_string(type(obj)))
Exemplo n.º 6
0
 def _replace_type_by_name(x):
     if "_target_" in x and callable(x._target_):
         try:
             x._target_ = _convert_target_to_string(x._target_)
         except AttributeError:
             pass
Exemplo n.º 7
0
 def _test_obj(self, obj):
     name = _convert_target_to_string(obj)
     newobj = locate(name)
     self.assertIs(obj, newobj)