def testRemoveRedundantSignature(self): src = textwrap.dedent(""" def foo(a: int) -> int def foo(a: int or bool) -> int """) expected = textwrap.dedent(""" def foo(a: int or bool) -> int """) ast = self.Parse(src) ast = ast.Visit(optimize.RemoveRedundantSignatures( optimize.SuperClassHierarchy({}))) self.AssertSourceEquals(ast, expected)
def test_remove_redundant_signature(self): src = pytd_src(""" def foo(a: int) -> int: ... def foo(a: Union[int, bool]) -> int: ... """) expected = pytd_src(""" def foo(a: Union[int, bool]) -> int: ... """) ast = self.Parse(src) ast = ast.Visit( optimize.RemoveRedundantSignatures(optimize.SuperClassHierarchy( {}))) self.AssertSourceEquals(ast, expected)
def testRemoveRedundantSignatureGenericLeftSide(self): src = textwrap.dedent(""" X = TypeVar("X") def foo(a: X, b: int) -> X def foo(a: X, b: Any) -> X """) expected = textwrap.dedent(""" X = TypeVar("X") def foo(a: X, b: Any) -> X """) ast = self.Parse(src) ast = ast.Visit(optimize.RemoveRedundantSignatures( optimize.SuperClassHierarchy({}))) self.AssertSourceEquals(ast, expected)
def test_simplify_unions_with_superclasses_generic(self): src = pytd_src(""" x = ... # type: Union[frozenset[int], AbstractSet[int]] """) expected = pytd_src(""" x = ... # type: AbstractSet[int] """) hierarchy = self.builtins.Visit(visitors.ExtractSuperClassesByName()) visitor = optimize.SimplifyUnionsWithSuperclasses( optimize.SuperClassHierarchy(hierarchy)) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) ast = ast.Visit(visitor) self.AssertSourceEquals(ast, expected)
def testSimplifyUnionsWithSuperclassesGeneric(self): src = textwrap.dedent(""" x = ... # type: frozenset[int] or AbstractSet[int] """) expected = textwrap.dedent(""" x = ... # type: AbstractSet[int] """) hierarchy = self.builtins.Visit(visitors.ExtractSuperClassesByName()) visitor = optimize.SimplifyUnionsWithSuperclasses( optimize.SuperClassHierarchy(hierarchy)) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) ast = ast.Visit(visitor) self.AssertSourceEquals(ast, expected)
def testFindCommonSuperClasses(self): src = textwrap.dedent(""" x = ... # type: int or other.Bar """) expected = textwrap.dedent(""" x = ... # type: int or other.Bar """) ast = self.Parse(src) ast = ast.Visit(visitors.ReplaceTypes( {"other.Bar": pytd.LateType("other.Bar")})) hierarchy = ast.Visit(visitors.ExtractSuperClassesByName()) ast = ast.Visit(optimize.FindCommonSuperClasses( optimize.SuperClassHierarchy(hierarchy))) ast = ast.Visit(visitors.LateTypeToClassType()) self.AssertSourceEquals(ast, expected)
def test_remove_redundant_signature_polymorphic(self): src = textwrap.dedent(""" T = TypeVar("T") def foo(a: T) -> T def foo(a: int or bool) -> int """) expected = textwrap.dedent(""" T = TypeVar("T") def foo(a: T) -> T """) ast = self.Parse(src) ast = ast.Visit( optimize.RemoveRedundantSignatures(optimize.SuperClassHierarchy( {}))) self.AssertSourceEquals(ast, expected)
def testBuiltinSuperClasses(self): src = textwrap.dedent(""" def f(x: list or object, y: int or float) -> int or bool """) expected = textwrap.dedent(""" def f(x, y) -> int """) b = builtins.GetBuiltinsPyTD() hierarchy = b.Visit(visitors.ExtractSuperClassesByName()) visitor = optimize.FindCommonSuperClasses( optimize.SuperClassHierarchy(hierarchy)) ast = self.ParseAndResolve(src) ast = ast.Visit(visitor) ast = ast.Visit(visitors.DropBuiltinPrefix()) self.AssertSourceEquals(ast, expected)
def test_remove_redundant_signature_with_exceptions(self): src = textwrap.dedent(""" def foo(a: int) -> int: raise IOError() def foo(a: int or bool) -> int """) expected = textwrap.dedent(""" def foo(a: int) -> int: raise IOError() def foo(a: int or bool) -> int """) ast = self.Parse(src) ast = ast.Visit( optimize.RemoveRedundantSignatures(optimize.SuperClassHierarchy( {}))) self.AssertSourceEquals(ast, expected)
def testBuiltinSuperClasses(self): src = textwrap.dedent(""" def f(x: list or object, y: complex or memoryview) -> int or bool """) expected = textwrap.dedent(""" def f(x: object, y: object) -> int """) hierarchy = self.builtins.Visit(visitors.ExtractSuperClassesByName()) hierarchy.update(self.typing.Visit(visitors.ExtractSuperClassesByName())) visitor = optimize.FindCommonSuperClasses( optimize.SuperClassHierarchy(hierarchy)) ast = self.ParseAndResolve(src) ast = ast.Visit(visitor) ast = ast.Visit(visitors.DropBuiltinPrefix()) ast = ast.Visit(visitors.CanonicalOrderingVisitor()) self.AssertSourceEquals(ast, expected)
def test_builtin_superclasses(self): src = pytd_src(""" def f(x: Union[list, object], y: Union[complex, memoryview]) -> Union[int, bool]: ... """) expected = pytd_src(""" def f(x: builtins.object, y: builtins.object) -> builtins.int: ... """) hierarchy = self.builtins.Visit(visitors.ExtractSuperClassesByName()) hierarchy.update( self.typing.Visit(visitors.ExtractSuperClassesByName())) visitor = optimize.FindCommonSuperClasses( optimize.SuperClassHierarchy(hierarchy)) ast = self.ParseAndResolve(src) ast = ast.Visit(visitor) ast = ast.Visit(visitors.CanonicalOrderingVisitor()) self.AssertSourceEquals(ast, expected)
def test_find_common_superclasses(self): src = pytd_src(""" x = ... # type: Union[int, other.Bar] """) expected = pytd_src(""" x = ... # type: Union[int, other.Bar] """) ast = self.Parse(src) ast = ast.Visit( visitors.ReplaceTypes({"other.Bar": pytd.LateType("other.Bar")})) hierarchy = ast.Visit(visitors.ExtractSuperClassesByName()) ast = ast.Visit( optimize.FindCommonSuperClasses( optimize.SuperClassHierarchy(hierarchy))) ast = ast.Visit(visitors.LateTypeToClassType()) self.AssertSourceEquals(ast, expected)
def test_remove_redundant_signature_generic_left_side(self): src = pytd_src(""" from typing import Any X = TypeVar("X") def foo(a: X, b: int) -> X: ... def foo(a: X, b: Any) -> X: ... """) expected = pytd_src(""" from typing import Any X = TypeVar("X") def foo(a: X, b: Any) -> X: ... """) ast = self.Parse(src) ast = ast.Visit( optimize.RemoveRedundantSignatures(optimize.SuperClassHierarchy( {}))) self.AssertSourceEquals(ast, expected)
def testSimplifyUnionsWithSuperclasses(self): src = textwrap.dedent(""" x = ... # type: int or bool y = ... # type: int or bool or float z = ... # type: list[int] or int """) expected = textwrap.dedent(""" x = ... # type: int y = ... # type: int or float z = ... # type: list[int] or int """) hierarchy = self.builtins.Visit(visitors.ExtractSuperClassesByName()) visitor = optimize.SimplifyUnionsWithSuperclasses( optimize.SuperClassHierarchy(hierarchy)) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) ast = ast.Visit(visitor) self.AssertSourceEquals(ast, expected)
def testRemoveRedundantSignatureTemplate(self): src = textwrap.dedent(""" T = TypeVar("T") class A(Generic[T]): def foo(a: int) -> int def foo(a: T) -> T def foo(a: int or bool) -> int """) expected = textwrap.dedent(""" T = TypeVar("T") class A(Generic[T]): def foo(a: T) -> T def foo(a: int or bool) -> int """) ast = self.Parse(src) ast = ast.Visit(optimize.RemoveRedundantSignatures( optimize.SuperClassHierarchy({}))) self.AssertSourceEquals(ast, expected)
def test_simplify_unions_with_superclasses(self): src = pytd_src(""" x = ... # type: Union[int, bool] y = ... # type: Union[int, bool, float] z = ... # type: Union[list[int], int] """) expected = pytd_src(""" x = ... # type: int y = ... # type: Union[int, float] z = ... # type: Union[list[int], int] """) hierarchy = self.builtins.Visit(visitors.ExtractSuperClassesByName()) visitor = optimize.SimplifyUnionsWithSuperclasses( optimize.SuperClassHierarchy(hierarchy)) ast = self.Parse(src) ast = visitors.LookupClasses(ast, self.builtins) ast = ast.Visit(visitor) self.AssertSourceEquals(ast, expected)
def test_remove_redundant_signature_template(self): src = pytd_src(""" T = TypeVar("T") class A(Generic[T]): def foo(a: int) -> int: ... def foo(a: T) -> T: ... def foo(a: Union[int, bool]) -> int: ... """) expected = pytd_src(""" T = TypeVar("T") class A(Generic[T]): def foo(a: T) -> T: ... def foo(a: Union[int, bool]) -> int: ... """) ast = self.Parse(src) ast = ast.Visit( optimize.RemoveRedundantSignatures(optimize.SuperClassHierarchy( {}))) self.AssertSourceEquals(ast, expected)
def test_user_superclass_hierarchy(self): class_data = pytd_src(""" class AB: pass class EFG: pass class A(AB, EFG): pass class B(AB): pass class E(EFG, AB): pass class F(EFG): pass class G(EFG): pass """) src = pytd_src(""" from typing import Any def f(x: Union[A, B], y: A, z: B) -> Union[E, F, G]: ... def g(x: Union[E, F, G, B]) -> Union[E, F]: ... def h(x) -> Any: ... """) + class_data expected = pytd_src(""" from typing import Any def f(x: AB, y: A, z: B) -> EFG: ... def g(x: object) -> EFG: ... def h(x) -> Any: ... """) + class_data hierarchy = self.Parse(src).Visit(visitors.ExtractSuperClassesByName()) visitor = optimize.FindCommonSuperClasses( optimize.SuperClassHierarchy(hierarchy)) new_src = self.ApplyVisitorToString(src, visitor) self.AssertSourceEquals(new_src, expected)
def testRemoveRedundantSignatureTwoTypeParams(self): src = textwrap.dedent(""" X = TypeVar("X") Y = TypeVar("Y") class A(Generic[X, Y]): def foo(a: X) -> Y def foo(a: Y) -> Y """) expected = textwrap.dedent(""" X = TypeVar("X") Y = TypeVar("Y") class A(Generic[X, Y]): def foo(a: X) -> Y def foo(a: Y) -> Y """) ast = self.Parse(src) ast = ast.Visit(optimize.RemoveRedundantSignatures( optimize.SuperClassHierarchy({}))) self.AssertSourceEquals(ast, expected)
def testUserSuperClassHierarchy(self): class_data = textwrap.dedent(""" class AB(object): pass class EFG(object): pass class A(AB, EFG): pass class B(AB): pass class E(EFG, AB): pass class F(EFG): pass class G(EFG): pass """) src = textwrap.dedent(""" def f(x: A or B, y: A, z: B) -> E or F or G def g(x: E or F or G or B) -> E or F def h(x) -> ? """) + class_data expected = textwrap.dedent(""" def f(x: AB, y: A, z: B) -> EFG def g(x) -> EFG def h(x) -> ? """) + class_data hierarchy = self.Parse(src).Visit( visitors.ExtractSuperClassesByName()) visitor = optimize.FindCommonSuperClasses( optimize.SuperClassHierarchy(hierarchy)) new_src = self.ApplyVisitorToString(src, visitor) self.AssertSourceEquals(new_src, expected)
def test_remove_redundant_signature_two_type_params(self): src = pytd_src(""" X = TypeVar("X") Y = TypeVar("Y") class A(Generic[X, Y]): def foo(a: X) -> Y: ... def foo(a: Y) -> Y: ... """) expected = pytd_src(""" X = TypeVar("X") Y = TypeVar("Y") class A(Generic[X, Y]): def foo(a: X) -> Y: ... def foo(a: Y) -> Y: ... """) ast = self.Parse(src) ast = ast.Visit( optimize.RemoveRedundantSignatures(optimize.SuperClassHierarchy( {}))) self.AssertSourceEquals(ast, expected)