Exemplo n.º 1
0
    def _set_backend(args):
        """Determines which backend class to use

        *Args*:

            args: list of classes passed to the builder

        *Returns*:

            backend: class of backend

        """
        # Gather if all attr backend
        type_attrs = all([attr.has(arg) for arg in args])
        if not type_attrs:
            which_idx = [attr.has(arg) for arg in args].index(False)
            if hasattr(args[which_idx], '__name__'):
                raise TypeError(
                    f"*args must be of all attrs backend -- missing a @spock decorator on class "
                    f"{args[which_idx].__name__}")
            else:
                raise TypeError(
                    f"*args must be of all attrs backend -- invalid type "
                    f"{type(args[which_idx])}")
        else:
            backend = {
                'builder': AttrBuilder,
                'payload': AttrPayload,
                'saver': AttrSaver
            }
        return backend
Exemplo n.º 2
0
 def assert_proper_col_class(obj, obj_tuple):
     # Iterate over all attributes, and if they are lists or mappings
     # in the original, assert they are the same class in the dumped.
     for index, field in enumerate(fields(obj.__class__)):
         field_val = getattr(obj, field.name)
         if has(field_val.__class__):
             # This field holds a class, recurse the assertions.
             assert_proper_col_class(field_val, obj_tuple[index])
         elif isinstance(field_val, (list, tuple)):
             # This field holds a sequence of something.
             expected_type = type(obj_tuple[index])
             assert type(field_val) is expected_type  # noqa: E721
             for obj_e, obj_tuple_e in zip(field_val, obj_tuple[index]):
                 if has(obj_e.__class__):
                     assert_proper_col_class(obj_e, obj_tuple_e)
         elif isinstance(field_val, dict):
             orig = field_val
             tupled = obj_tuple[index]
             assert type(orig) is type(tupled)  # noqa: E721
             for obj_e, obj_tuple_e in zip(orig.items(),
                                           tupled.items()):
                 if has(obj_e[0].__class__):  # Dict key
                     assert_proper_col_class(obj_e[0], obj_tuple_e[0])
                 if has(obj_e[1].__class__):  # Dict value
                     assert_proper_col_class(obj_e[1], obj_tuple_e[1])
Exemplo n.º 3
0
def check_if_different(testobj1: object, testobj2: object) -> None:
    assert testobj1 is not testobj2
    if attr.has(testobj1.__class__) and attr.has(testobj2.__class__):
        for key, val in attr.asdict(testobj1, recurse=False).items():
            if isinstance(val, dict) or isinstance(val, list) or attr.has(val):
                # Note: this check doesn't check the contents of mutables.
                check_if_different(val, attr.asdict(testobj2, recurse=False)[key])
Exemplo n.º 4
0
 def assert_proper_col_class(obj, obj_tuple):
     # Iterate over all attributes, and if they are lists or mappings
     # in the original, assert they are the same class in the dumped.
     for index, field in enumerate(fields(obj.__class__)):
         field_val = getattr(obj, field.name)
         if has(field_val.__class__):
             # This field holds a class, recurse the assertions.
             assert_proper_col_class(field_val, obj_tuple[index])
         elif isinstance(field_val, (list, tuple)):
             # This field holds a sequence of something.
             expected_type = type(obj_tuple[index])
             assert type(field_val) is expected_type  # noqa: E721
             for obj_e, obj_tuple_e in zip(field_val, obj_tuple[index]):
                 if has(obj_e.__class__):
                     assert_proper_col_class(obj_e, obj_tuple_e)
         elif isinstance(field_val, dict):
             orig = field_val
             tupled = obj_tuple[index]
             assert type(orig) is type(tupled)  # noqa: E721
             for obj_e, obj_tuple_e in zip(orig.items(),
                                           tupled.items()):
                 if has(obj_e[0].__class__):  # Dict key
                     assert_proper_col_class(obj_e[0], obj_tuple_e[0])
                 if has(obj_e[1].__class__):  # Dict value
                     assert_proper_col_class(obj_e[1], obj_tuple_e[1])
Exemplo n.º 5
0
    def wrapper(self, context, *args, **kwargs):
        self.log.debug("Preparing lineage inlets and outlets")

        if isinstance(self._inlets, (str, Operator)) or attr.has(self._inlets):
            self._inlets = [
                self._inlets,
            ]

        if self._inlets and isinstance(self._inlets, list):
            # get task_ids that are specified as parameter and make sure they are upstream
            task_ids = set(
                filter(lambda x: isinstance(x, str) and x.lower() != AUTO,
                       self._inlets)).union(
                           map(
                               lambda op: op.task_id,
                               filter(lambda op: isinstance(op, Operator),
                                      self._inlets))).intersection(
                                          self.get_flat_relative_ids(
                                              upstream=True))

            # pick up unique direct upstream task_ids if AUTO is specified
            if AUTO.upper() in self._inlets or AUTO.lower() in self._inlets:
                task_ids = task_ids.union(
                    task_ids.symmetric_difference(self.upstream_task_ids))

            _inlets = self.xcom_pull(context,
                                     task_ids=task_ids,
                                     dag_id=self.dag_id,
                                     key=PIPELINE_OUTLETS)

            # re-instantiate and render the obtained inlets
            _inlets = [
                _get_instance(structure(item, Metadata)) for sublist in _inlets
                if sublist for item in sublist
            ]
            _inlets.extend([
                _render_object(i, context) for i in self._inlets if attr.has(i)
            ])

            self.inlets.extend(_inlets)

        elif self._inlets:
            raise AttributeError(
                "inlets is not a list, operator, string or attr annotated object"
            )

        if not isinstance(self._outlets, list):
            self._outlets = [
                self._outlets,
            ]

        _outlets = list(
            map(lambda i: _render_object(i, context),
                filter(attr.has, self._outlets)))

        self.outlets.extend(_outlets)

        self.log.debug("inlets: %s, outlets: %s", self.inlets, self.outlets)
        return func(self, context, *args, **kwargs)
Exemplo n.º 6
0
def as_json_dict(obj):
    # type: (Any) -> Dict
    """
    Similar to attr.asdict, but will prioritize an `as_dict` instance method
    over ``attr.asdict`` (if present) on nested objects and tries to convert
    common python types to json-encodable values.

    Ex: datetimes are converted to isoformat, enums are converted to their
    values, etc.

    Expected Usage::

        @attr.s
        class MyClass(object):
            a = attr.ib()

            as_dict = as_json_dict

        # ... later
        my_dict = MyClass(a=x).as_dict()

    Note the caveat that ``as_json_dict(my_class)`` won't equal
    ``my_class.as_dict()`` if ``my_class`` defines custom logic within an
    overridden ``as_dict`` method.

    :param obj: attrs object to convert to a dictionary. Optionally can be
        a dictionary, which will recursively serialize keys/values the same
        way.
    :returns: a dict
    """
    if isinstance(obj, dict):
        ret = obj
    else:
        ret = attr.asdict(obj, recurse=False)  # handling recursing manually

    for k, v in ret.items():
        if isinstance(v, datetime.datetime):
            ret[k] = v.isoformat()
        elif isinstance(v, enum.Enum):
            ret[k] = v.value
        elif isinstance(v, (list, tuple, set)):
            ret[k] = [(as_json_dict(i) if attr.has(i.__class__) else i)
                      for i in v]
        elif isinstance(v, dict):
            ret[k] = as_json_dict(v)
        elif attr.has(v.__class__):
            if callable(getattr(v, "as_dict", None)):
                ret[k] = v.as_dict()
            else:
                ret[k] = as_json_dict(v)

    return ret
Exemplo n.º 7
0
def create_robcoewmtype_str(obj: Any) -> str:
    """
    Create a type string from an robcoewm type class.

    This function should be used to save the data types of classes with "attr"
    attributes as a string to unstructure and restructure these classes later
    using "cattr" module.

    Input could be a class with "attr" attributes or a list of those classes.
    """
    # List of objects
    if isinstance(obj, list):
        clslist: List[Type[object]] = []
        clsmodulelist = []
        for entry in obj:
            # If data type is not collected yet, append it to list
            if entry.__class__ not in clslist:
                if not attr.has(entry.__class__):
                    raise TypeError(
                        '"{}" is not a class with "attr" attributes'.format(
                            entry.__class__))
                clslist.append(entry.__class__)
                clsmodulelist.append(entry.__module__)

        # List with one data type
        if len(clslist) == 1:
            clsstr = 'List[{mod}.{cls}]'.format(mod=clsmodulelist[0],
                                                cls=clslist[0].__name__)
        # List with a union of data types
        else:
            clsstrlist = []
            # Convert class module and name to a string for all classes
            for clsmodule, cls in zip(clsmodulelist, clslist):
                clsstrlist.append('{mod}.{cls}'.format(mod=clsmodule,
                                                       cls=cls.__name__))
            # Join class string list to a single string
            clsstr = 'List[Union[{clsjoin}]]'.format(
                clsjoin=','.join(clsstrlist))

    # Single object
    else:
        clsmodule = obj.__module__
        cls = obj.__class__
        # Check for class for attr attributes
        if not attr.has(cls):
            raise TypeError(
                '"{}" is not a class with "attr" attributes'.format(cls))

        # Convert class module and name to a string
        clsstr = '{mod}.{cls}'.format(mod=clsmodule, cls=cls.__name__)

    return clsstr
Exemplo n.º 8
0
    def wrapper(self, context, *args, **kwargs):
        from airflow.models.abstractoperator import AbstractOperator

        self.log.debug("Preparing lineage inlets and outlets")

        if isinstance(self._inlets, (str, AbstractOperator)) or attr.has(self._inlets):
            self._inlets = [
                self._inlets,
            ]

        if self._inlets and isinstance(self._inlets, list):
            # get task_ids that are specified as parameter and make sure they are upstream
            task_ids = (
                {o for o in self._inlets if isinstance(o, str)}
                .union(op.task_id for op in self._inlets if isinstance(op, AbstractOperator))
                .intersection(self.get_flat_relative_ids(upstream=True))
            )

            # pick up unique direct upstream task_ids if AUTO is specified
            if AUTO.upper() in self._inlets or AUTO.lower() in self._inlets:
                task_ids = task_ids.union(task_ids.symmetric_difference(self.upstream_task_ids))

            _inlets = self.xcom_pull(context, task_ids=task_ids, dag_id=self.dag_id, key=PIPELINE_OUTLETS)

            # re-instantiate the obtained inlets
            _inlets = [
                _get_instance(structure(item, Metadata)) for sublist in _inlets if sublist for item in sublist
            ]

            self.inlets.extend(_inlets)
            self.inlets.extend(self._inlets)

        elif self._inlets:
            raise AttributeError("inlets is not a list, operator, string or attr annotated object")

        if not isinstance(self._outlets, list):
            self._outlets = [
                self._outlets,
            ]

        self.outlets.extend(self._outlets)

        # render inlets and outlets
        self.inlets = [_render_object(i, context) for i in self.inlets if attr.has(i)]

        self.outlets = [_render_object(i, context) for i in self.outlets if attr.has(i)]

        self.log.debug("inlets: %s, outlets: %s", self.inlets, self.outlets)
        return func(self, context, *args, **kwargs)
Exemplo n.º 9
0
def _guess_type(python_type, name: str) -> Tuple[TType, Subtype]:
    if python_type == str or python_type == bytes:
        return TType.BINARY, None
    elif python_type == bool:
        return TType.BOOL, None
    elif python_type == int:
        raise ValueError(f"Ambiguous integer field {name}")
    elif python_type == float:
        return TType.DOUBLE, None
    elif attr.has(python_type):
        return TType.STRUCT, None

    type_class = _get_type_class(python_type)
    args = getattr(python_type, "__args__", None)
    if type_class == list:
        return TType.LIST, _guess_type(args[0], f"{name} item")
    elif type_class == dict:
        return TType.MAP, (
            _guess_type(args[0], f"{name} key"),
            _guess_type(args[1], f"{name} value"),
        )
    elif type_class == set:
        return TType.SET, _guess_type(args[0], f"{name} item")

    raise ValueError(f"Unknown type {python_type} for {name}")
Exemplo n.º 10
0
def enumerate_randomizable_params(
        parameters: PType) -> Iterable[_RandomizableParam]:
    """
    Recursively enumerate all randomizable params under given parameters type.
    return iterable of _RandomizableParam for each randomizable parameter.

    :param parameters: The parameters instance.
    """
    parameters_type = type(parameters)

    for field in attr.fields(parameters_type):
        metadata = field.metadata

        name = field.name

        if metadata.get("randomizable", False):
            assert field.type
            assert field.default is not None

            yield _RandomizableParam(
                name=name,
                value_type=field.type,
                default=getattr(parameters, name),
                value_range=(metadata["low"], metadata["high"]),
                parent_instance=parameters,
            )

        assert field.type, f"No type available for field {field}"

        if attr.has(field.type):
            child_instance = getattr(parameters, name)
            for param in enumerate_randomizable_params(child_instance):
                yield param
Exemplo n.º 11
0
 def __make_cmp_key(self, value):
     """Converts `value` to a hashable key."""
     if isinstance(value, (int, float, bool, dtypes.DType, TypeSpec)):
         return value
     if isinstance(value, compat.bytes_or_text_types):
         return value
     if value is None:
         return value
     if isinstance(value, dict):
         return tuple([
             tuple([
                 self.__make_cmp_key(key),
                 self.__make_cmp_key(value[key])
             ]) for key in sorted(value.keys())
         ])
     if attr.has(value):
         d = attr.asdict(value, dict_factory=collections.OrderedDict)
         return self.__make_cmp_key(d)
     if isinstance(value, tuple):
         return tuple([self.__make_cmp_key(v) for v in value])
     if isinstance(value, list):
         return (list, tuple([self.__make_cmp_key(v) for v in value]))
     if isinstance(value, tensor_shape.TensorShape):
         if value.ndims is None:
             # Note: we include a type object in the tuple, to ensure we can't get
             # false-positive matches (since users can't include type objects).
             return (tensor_shape.TensorShape, None)
         return (tensor_shape.TensorShape, tuple(value.as_list()))
     if isinstance(value, np.ndarray):
         return (np.ndarray, value.shape,
                 TypeSpec.__nested_list_to_tuple(value.tolist()))
     raise ValueError("Unsupported value type %s returned by "
                      "%s._serialize" %
                      (type(value).__name__, type(self).__name__))
Exemplo n.º 12
0
def pretty_repr(obj: Type):
    """
	Add a pretty-printing ``__repr__`` function to the decorated attrs class.

	.. code-block:: python

		>>> import attr
		>>> from attr_utils.pprinter import pretty_repr

		>>> @pretty_repr
		... @attr.s
		... class Person(object):
		... 	name = attr.ib()

		>>> repr(Person(name="Bob"))
		Person(name='Bob')

	:param obj:
	"""

    if attr.has(obj):

        def __repr__(self) -> str:
            return prettyprinter.pformat(self)

        __repr__.__doc__ = f"Return a string representation of the :class:`~.{obj.__name__}`."

        obj.__repr__ = __repr__  # type: ignore
        obj.__repr__.__qualname__ = f"{obj.__name__}.__repr__"
        obj.__repr__.__module__ = obj.__module__

    return obj
Exemplo n.º 13
0
def generate_toml_help(config_cls, *, parent=None):
    if parent is None:
        parent = tomlkit.table()
        doclines = trim(config_cls.__doc__).split("\n")
        for line in doclines:
            parent.add(tomlkit.comment(line))
        parent.add(tomlkit.nl())

    for attrib in attr.fields(config_cls):
        meta = attrib.metadata.get(CNF_KEY)
        if attr.has(attrib.type):
            # yield (attrib.name,), attrib
            sub_doc = generate_toml_help(attrib.type)
            parent.add(attrib.name, sub_doc)
        else:
            if meta:
                parent.add(tomlkit.comment(meta.help))

            if attrib.default in (missing, attr.NOTHING):
                parent.add(tomlkit.comment(f"{attrib.name} ="))
            else:
                default = (attrib.default()
                           if callable(attrib.default) else attrib.default)
                parent.add(attrib.name, default)

            parent.add(tomlkit.nl())

    return parent
Exemplo n.º 14
0
    def __create_requirements_check(
        argument_types: List[type],
    ) -> Callable:
        individual_argument_checkers = []
        for input_type in argument_types:
            is_list = False
            qn_type = input_type
            if _is_sequence_type(input_type):
                qn_type = input_type.__args__[0]  # type: ignore
                is_list = True

            if attr.has(qn_type):
                class_field_types = [
                    class_field.type
                    for class_field in attr.fields(qn_type)
                    if not _is_optional(class_field.type)
                ]
                qn_check_function: Callable[
                    ..., bool
                ] = _CompositeArgumentCheck(
                    class_field_types  # type: ignore
                )
            else:
                qn_check_function = _direct_qn_check(qn_type)

            if is_list:
                qn_check_function = _sequence_input_check(qn_check_function)

            individual_argument_checkers.append(qn_check_function)

        return _check_all_arguments(individual_argument_checkers)
Exemplo n.º 15
0
 def __call__(self, cls: AnyClass) -> AnyClass:
     if not isinstance(cls, type):
         raise TypeError("Can only decorate classes")
     if attr.has(cls):
         raise JsonableError(
             "@jsonable must be the inner-most decorator when used with @attr.s"
         )
     if dataclasses.is_dataclass(cls):
         raise JsonableError(
             "@jsonable must be the inner-most decorator when used with @dataclass"
         )
     if is_jsonable_class(cls) or self.is_delayed(cls):
         # dont' decorate or delay twice
         return cls
     decorator = _decorator_table[self.kind]
     kwargs = {
         kwarg: getattr(self, kwarg)
         for kwarg in self._KWARGS_OUT[self.kind]
     }
     if self.after:
         setattr(
             cls,
             PHERES_ATTR,
             DelayedData(functools.partial(decorator, **kwargs), self.kind),
         )
         self.delay(cls, self.after)
     else:
         cls = decorator(cls, **kwargs)
     self.decorate_delayed()
     return cls
Exemplo n.º 16
0
def generate_env_help(cls, env_prefix=""):
    """Generate a list of all environment options."""

    help = "\n".join(  # noqa: W0622
        "_".join((env_prefix, ) + path if env_prefix else path).upper()
        for path, type, _ in traverse_attrs(cls) if not attr.has(type))
    return help
Exemplo n.º 17
0
def _klasses_from_attr(cls):
    klasses = {}
    klasses[cls] = _fields_from_attr(cls)
    for field in attr.fields(cls):
        if field.type and attr.has(field.type):
            klasses[field.type] = _fields_from_attr(field.type)
    return klasses
Exemplo n.º 18
0
def attr_to_json(obj, **kwds):
    assert attr.has(obj)

    if kwds.pop("pretty", False):
        kwds.update(sort_keys=True, separators=(',', ': '), indent=4)

    return json.dumps(asdict_plus(obj), **kwds)
Exemplo n.º 19
0
def create_test_instance(cls, name=(), *, overrides={}):
    if '.'.join(name) in overrides:
        return overrides['.'.join(name)]
    if attr.has(cls):
        args = {}
        for field in attr.fields(cls):
            args[field.name] = create_test_instance(field.type,
                                                    name + (field.name, ),
                                                    overrides=overrides)
        return cls(**args)
    elif hasattr(cls, '__origin__'):
        t = cls.__origin__
        if t is typing.Union and type(None) in cls.__args__:
            return None
        elif t in [list, typing.List]:
            return [
                create_test_instance(cls.__args__[0],
                                     name,
                                     overrides=overrides),
            ]
        else:
            raise Exception("do not understand annotation {} at {}".format(
                t, '.'.join(name)))
    elif issubclass(cls, enum.Enum):
        return next(iter(cls))
    else:
        try:
            return cls()
        except Exception:
            raise Exception("instantiating {} failed".format(cls))
Exemplo n.º 20
0
def _sequence_like(instance, args):
    """Converts the sequence `args` to the same type as `instance`.

  Args:
    instance: an instance of `tuple`, `list`, or a `namedtuple` class.
    args: elements to be converted to a sequence.

  Returns:
    `args` with the type of `instance`.
  """
    if isinstance(instance, dict):
        # Pack dictionaries in a deterministic order by sorting the keys.
        # Notice this means that we ignore the original order of `OrderedDict`
        # instances. This is intentional, to avoid potential bugs caused by mixing
        # ordered and plain dicts (e.g., flattening a dict but using a
        # corresponding `OrderedDict` to pack it back).
        result = dict(zip(_sorted(instance), args))
        return type(instance)((key, result[key]) for key in instance)
    elif (isinstance(instance, tuple) and hasattr(instance, "_fields")
          and isinstance(instance._fields, _collections_abc.Sequence)
          and all(isinstance(f, _six.string_types) for f in instance._fields)):
        # This is a namedtuple
        return type(instance)(*args)
    elif attr.has(instance):
        # This is an attr class
        return type(instance)(*args)
    else:
        # Not a namedtuple
        return type(instance)(args)
Exemplo n.º 21
0
def _guess_type(python_type, name: str) -> RecursiveType:
    if python_type == str or python_type == bytes:
        return RecursiveType(TType.BINARY, python_type=python_type)
    elif python_type == bool:
        return RecursiveType(TType.BOOL, python_type=python_type)
    elif python_type == int:
        raise ValueError(f"Ambiguous integer field {name}")
    elif python_type == float:
        return RecursiveType(TType.DOUBLE, python_type=python_type)
    elif attr.has(python_type):
        return RecursiveType(TType.STRUCT, python_type=python_type)

    type_class = _get_type_class(python_type)
    args = getattr(python_type, "__args__", None)
    if type_class == list:
        return RecursiveType(TType.LIST,
                             item_type=_guess_type(args[0], f"{name} item"),
                             python_type=list)
    elif type_class == dict:
        return RecursiveType(TType.MAP,
                             key_type=_guess_type(args[0], f"{name} key"),
                             value_type=_guess_type(args[1], f"{name} value"),
                             python_type=dict)
    elif type_class == set:
        return RecursiveType(TType.SET,
                             item_type=_guess_type(args[0], f"{name} item"),
                             python_type=set)

    raise ValueError(f"Unknown type {python_type} for {name}")
Exemplo n.º 22
0
def deserialize(value):
    """Inverse of 'serialize'."""

    if isinstance(value, list):
        return [deserialize(elem) for elem in value]

    if isinstance(value, dict) and "_class" not in value:
        # Plain old dict
        out = {}
        for key, elem in value.items():
            out[key] = deserialize(elem)
        return out

    if isinstance(value, dict) and "_class" in value:
        value = value.copy()

        model_class = getattr(pulplib, value.pop("_class"))
        assert attr.has(model_class)

        # Deserialize everything inside it first using the plain dict
        # logic. This is where we recurse into nested attr classes, if any.
        value = deserialize(value)

        return model_class(**value)

    return value
Exemplo n.º 23
0
    def __create_argument_builder(
        argument_types: List[type],
    ) -> Callable:
        individual_argument_builders = []
        for input_type in argument_types:
            is_list = False
            qn_type = input_type
            if _is_sequence_type(input_type):
                qn_type = input_type.__args__[0]  # type: ignore
                is_list = True

            if attr.has(qn_type):
                arg_builder: Callable[..., Any] = _CompositeArgumentCreator(
                    qn_type
                )
            else:
                if _is_edge_quantum_number(qn_type):
                    arg_builder = _ValueExtractor[EdgeQuantumNumber](qn_type)
                elif _is_node_quantum_number(qn_type):
                    arg_builder = _ValueExtractor[NodeQuantumNumber](qn_type)
                else:
                    raise TypeError(
                        f"Quantum number type {qn_type} is not supported."
                        " Has to be of type Edge/NodeQuantumNumber."
                    )

            if is_list:
                arg_builder = _sequence_arg_builder(arg_builder)

            individual_argument_builders.append(arg_builder)

        return _build_all_arguments(individual_argument_builders)
Exemplo n.º 24
0
def parse_attr(cls):
    if not attr.has(cls):
        raise TypeError(f"{cls} is not an attrs class")

    fields = []
    for f in attr.fields(cls):
        if hasattr(f.type, "__args__"):
            t = " | ".join(str(i.__name__) for i in f.type.__args__)
        elif hasattr(f.type, "__name__"):
            t = f.type.__name__
        else:
            t = f.type

        fields.append(f" **{f.name} ({t})** {get_desc(f)}")
        if f.default is not attr.NOTHING:
            fields.append(f"> Default: {f.default}")
        if f.validator is not None:
            try:
                v = ", ".join(map(repr_validator, f.validator._validators)
                              )  # and validator - multiple validators
            except AttributeError:
                v = repr_validator(f.validator)
            fields.append(f"> Constraints:  {v}")

    if cls.__doc__ is None:
        docstring = [""]
    else:
        docstring = cls.__doc__.splitlines()
    short_desc, *rest = docstring
    fields = f"\n\n#### Fields" + "\n\n" + "\n\n".join(fields) + "\n\n"
    return short_desc + dedent("\n".join(rest)) + fields
Exemplo n.º 25
0
def InitializeModel(model: Any,
                    protoType: Any,
                    namePrefix: Optional[str] = None) -> None:

    if IsTransformed(type(model)):
        for name in GetInstanceVars(type(model)):
            memberModel = GetMemberModel(protoType, name)

            if memberModel is not None:
                # This member is also a transformed class
                # Use it directly as a nested transformation
                if namePrefix is not None:
                    memberName = "{}.{}".format(namePrefix, name)
                else:
                    memberName = name

                setattr(model, name,
                        memberModel(getattr(protoType, name), memberName))
            else:
                setattr(model, name,
                        model.TransformMember(name, protoType, namePrefix))

    if attr.has(type(model)):
        # Initialize additional attrs members.
        InitializeModelFromAttr(model, namePrefix)
Exemplo n.º 26
0
 def decorator(cls):
     if not attr.has(cls):
         raise attr.exceptions.NotAnAttrsClassError(
             'Include list feature works with attr.s classes only')
     cls._include_list_name = name
     IncludeLists.register(name, cls)
     return cls
Exemplo n.º 27
0
    def structure(d: Mapping, t: type) -> Any:
        """
        Helper method to structure a TrainerSettings class. Meant to be registered with
        cattr.register_structure_hook() and called with cattr.structure().
        """
        if not isinstance(d, Mapping):
            raise TrainerConfigError(f"Unsupported config {d} for {t.__name__}.")
        d_copy: Dict[str, Any] = {}
        d_copy.update(d)

        for key, val in d_copy.items():
            if attr.has(type(val)):
                # Don't convert already-converted attrs classes.
                continue
            if key == "hyperparameters":
                if "trainer_type" not in d_copy:
                    raise TrainerConfigError(
                        "Hyperparameters were specified but no trainer_type was given."
                    )
                else:
                    d_copy[key] = strict_to_cls(
                        d_copy[key], TrainerType(d_copy["trainer_type"]).to_settings()
                    )
            elif key == "max_steps":
                d_copy[key] = int(float(val))
                # In some legacy configs, max steps was specified as a float
            else:
                d_copy[key] = check_and_structure(key, val, t)
        return t(**d_copy)
def all_fields(klass):
    """Returns all attr fields on klass and all base classes."""
    out = attr.fields(klass)
    for base in klass.__bases__:
        if attr.has(base):
            out = out + attr.fields(base)
    return out
Exemplo n.º 29
0
def traverse_attrs(cls, *, target=None, get_type_hints=typing.get_type_hints):
    """Traverse a nested attrs structure, create a dictionary for each nested
    attrs class and yield all fields resp. path, type and target dictionary."""
    stack = [(
        target if target is not None else {},
        (),
        list(attr.fields(cls)),
        get_type_hints(cls),
    )]
    while stack:
        target, path, fields, hints = stack.pop()
        while fields:
            field = fields.pop()
            field_path = path + (field.name, )
            field_type = hints[field.name]
            if attr.has(field_type):
                target[field.name] = field_target = {}
                # XXX should we yield also attrs classes?
                yield field_path, field_type, target

                stack.append((target, path, fields, hints))
                target, path, fields, hints = (
                    field_target,
                    field_path,
                    list(attr.fields(field_type)),
                    get_type_hints(field_type),
                )
            else:
                yield field_path, field_type, target
Exemplo n.º 30
0
 def assert_proper_tuple_class(obj, obj_tuple):
     assert isinstance(obj_tuple, tuple_class)
     for index, field in enumerate(fields(obj.__class__)):
         field_val = getattr(obj, field.name)
         if has(field_val.__class__):
             # This field holds a class, recurse the assertions.
             assert_proper_tuple_class(field_val, obj_tuple[index])
Exemplo n.º 31
0
 def assert_proper_tuple_class(obj, obj_tuple):
     assert isinstance(obj_tuple, tuple_class)
     for index, field in enumerate(fields(obj.__class__)):
         field_val = getattr(obj, field.name)
         if has(field_val.__class__):
             # This field holds a class, recurse the assertions.
             assert_proper_tuple_class(field_val, obj_tuple[index])
Exemplo n.º 32
0
def is_attr_class(obj) -> bool:
    try:
        import attr
    except ImportError:  # pragma: no cover
        return False

    return attr.has(type(obj))
Exemplo n.º 33
0
 def decorator(cls):
   if not attr.has(cls):
     raise attr.exceptions.NotAnAttrsClassError(
       'Include list feature works with attr.s classes only'
     )
   cls._include_list_name = name
   IncludeLists.register(name, cls)
   return cls
Exemplo n.º 34
0
    def test_positive_empty(self):
        """
        Returns `True` on decorated classes even if there are no attributes.
        """
        @attributes
        class D(object):
            pass

        assert has(D)
Exemplo n.º 35
0
 def assert_proper_dict_class(obj, obj_dict):
     assert isinstance(obj_dict, dict_class)
     for field in fields(obj.__class__):
         field_val = getattr(obj, field.name)
         if has(field_val.__class__):
             # This field holds a class, recurse the assertions.
             assert_proper_dict_class(field_val, obj_dict[field.name])
         elif isinstance(field_val, Sequence):
             dict_val = obj_dict[field.name]
             for item, item_dict in zip(field_val, dict_val):
                 if has(item.__class__):
                     assert_proper_dict_class(item, item_dict)
         elif isinstance(field_val, Mapping):
             # This field holds a dictionary.
             assert isinstance(obj_dict[field.name], dict_class)
             for key, val in field_val.items():
                 if has(val.__class__):
                     assert_proper_dict_class(val,
                                              obj_dict[field.name][key])
Exemplo n.º 36
0
def test_hashability():
    """
    Validator classes are hashable.
    """
    for obj_name in dir(validator_module):
        obj = getattr(validator_module, obj_name)
        if not has(obj):
            continue
        hash_func = getattr(obj, '__hash__', None)
        assert hash_func is not None
        assert hash_func is not object.__hash__
Exemplo n.º 37
0
def stats_to_dict(stats, include_lists=None):
  """ Renders stats entity to dictionary. If include_lists is specified
  it will skip not included fields.

  Args:
    stats: An instance of stats entity.
    include_lists: An instance of IncludeLists.

  Returns:
    A dictionary representation of stats.
  """
  if not include_lists:
    return attr.asdict(stats)
  included = include_lists.get_included_attrs(stats.__class__)
  result = {}
  for att in included:
    if att not in included:
      continue
    value = getattr(stats, att.name)
    if value is MISSED:
      continue
    if value and isinstance(value, dict):
      if attr.has(value.itervalues().next()):
        # Only collections of attr types (stats models) should be converted
        value = {
          k: stats_to_dict(v, include_lists)
          for k, v in value.iteritems()
        }
    elif value and isinstance(value, list):
      if attr.has(value[0]):
        # Only collections of attr types (stats models) should be converted
        value = [stats_to_dict(v, include_lists) for v in value]
    elif attr.has(value):
      value = stats_to_dict(value, include_lists)
    result[att.name] = value
  return result
Exemplo n.º 38
0
 def process_obj(obj, seen_entity_obj_ids):
   """Convert an obj to dict replacing circular references."""
   # `seen_entity_obj_ids` is a list of entity object ids from the root
   #   `obj` to the current obj
   if attr.has(obj):
     if id(obj) in seen_entity_obj_ids:
       return "{} was here".format(obj.obj_type())
     obj_dict = collections.OrderedDict()
     for name, value in attr.asdict(obj, recurse=False).iteritems():
       obj_dict[name] = process_obj(
           value, seen_entity_obj_ids=seen_entity_obj_ids + [id(obj)])
     return obj_dict
   if isinstance(obj, list):
     return [process_obj(list_elem, seen_entity_obj_ids)
             for list_elem in obj]
   return obj
Exemplo n.º 39
0
 def test_negative(self):
     """
     Returns `False` on non-decorated classes.
     """
     assert not has(object)
Exemplo n.º 40
0
 def test_positive(self, C):
     """
     Returns `True` on decorated classes.
     """
     assert has(C)