def _combine_protocols(p1: Instance, p2: Instance) -> Instance: def base_repr(base): if 'pfun.Intersection' in base.type.fullname: return ', '.join([repr(b) for b in base.type.bases]) return repr(base) def get_bases(base): if 'pfun.Intersection' in base.type.fullname: bases = set() for b in base.type.bases: bases |= get_bases(b) return bases return set([base]) names = p1.type.names.copy() names.update(p2.type.names) keywords = p1.type.defn.keywords.copy() keywords.update(p2.type.defn.keywords) bases = get_bases(p1) | get_bases(p2) bases_repr = ', '.join(sorted([repr(base) for base in bases])) name = f'Intersection[{bases_repr}]' defn = ClassDef(name, Block([]), p1.type.defn.type_vars + p2.type.defn.type_vars, [NameExpr(p1.type.fullname), NameExpr(p2.type.fullname)], None, list(keywords.items())) defn.fullname = f'pfun.{name}' info = TypeInfo(names, defn, '') info.is_protocol = True info.is_abstract = True info.bases = [p1, p2] info.abstract_attributes = (p1.type.abstract_attributes + p2.type.abstract_attributes) calculate_mro(info) return Instance(info, p1.args + p2.args)
def _create_intersection(args, context, api): defn = ClassDef('Intersection', Block([])) defn.fullname = 'pfun.Intersection' info = TypeInfo({}, defn, 'pfun') info.is_protocol = True calculate_mro(info) i = Instance(info, args, line=context.line, column=context.column) intersection_translator = TranslateIntersection(api, i) return i.accept(intersection_translator)
def visit_instance(self, t: Instance) -> Type: if 'pfun.Intersection' == t.type.fullname: args = [get_proper_type(arg) for arg in t.args] if any(isinstance(arg, AnyType) for arg in args): return AnyType(TypeOfAny.special_form) if all( hasattr(arg, 'type') and arg.type.fullname == 'builtins.object' for arg in args): return args[0] is_typevar = lambda arg: isinstance(arg, TypeVarType) has_type_attr = lambda arg: hasattr(arg, 'type') is_protocol = lambda arg: arg.type.is_protocol is_object = lambda arg: arg.type.fullname == 'builtins.object' if not all( is_typevar(arg) or has_type_attr(arg) and (is_protocol(arg) or is_object(arg)) for arg in args): s = str(t) if self.inferred: msg = (f'All arguments to "Intersection" ' f'must be protocols but inferred "{s}"') else: msg = (f'All arguments to "Intersection" ' f'must be protocols, but got "{s}"') self.api.msg.fail(msg, self.context) return AnyType(TypeOfAny.special_form) if not has_no_typevars(t): return t bases = [] for arg in args: if arg in bases: continue bases.extend(self.get_bases(arg, [])) if len(bases) == 1: return bases[0] bases_repr = ', '.join([repr(base) for base in bases]) name = f'Intersection[{bases_repr}]' defn = ClassDef(name, Block([]), [], [ NameExpr(arg.name) if isinstance(arg, TypeVarType) else NameExpr(arg.type.fullname) for arg in args ], None, []) defn.fullname = f'pfun.{name}' info = TypeInfo({}, defn, '') info.is_protocol = True info.is_abstract = True info.bases = bases attrs = [] for base in bases: if isinstance(base, TypeVarType): continue attrs.extend(base.type.abstract_attributes) info.abstract_attributes = attrs try: calculate_mro(info) except MroError: self.api.msg.fail( 'Cannot determine consistent method resolution ' 'order (MRO) for "%s"' % defn.fullname, self.context) return AnyType(TypeOfAny.special_form) return Instance(info, []) return super().visit_instance(t)