def testAddInheritedMethods(self): src = textwrap.dedent(""" class A(nothing): foo: bool def f(self, x: int) -> float def h(self) -> complex class B(A): bar: int def g(self, y: int) -> bool def h(self, z: float) -> ? """) expected = textwrap.dedent(""" class A(nothing): foo: bool def f(self, x: int) -> float def h(self) -> complex class B(A): bar: int foo: bool def g(self, y: int) -> bool def h(self, z: float) -> ? def f(self, x: int) -> float """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, builtins.GetBuiltins()) ast = ast.Visit(optimize.AddInheritedMethods()) self.AssertSourceEquals(ast, expected)
def Optimize(node, flags=None): """Optimize a PYTD tree. Tries to shrink a PYTD tree by applying various optimizations. Arguments: node: A pytd node to be optimized. It won't be modified - this function will return a new node. flags: An instance of OptimizeFlags, to control which optimizations happen and what parameters to use for the ones that take parameters. Can be None, in which case defaults will be applied. Returns: An optimized node. """ node = node.Visit(RemoveDuplicates()) node = node.Visit(CombineReturnsAndExceptions()) node = node.Visit(Factorize()) node = node.Visit(ApplyOptionalArguments()) node = node.Visit(CombineContainers()) if flags and flags.lossy: hierarchy = node.Visit(visitors.ExtractSuperClassesByName()) node = node.Visit( FindCommonSuperClasses(hierarchy, flags and flags.use_abcs)) if flags and flags.max_union: node = node.Visit(CollapseLongParameterUnions(flags.max_union)) node = node.Visit(CollapseLongReturnUnions(flags.max_union)) if flags and flags.remove_mutable: node = node.Visit(AbsorbMutableParameters()) node = node.Visit(CombineContainers()) node = node.Visit(MergeTypeParameters()) node = node.Visit(visitors.AdjustSelf(force=True)) node = visitors.LookupClasses(node, builtins.GetBuiltins()) node = node.Visit(RemoveInheritedMethods()) return node
def testFindUnknownVisitor(self): src = textwrap.dedent(""" class `~unknown1`(nothing): pass class `~unknown_foobar`(nothing): pass class `~int`(nothing): pass class A(nothing): def foobar(self, x: `~unknown1`) -> ? class B(nothing): def foobar(self, x: `~int`) -> ? class C(nothing): x: `~unknown_foobar` class D(`~unknown1`): pass """) tree = self.Parse(src) tree = visitors.LookupClasses(tree) find_on = lambda x: tree.Lookup(x).Visit(visitors. RaiseIfContainsUnknown()) self.assertRaises(visitors.RaiseIfContainsUnknown.HasUnknown, find_on, "A") find_on("B") # shouldn't raise self.assertRaises(visitors.RaiseIfContainsUnknown.HasUnknown, find_on, "C") self.assertRaises(visitors.RaiseIfContainsUnknown.HasUnknown, find_on, "D")
def testLookupClasses(self): src = """ class object: pass class A: def a(self, a: A, b: B) -> A or B raises A, B class B: def b(self, a: A, b: B) -> A or B raises A, B """ tree = self.Parse(src) new_tree = visitors.LookupClasses(tree) self.AssertSourceEquals(new_tree, src) new_tree.Visit(VerifyLookup())
def testGeneric(self): ast = parser.parse_string( textwrap.dedent(""" class A<T extends nothing>(nothing): pass left: A<?> right: A<?> """)) ast = visitors.LookupClasses(ast) m = type_match.TypeMatch() self.assertEquals( m.match_type_against_type( ast.Lookup("left").type, ast.Lookup("right").type, {}), booleq.TRUE)
def testClassMatch(self): ast = parser.parse_string( textwrap.dedent(""" class Left(nothing): def method(self) -> ? class Right(nothing): def method(self) -> ? def method2(self) -> ? """)) ast = visitors.LookupClasses(ast) m = type_match.TypeMatch() left, right = ast.Lookup("Left"), ast.Lookup("Right") self.assertEquals(m.match(left, right, {}), booleq.TRUE) self.assertNotEquals(m.match(right, left, {}), booleq.TRUE)
def testSubclasses(self): ast = parser.parse_string( textwrap.dedent(""" class A(nothing): pass class B(A): pass a : A def left(a: B) -> B def right(a: A) -> A """)) ast = visitors.LookupClasses(ast) m = type_match.TypeMatch({ast.Lookup("a").type: [ast.Lookup("B")]}) left, right = ast.Lookup("left"), ast.Lookup("right") self.assertEquals(m.match(left, right, {}), booleq.TRUE) self.assertNotEquals(m.match(right, left, {}), booleq.TRUE)
def testSuperClasses(self): src = textwrap.dedent(""" class A(nothing): pass class B(nothing): pass class C(A): pass class D(A,B): pass class E(C,D,A): pass """) ast = visitors.LookupClasses(self.Parse(src)) data = ast.Visit(visitors.ExtractSuperClasses()) self.assertItemsEqual([], [t.name for t in data[ast.Lookup("A")]]) self.assertItemsEqual([], [t.name for t in data[ast.Lookup("B")]]) self.assertItemsEqual(["A"], [t.name for t in data[ast.Lookup("C")]]) self.assertItemsEqual(["A", "B"], [t.name for t in data[ast.Lookup("D")]]) self.assertItemsEqual(["C", "D", "A"], [t.name for t in data[ast.Lookup("E")]])
def testRemoveInheritedMethods(self): src = textwrap.dedent(""" class A(nothing): def f(self, y: int) -> bool def g(self) -> float class B(A): def b(self) -> B def f(self, y: int) -> bool class C(A): def c(self) -> C def f(self, y: int) -> bool class D(B): def g(self) -> float def d(self) -> D """) expected = textwrap.dedent(""" class A(nothing): def f(self, y: int) -> bool def g(self) -> float class B(A): def b(self) -> B class C(A): def c(self) -> C class D(B): def d(self) -> D """) ast = self.Parse(src) ast = visitors.LookupClasses(ast, builtins.GetBuiltins()) ast = ast.Visit(optimize.RemoveInheritedMethods()) self.AssertSourceEquals(ast, expected)
def testGetBuiltins(self): self.assertIsNotNone(self.builtins) self.assertTrue(hasattr(self.builtins, "modules")) # Will throw an error for unresolved identifiers: visitors.LookupClasses(self.builtins)