def VisitSignature(self, sig): """Expand a single signature. For argument lists that contain disjunctions, generates all combinations of arguments. The expansion will be done right to left. E.g., from (a or b, c or d), this will generate the signatures (a, c), (a, d), (b, c), (b, d). (In that order) Arguments: sig: A pytd.Signature instance. Returns: A list. The visit function of the parent of this node (VisitFunction) will process this list further. """ params = [] for name, param_type in sig.params: if isinstance(param_type, pytd.UnionType): # multiple types params.append( [pytd.Parameter(name, t) for t in param_type.type_list]) else: # single type params.append([pytd.Parameter(name, param_type)]) new_signatures = [ sig.Replace(params=tuple(combination)) for combination in itertools.product(*params) ] return new_signatures # Hand list over to VisitFunction
def testTokens(self): """Test various token forms (int, float, n"...", etc.).""" # TODO: a test with '"' or "'" in a string data = textwrap.dedent(""" def `interface`(abcde: "xyz", foo: 'a"b', b: -1.0, c: 666) -> int """) result = self.Parse(data) f1 = result.Lookup("interface") f2 = pytd.Function( name="interface", signatures=(pytd.Signature( params=( pytd.Parameter(name="abcde", type=pytd.Scalar(value="xyz")), pytd.Parameter(name="foo", type=pytd.Scalar(value='a"b')), pytd.Parameter(name="b", type=pytd.Scalar(value=-1.0)), pytd.Parameter(name="c", type=pytd.Scalar(value=666))), return_type=pytd.NamedType("int"), exceptions=(), template=(), has_optional=False),)) self.assertEqual(f1, f2)
def VisitFunction(self, f): """Shrink a function, by factorizing cartesian products of arguments. Greedily groups signatures, looking at the arguments from left to right. This algorithm is *not* optimal. But it does the right thing for the typical cases. Arguments: f: An instance of pytd.Function. If this function has more than one signature, we will try to combine some of these signatures by introducing union types. Returns: A new, potentially optimized, instance of pytd.Function. """ max_argument_count = max(len(s.params) for s in f.signatures) signatures = f.signatures for i in xrange(max_argument_count): new_sigs = [] for sig, types in self._GroupByOmittedArg(signatures, i): if types: # One or more options for argument <i>: new_params = list(sig.params) new_params[i] = pytd.Parameter(sig.params[i].name, JoinTypes(types)) sig = sig.Replace(params=tuple(new_params)) new_sigs.append(sig) else: # Signature doesn't have argument <i>, so we store the original: new_sigs.append(sig) signatures = new_sigs return f.Replace(signatures=tuple(signatures))
def _GroupByOmittedArg(self, signatures, i): """Group functions that are identical if you ignore one of the arguments. Arguments: signatures: A list of function signatures i: The index of the argument to ignore during comparison. Returns: A list of tuples (signature, types). "signature" is a signature with argument i omitted, "types" is the list of types that argument was found to have. signatures that don't have argument i are represented as (original, None). """ groups = collections.OrderedDict() for sig in signatures: if i >= len(sig.params): # We can't omit argument i, because this signature has too few # arguments. Represent this signature as (original, None). groups[sig] = None continue # Set type of parameter i to None params = list(sig.params) param_i = params[i] params[i] = pytd.Parameter(param_i.name, None) stripped_signature = sig.Replace(params=tuple(params)) existing = groups.get(stripped_signature) if existing: existing.append(param_i.type) else: groups[stripped_signature] = [param_i.type] return groups.items()
def testComplexCombinedType(self): """Test parsing a type with both union and intersection.""" data1 = r"def foo(a: Foo or Bar and Zot) -> object" data2 = r"def foo(a: Foo or (Bar and Zot)) -> object" result1 = self.Parse(data1) result2 = self.Parse(data2) f = pytd.Function( name="foo", signatures=(pytd.Signature( params=( pytd.Parameter( name="a", type=pytd.UnionType( type_list=( pytd.NamedType("Foo"), pytd.IntersectionType( type_list=( pytd.NamedType("Bar"), pytd.NamedType("Zot")))) ) ),), return_type=pytd.NamedType("object"), template=(), has_optional=False, exceptions=()),)) self.assertEqual(f, result1.Lookup("foo")) self.assertEqual(f, result2.Lookup("foo"))
def VisitParameter(self, p): """Adjust all parameters called "self" to have their parent class type. But do this only if their original type is unoccupied ("object" or, if configured, "?"). Args: p: pytd.Parameter instance. Returns: Adjusted pytd.Parameter instance. """ if not self.class_types: # We're not within a class, so this is not a parameter of a method. return p if p.name == "self" and (self.force or p.type in self.replaced_self_types): return pytd.Parameter("self", self.class_types[-1]) else: return p
def p_param_and_type(self, p): """param : NAME COLON type""" p[0] = pytd.Parameter(p[1], p[3])
def p_param(self, p): """param : NAME""" # type is optional and defaults to "object" p[0] = pytd.Parameter(p[1], pytd.NamedType('object'))
def VisitMutableParameter(self, p): return pytd.Parameter(p.name, utils.JoinTypes([p.type, p.new_type]))