Example #1
0
def _combine_protocols(p1: Instance, p2: Instance) -> Instance:
    def base_repr(base):
        if 'pfun.Intersection' in base.type.fullname:
            return ', '.join([repr(b) for b in base.type.bases])
        return repr(base)

    def get_bases(base):
        if 'pfun.Intersection' in base.type.fullname:
            bases = set()
            for b in base.type.bases:
                bases |= get_bases(b)
            return bases
        return set([base])

    names = p1.type.names.copy()
    names.update(p2.type.names)
    keywords = p1.type.defn.keywords.copy()
    keywords.update(p2.type.defn.keywords)
    bases = get_bases(p1) | get_bases(p2)
    bases_repr = ', '.join(sorted([repr(base) for base in bases]))
    name = f'Intersection[{bases_repr}]'
    defn = ClassDef(name, Block([]),
                    p1.type.defn.type_vars + p2.type.defn.type_vars,
                    [NameExpr(p1.type.fullname),
                     NameExpr(p2.type.fullname)], None, list(keywords.items()))
    defn.fullname = f'pfun.{name}'
    info = TypeInfo(names, defn, '')
    info.is_protocol = True
    info.is_abstract = True
    info.bases = [p1, p2]
    info.abstract_attributes = (p1.type.abstract_attributes +
                                p2.type.abstract_attributes)
    calculate_mro(info)
    return Instance(info, p1.args + p2.args)
Example #2
0
def _create_intersection(args, context, api):
    defn = ClassDef('Intersection', Block([]))
    defn.fullname = 'pfun.Intersection'
    info = TypeInfo({}, defn, 'pfun')
    info.is_protocol = True
    calculate_mro(info)
    i = Instance(info, args, line=context.line, column=context.column)
    intersection_translator = TranslateIntersection(api, i)
    return i.accept(intersection_translator)
Example #3
0
    def visit_instance(self, t: Instance) -> Type:
        if 'pfun.Intersection' == t.type.fullname:
            args = [get_proper_type(arg) for arg in t.args]
            if any(isinstance(arg, AnyType) for arg in args):
                return AnyType(TypeOfAny.special_form)
            if all(
                    hasattr(arg, 'type')
                    and arg.type.fullname == 'builtins.object'
                    for arg in args):
                return args[0]
            is_typevar = lambda arg: isinstance(arg, TypeVarType)
            has_type_attr = lambda arg: hasattr(arg, 'type')
            is_protocol = lambda arg: arg.type.is_protocol
            is_object = lambda arg: arg.type.fullname == 'builtins.object'
            if not all(
                    is_typevar(arg) or has_type_attr(arg) and
                (is_protocol(arg) or is_object(arg)) for arg in args):
                s = str(t)
                if self.inferred:
                    msg = (f'All arguments to "Intersection" '
                           f'must be protocols but inferred "{s}"')
                else:
                    msg = (f'All arguments to "Intersection" '
                           f'must be protocols, but got "{s}"')
                self.api.msg.fail(msg, self.context)
                return AnyType(TypeOfAny.special_form)
            if not has_no_typevars(t):
                return t
            bases = []
            for arg in args:
                if arg in bases:
                    continue
                bases.extend(self.get_bases(arg, []))
            if len(bases) == 1:
                return bases[0]
            bases_repr = ', '.join([repr(base) for base in bases])
            name = f'Intersection[{bases_repr}]'
            defn = ClassDef(name, Block([]), [], [
                NameExpr(arg.name) if isinstance(arg, TypeVarType) else
                NameExpr(arg.type.fullname) for arg in args
            ], None, [])
            defn.fullname = f'pfun.{name}'
            info = TypeInfo({}, defn, '')
            info.is_protocol = True
            info.is_abstract = True
            info.bases = bases
            attrs = []
            for base in bases:
                if isinstance(base, TypeVarType):
                    continue
                attrs.extend(base.type.abstract_attributes)
            info.abstract_attributes = attrs
            try:
                calculate_mro(info)
            except MroError:
                self.api.msg.fail(
                    'Cannot determine consistent method resolution '
                    'order (MRO) for "%s"' % defn.fullname, self.context)
                return AnyType(TypeOfAny.special_form)
            return Instance(info, [])

        return super().visit_instance(t)