Beispiel #1
0
        def test_update_non_abc(self):
            class A:
                pass

            @abc.abstractmethod
            def updated_foo(self):
                pass

            A.foo = updated_foo
            abc.update_abstractmethods(A)
            A()
            self.assertFalse(hasattr(A, '__abstractmethods__'))
Beispiel #2
0
        def test_update_del(self):
            class A(metaclass=abc_ABCMeta):
                @abc.abstractmethod
                def foo(self):
                    pass

            del A.foo
            self.assertEqual(A.__abstractmethods__, {'foo'})
            self.assertFalse(hasattr(A, 'foo'))

            abc.update_abstractmethods(A)

            self.assertEqual(A.__abstractmethods__, set())
            A()
Beispiel #3
0
        def test_update_new_abstractmethods(self):
            class A(metaclass=abc_ABCMeta):
                @abc.abstractmethod
                def bar(self):
                    pass

            @abc.abstractmethod
            def updated_foo(self):
                pass

            A.foo = updated_foo
            abc.update_abstractmethods(A)
            self.assertEqual(A.__abstractmethods__, {'foo', 'bar'})
            msg = "class A with abstract methods bar, foo"
            self.assertRaisesRegex(TypeError, msg, A)
Beispiel #4
0
        def test_update_del_implementation(self):
            class A(metaclass=abc_ABCMeta):
                @abc.abstractmethod
                def foo(self):
                    pass

            class B(A):
                def foo(self):
                    pass

            B()

            del B.foo

            abc.update_abstractmethods(B)

            msg = "class B with abstract method foo"
            self.assertRaisesRegex(TypeError, msg, B)
Beispiel #5
0
        def test_update_implementation(self):
            class A(metaclass=abc_ABCMeta):
                @abc.abstractmethod
                def foo(self):
                    pass

            class B(A):
                pass

            msg = "class B with abstract method foo"
            self.assertRaisesRegex(TypeError, msg, B)
            self.assertEqual(B.__abstractmethods__, {'foo'})

            B.foo = lambda self: None

            abc.update_abstractmethods(B)

            B()
            self.assertEqual(B.__abstractmethods__, set())
Beispiel #6
0
        def test_update_layered_implementation(self):
            class A(metaclass=abc_ABCMeta):
                @abc.abstractmethod
                def foo(self):
                    pass

            class B(A):
                pass

            class C(B):
                def foo(self):
                    pass

            C()

            del C.foo

            abc.update_abstractmethods(C)

            msg = "class C without an implementation for abstract method foo"
            self.assertRaisesRegex(TypeError, msg, C)
Beispiel #7
0
        def test_update_multi_inheritance(self):
            class A(metaclass=abc_ABCMeta):
                @abc.abstractmethod
                def foo(self):
                    pass

            class B(metaclass=abc_ABCMeta):
                def foo(self):
                    pass

            class C(B, A):
                @abc.abstractmethod
                def foo(self):
                    pass

            self.assertEqual(C.__abstractmethods__, {'foo'})

            del C.foo

            abc.update_abstractmethods(C)

            self.assertEqual(C.__abstractmethods__, set())

            C()
Beispiel #8
0
def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
    # Now that dicts retain insertion order, there's no reason to use
    # an ordered dict.  I am leveraging that ordering here, because
    # derived class fields overwrite base class fields, but the order
    # is defined by the base class, which is found first.
    fields = {}

    if cls.__module__ in sys.modules:
        globals = sys.modules[cls.__module__].__dict__
    else:
        # Theoretically this can happen if someone writes
        # a custom string to cls.__module__.  In which case
        # such dataclass won't be fully introspectable
        # (w.r.t. typing.get_type_hints) but will still function
        # correctly.
        globals = {}

    setattr(cls, _PARAMS,
            _DataclassParams(init, repr, eq, order, unsafe_hash, frozen))

    # Find our base classes in reverse MRO order, and exclude
    # ourselves.  In reversed order so that more derived classes
    # override earlier field definitions in base classes.  As long as
    # we're iterating over them, see if any are frozen.
    any_frozen_base = False
    has_dataclass_bases = False
    for b in cls.__mro__[-1:0:-1]:
        # Only process classes that have been processed by our
        # decorator.  That is, they have a _FIELDS attribute.
        base_fields = getattr(b, _FIELDS, None)
        if base_fields is not None:
            has_dataclass_bases = True
            for f in base_fields.values():
                fields[f.name] = f
            if getattr(b, _PARAMS).frozen:
                any_frozen_base = True

    # Annotations that are defined in this class (not in base
    # classes).  If __annotations__ isn't present, then this class
    # adds no new annotations.  We use this to compute fields that are
    # added by this class.
    #
    # Fields are found from cls_annotations, which is guaranteed to be
    # ordered.  Default values are from class attributes, if a field
    # has a default.  If the default value is a Field(), then it
    # contains additional info beyond (and possibly including) the
    # actual default value.  Pseudo-fields ClassVars and InitVars are
    # included, despite the fact that they're not real fields.  That's
    # dealt with later.
    cls_annotations = cls.__dict__.get('__annotations__', {})

    # Now find fields in our class.  While doing so, validate some
    # things, and set the default values (as class attributes) where
    # we can.
    cls_fields = [
        _get_field(cls, name, type) for name, type in cls_annotations.items()
    ]
    for f in cls_fields:
        fields[f.name] = f

        # If the class attribute (which is the default value for this
        # field) exists and is of type 'Field', replace it with the
        # real default.  This is so that normal class introspection
        # sees a real default value, not a Field.
        if isinstance(getattr(cls, f.name, None), Field):
            if f.default is MISSING:
                # If there's no default, delete the class attribute.
                # This happens if we specify field(repr=False), for
                # example (that is, we specified a field object, but
                # no default value).  Also if we're using a default
                # factory.  The class attribute should not be set at
                # all in the post-processed class.
                delattr(cls, f.name)
            else:
                setattr(cls, f.name, f.default)

    # Do we have any Field members that don't also have annotations?
    for name, value in cls.__dict__.items():
        if isinstance(value, Field) and not name in cls_annotations:
            raise TypeError(f'{name!r} is a field but has no type annotation')

    # Check rules that apply if we are derived from any dataclasses.
    if has_dataclass_bases:
        # Raise an exception if any of our bases are frozen, but we're not.
        if any_frozen_base and not frozen:
            raise TypeError('cannot inherit non-frozen dataclass from a '
                            'frozen one')

        # Raise an exception if we're frozen, but none of our bases are.
        if not any_frozen_base and frozen:
            raise TypeError('cannot inherit frozen dataclass from a '
                            'non-frozen one')

    # Remember all of the fields on our class (including bases).  This
    # also marks this class as being a dataclass.
    setattr(cls, _FIELDS, fields)

    # Was this class defined with an explicit __hash__?  Note that if
    # __eq__ is defined in this class, then python will automatically
    # set __hash__ to None.  This is a heuristic, as it's possible
    # that such a __hash__ == None was not auto-generated, but it
    # close enough.
    class_hash = cls.__dict__.get('__hash__', MISSING)
    has_explicit_hash = not (class_hash is MISSING or
                             (class_hash is None and '__eq__' in cls.__dict__))

    # If we're generating ordering methods, we must be generating the
    # eq methods.
    if order and not eq:
        raise ValueError('eq must be true if order is true')

    if init:
        # Does this class have a post-init function?
        has_post_init = hasattr(cls, _POST_INIT_NAME)

        # Include InitVars and regular fields (so, not ClassVars).
        flds = [
            f for f in fields.values()
            if f._field_type in (_FIELD, _FIELD_INITVAR)
        ]
        _set_new_attribute(
            cls,
            '__init__',
            _init_fn(
                flds,
                frozen,
                has_post_init,
                # The name to use for the "self"
                # param in __init__.  Use "self"
                # if possible.
                '__dataclass_self__' if 'self' in fields else 'self',
                globals,
            ))

    # Get the fields as a list, and include only real fields.  This is
    # used in all of the following methods.
    field_list = [f for f in fields.values() if f._field_type is _FIELD]

    if repr:
        flds = [f for f in field_list if f.repr]
        _set_new_attribute(cls, '__repr__', _repr_fn(flds, globals))

    if eq:
        # Create _eq__ method.  There's no need for a __ne__ method,
        # since python will call __eq__ and negate it.
        flds = [f for f in field_list if f.compare]
        self_tuple = _tuple_str('self', flds)
        other_tuple = _tuple_str('other', flds)
        _set_new_attribute(
            cls, '__eq__',
            _cmp_fn('__eq__', '==', self_tuple, other_tuple, globals=globals))

    if order:
        # Create and set the ordering methods.
        flds = [f for f in field_list if f.compare]
        self_tuple = _tuple_str('self', flds)
        other_tuple = _tuple_str('other', flds)
        for name, op in [
            ('__lt__', '<'),
            ('__le__', '<='),
            ('__gt__', '>'),
            ('__ge__', '>='),
        ]:
            if _set_new_attribute(
                    cls, name,
                    _cmp_fn(name, op, self_tuple, other_tuple,
                            globals=globals)):
                raise TypeError(f'Cannot overwrite attribute {name} '
                                f'in class {cls.__name__}. Consider using '
                                'functools.total_ordering')

    if frozen:
        for fn in _frozen_get_del_attr(cls, field_list, globals):
            if _set_new_attribute(cls, fn.__name__, fn):
                raise TypeError(f'Cannot overwrite attribute {fn.__name__} '
                                f'in class {cls.__name__}')

    # Decide if/how we're going to create a hash function.
    hash_action = _hash_action[bool(unsafe_hash),
                               bool(eq),
                               bool(frozen), has_explicit_hash]
    if hash_action:
        # No need to call _set_new_attribute here, since by the time
        # we're here the overwriting is unconditional.
        cls.__hash__ = hash_action(cls, field_list, globals)

    if not getattr(cls, '__doc__'):
        # Create a class doc-string.
        cls.__doc__ = (cls.__name__ +
                       str(inspect.signature(cls)).replace(' -> NoneType', ''))

    abc.update_abstractmethods(cls)

    return cls