Ejemplo n.º 1
0
    def testAddInheritedMethods(self):
        src = textwrap.dedent("""
        class A(nothing):
            foo: bool
            def f(self, x: int) -> float
            def h(self) -> complex

        class B(A):
            bar: int
            def g(self, y: int) -> bool
            def h(self, z: float) -> ?
    """)
        expected = textwrap.dedent("""
        class A(nothing):
            foo: bool
            def f(self, x: int) -> float
            def h(self) -> complex

        class B(A):
            bar: int
            foo: bool
            def g(self, y: int) -> bool
            def h(self, z: float) -> ?
            def f(self, x: int) -> float
    """)
        ast = self.Parse(src)
        ast = visitors.LookupClasses(ast, builtins.GetBuiltins())
        ast = ast.Visit(optimize.AddInheritedMethods())
        self.AssertSourceEquals(ast, expected)
Ejemplo n.º 2
0
def Optimize(node, flags=None):
    """Optimize a PYTD tree.

  Tries to shrink a PYTD tree by applying various optimizations.

  Arguments:
    node: A pytd node to be optimized. It won't be modified - this function will
        return a new node.
    flags: An instance of OptimizeFlags, to control which optimizations
        happen and what parameters to use for the ones that take parameters. Can
        be None, in which case defaults will be applied.

  Returns:
    An optimized node.
  """
    node = node.Visit(RemoveDuplicates())
    node = node.Visit(CombineReturnsAndExceptions())
    node = node.Visit(Factorize())
    node = node.Visit(ApplyOptionalArguments())
    node = node.Visit(CombineContainers())
    if flags and flags.lossy:
        hierarchy = node.Visit(visitors.ExtractSuperClassesByName())
        node = node.Visit(
            FindCommonSuperClasses(hierarchy, flags and flags.use_abcs))
    if flags and flags.max_union:
        node = node.Visit(CollapseLongParameterUnions(flags.max_union))
        node = node.Visit(CollapseLongReturnUnions(flags.max_union))
    if flags and flags.remove_mutable:
        node = node.Visit(AbsorbMutableParameters())
        node = node.Visit(CombineContainers())
        node = node.Visit(MergeTypeParameters())
        node = node.Visit(visitors.AdjustSelf(force=True))
    node = visitors.LookupClasses(node, builtins.GetBuiltins())
    node = node.Visit(RemoveInheritedMethods())
    return node
Ejemplo n.º 3
0
 def testFindUnknownVisitor(self):
     src = textwrap.dedent("""
     class `~unknown1`(nothing):
       pass
     class `~unknown_foobar`(nothing):
       pass
     class `~int`(nothing):
       pass
     class A(nothing):
       def foobar(self, x: `~unknown1`) -> ?
     class B(nothing):
       def foobar(self, x: `~int`) -> ?
     class C(nothing):
       x: `~unknown_foobar`
     class D(`~unknown1`):
       pass
 """)
     tree = self.Parse(src)
     tree = visitors.LookupClasses(tree)
     find_on = lambda x: tree.Lookup(x).Visit(visitors.
                                              RaiseIfContainsUnknown())
     self.assertRaises(visitors.RaiseIfContainsUnknown.HasUnknown, find_on,
                       "A")
     find_on("B")  # shouldn't raise
     self.assertRaises(visitors.RaiseIfContainsUnknown.HasUnknown, find_on,
                       "C")
     self.assertRaises(visitors.RaiseIfContainsUnknown.HasUnknown, find_on,
                       "D")
Ejemplo n.º 4
0
 def testLookupClasses(self):
     src = """
     class object:
       pass
     class A:
       def a(self, a: A, b: B) -> A or B raises A, B
     class B:
       def b(self, a: A, b: B) -> A or B raises A, B
 """
     tree = self.Parse(src)
     new_tree = visitors.LookupClasses(tree)
     self.AssertSourceEquals(new_tree, src)
     new_tree.Visit(VerifyLookup())
Ejemplo n.º 5
0
 def testGeneric(self):
     ast = parser.parse_string(
         textwrap.dedent("""
   class A<T extends nothing>(nothing):
     pass
   left: A<?>
   right: A<?>
 """))
     ast = visitors.LookupClasses(ast)
     m = type_match.TypeMatch()
     self.assertEquals(
         m.match_type_against_type(
             ast.Lookup("left").type,
             ast.Lookup("right").type, {}), booleq.TRUE)
Ejemplo n.º 6
0
 def testClassMatch(self):
     ast = parser.parse_string(
         textwrap.dedent("""
   class Left(nothing):
     def method(self) -> ?
   class Right(nothing):
     def method(self) -> ?
     def method2(self) -> ?
 """))
     ast = visitors.LookupClasses(ast)
     m = type_match.TypeMatch()
     left, right = ast.Lookup("Left"), ast.Lookup("Right")
     self.assertEquals(m.match(left, right, {}), booleq.TRUE)
     self.assertNotEquals(m.match(right, left, {}), booleq.TRUE)
Ejemplo n.º 7
0
 def testSubclasses(self):
     ast = parser.parse_string(
         textwrap.dedent("""
   class A(nothing):
     pass
   class B(A):
     pass
   a : A
   def left(a: B) -> B
   def right(a: A) -> A
 """))
     ast = visitors.LookupClasses(ast)
     m = type_match.TypeMatch({ast.Lookup("a").type: [ast.Lookup("B")]})
     left, right = ast.Lookup("left"), ast.Lookup("right")
     self.assertEquals(m.match(left, right, {}), booleq.TRUE)
     self.assertNotEquals(m.match(right, left, {}), booleq.TRUE)
Ejemplo n.º 8
0
 def testSuperClasses(self):
     src = textwrap.dedent("""
   class A(nothing):
       pass
   class B(nothing):
       pass
   class C(A):
       pass
   class D(A,B):
       pass
   class E(C,D,A):
       pass
 """)
     ast = visitors.LookupClasses(self.Parse(src))
     data = ast.Visit(visitors.ExtractSuperClasses())
     self.assertItemsEqual([], [t.name for t in data[ast.Lookup("A")]])
     self.assertItemsEqual([], [t.name for t in data[ast.Lookup("B")]])
     self.assertItemsEqual(["A"], [t.name for t in data[ast.Lookup("C")]])
     self.assertItemsEqual(["A", "B"],
                           [t.name for t in data[ast.Lookup("D")]])
     self.assertItemsEqual(["C", "D", "A"],
                           [t.name for t in data[ast.Lookup("E")]])
Ejemplo n.º 9
0
    def testRemoveInheritedMethods(self):
        src = textwrap.dedent("""
        class A(nothing):
            def f(self, y: int) -> bool
            def g(self) -> float

        class B(A):
            def b(self) -> B
            def f(self, y: int) -> bool

        class C(A):
            def c(self) -> C
            def f(self, y: int) -> bool

        class D(B):
            def g(self) -> float
            def d(self) -> D
    """)
        expected = textwrap.dedent("""
        class A(nothing):
            def f(self, y: int) -> bool
            def g(self) -> float

        class B(A):
            def b(self) -> B

        class C(A):
            def c(self) -> C

        class D(B):
            def d(self) -> D
    """)
        ast = self.Parse(src)
        ast = visitors.LookupClasses(ast, builtins.GetBuiltins())
        ast = ast.Visit(optimize.RemoveInheritedMethods())
        self.AssertSourceEquals(ast, expected)
Ejemplo n.º 10
0
 def testGetBuiltins(self):
     self.assertIsNotNone(self.builtins)
     self.assertTrue(hasattr(self.builtins, "modules"))
     # Will throw an error for unresolved identifiers:
     visitors.LookupClasses(self.builtins)