def AssertSourceEquals(self, src_or_tree_1, src_or_tree_2): # Strip leading "\n"s for convenience ast1 = self.ToAST(src_or_tree_1) ast2 = self.ToAST(src_or_tree_2) src1 = pytd_utils.Print(ast1).strip() + "\n" src2 = pytd_utils.Print(ast2).strip() + "\n" # Verify printed versions are the same and ASTs are the same. ast1 = ast1.Visit(visitors.ClassTypeToNamedType()) ast2 = ast2.Visit(visitors.ClassTypeToNamedType()) if src1 != src2 or not pytd_utils.ASTeq(ast1, ast2): # Due to differing opinions on the form of debug output, allow an # environment variable to control what output you want. Set # PY_UNITTEST_DIFF to get diff output. if os.getenv("PY_UNITTEST_DIFF"): self.maxDiff = None # for better diff output (assertMultiLineEqual) # pylint: disable=invalid-name self.assertMultiLineEqual(src1, src2) else: sys.stdout.flush() sys.stderr.flush() print("Source files or ASTs differ:", file=sys.stderr) print("-" * 36, " Actual ", "-" * 36, file=sys.stderr) print(textwrap.dedent(src1).strip(), file=sys.stderr) print("-" * 36, "Expected", "-" * 36, file=sys.stderr) print(textwrap.dedent(src2).strip(), file=sys.stderr) print("-" * 80, file=sys.stderr) if not pytd_utils.ASTeq(ast1, ast2): print("Actual AST:", ast1, file=sys.stderr) print("Expect AST:", ast2, file=sys.stderr) self.fail("source files differ")
def test_asteq(self): # This creates two ASts that are equivalent but whose sources are slightly # different. The union types are different (int,str) vs (str,int) but the # ordering is ignored when testing for equality (which ASTeq uses). src1 = textwrap.dedent(""" from typing import Union def foo(a: Union[int, str]) -> C: ... T = TypeVar('T') class C(typing.Generic[T], object): def bar(x: T) -> NoneType: ... CONSTANT = ... # type: C[float] """) src2 = textwrap.dedent(""" from typing import Union CONSTANT = ... # type: C[float] T = TypeVar('T') class C(typing.Generic[T], object): def bar(x: T) -> NoneType: ... def foo(a: Union[str, int]) -> C: ... """) tree1 = parser.parse_string(src1, python_version=self.python_version) tree2 = parser.parse_string(src2, python_version=self.python_version) tree1.Visit(visitors.VerifyVisitor()) tree2.Visit(visitors.VerifyVisitor()) self.assertTrue(tree1.constants) self.assertTrue(tree1.classes) self.assertTrue(tree1.functions) self.assertTrue(tree2.constants) self.assertTrue(tree2.classes) self.assertTrue(tree2.functions) self.assertIsInstance(tree1, pytd.TypeDeclUnit) self.assertIsInstance(tree2, pytd.TypeDeclUnit) # For the ==, != tests, TypeDeclUnit uses identity # pylint: disable=g-generic-assert # pylint: disable=comparison-with-itself self.assertTrue(tree1 == tree1) self.assertTrue(tree2 == tree2) self.assertFalse(tree1 == tree2) self.assertFalse(tree2 == tree1) self.assertFalse(tree1 != tree1) self.assertFalse(tree2 != tree2) self.assertTrue(tree1 != tree2) self.assertTrue(tree2 != tree1) # pylint: enable=g-generic-assert # pylint: enable=comparison-with-itself self.assertEqual(tree1, tree1) self.assertEqual(tree2, tree2) self.assertNotEqual(tree1, tree2) self.assertTrue(pytd_utils.ASTeq(tree1, tree2)) self.assertTrue(pytd_utils.ASTeq(tree1, tree1)) self.assertTrue(pytd_utils.ASTeq(tree2, tree1)) self.assertTrue(pytd_utils.ASTeq(tree2, tree2))
def assertInferredPyiEquals(self, expected_pyi=None, filename=None): assert bool(expected_pyi) != bool(filename) if filename: with open(self._DataPath(filename), "r") as f: expected_pyi = f.read() message = ("\n==Expected pyi==\n" + expected_pyi + "\n==Actual pyi==\n" + self.stdout) self.assertTrue(pytd_utils.ASTeq(self._ParseString(self.stdout), self._ParseString(expected_pyi)), message)
def testInferToFile(self): self.pytype_args[self._DataPath("simple.py")] = self.INCLUDE pyi_file = self._TmpPath("simple.pyi") self.pytype_args["--output"] = pyi_file self._RunPytype(self.pytype_args) self.assertOutputStateMatches(stdout=False, stderr=False, returncode=False) with open(pyi_file, "r") as f: pyi = f.read() with open(self._DataPath("simple.pyi"), "r") as f: expected_pyi = f.read() self.assertTrue(pytd_utils.ASTeq(self._ParseString(pyi), self._ParseString(expected_pyi)))
def test_infer_to_file(self): self.pytype_args[self._data_path("simple.py")] = self.INCLUDE pyi_file = self._tmp_path("simple.pyi") self.pytype_args["--output"] = pyi_file self._run_pytype(self.pytype_args) self.assertOutputStateMatches(stdout=False, stderr=False, returncode=False) with open(pyi_file, "r") as f: pyi = f.read() with open(self._data_path("simple.pyi"), "r") as f: expected_pyi = f.read() self.assertTrue(pytd_utils.ASTeq(self._parse_string(pyi), self._parse_string(expected_pyi)))
def test_load_with_different_module_name(self): with file_utils.Tempdir() as d: original_module_name = "module1" pickled_ast_filename = os.path.join(d.path, "module1.pyi.pickled") module_map = self._store_ast( d, original_module_name, pickled_ast_filename) original_ast = module_map[original_module_name] del module_map[original_module_name] new_module_name = "wurstbrot.module2" serializable_ast = pytd_utils.LoadPickle(pickled_ast_filename) serializable_ast = serialize_ast.EnsureAstName( serializable_ast, new_module_name, fix=True) loaded_ast = serialize_ast.ProcessAst(serializable_ast, module_map) self.assertTrue(loaded_ast) self.assertIsNot(loaded_ast, original_ast) self.assertEqual(loaded_ast.name, new_module_name) loaded_ast.Visit(visitors.VerifyLookup()) self.assertFalse(pytd_utils.ASTeq(original_ast, loaded_ast)) ast_new_module, _ = self._get_ast(temp_dir=d, module_name=new_module_name) self.assertTrue(pytd_utils.ASTeq(ast_new_module, loaded_ast))
def test_load_top_level(self): """Tests that a pickled file can be read.""" with file_utils.Tempdir() as d: module_name = "module1" pickled_ast_filename = os.path.join(d.path, "module1.pyi.pickled") module_map = self._store_ast(d, module_name, pickled_ast_filename) original_ast = module_map[module_name] del module_map[module_name] loaded_ast = serialize_ast.ProcessAst( pytd_utils.LoadPickle(pickled_ast_filename), module_map) self.assertTrue(loaded_ast) self.assertIsNot(loaded_ast, original_ast) self.assertEqual(loaded_ast.name, module_name) self.assertTrue(pytd_utils.ASTeq(original_ast, loaded_ast)) loaded_ast.Visit(visitors.VerifyLookup())
def test_load_with_same_module_name(self): with file_utils.Tempdir() as d: self._create_files(tempdir=d) module1 = _Module(module_name="foo.bar.module1", file_name="module1.pyi") module2 = _Module(module_name="module2", file_name="module2.pyi") loader, ast = self._load_ast(tempdir=d, module=module1) self._pickle_modules(loader, d, module1, module2) pickled_ast_filename = self._get_path(d, module1.file_name + ".pickled") result = serialize_ast.StoreAst(ast, pickled_ast_filename) self.assertIsNone(result) loaded_ast = self._load_pickled_module(d, module1) self.assertTrue(loaded_ast) self.assertIsNot(loaded_ast, ast) self.assertTrue(pytd_utils.ASTeq(ast, loaded_ast)) loaded_ast.Visit(visitors.VerifyLookup())
def write_pickle(ast, options, loader=None): """Dump a pickle of the ast to a file.""" loader = loader or load_pytd.create_loader(options) try: ast = serialize_ast.PrepareForExport(options.module_name, ast, loader) except parser.ParseError as e: if options.nofail: ast = serialize_ast.PrepareForExport( options.module_name, pytd_builtins.GetDefaultAst(options.python_version), loader) log.warning("***Caught exception: %s", str(e), exc_info=True) else: raise if options.verify_pickle: ast1 = ast.Visit(visitors.LateTypeToClassType()) ast1 = ast1.Visit(visitors.ClearClassPointers()) ast2 = loader.load_file(options.module_name, options.verify_pickle) ast2 = ast2.Visit(visitors.ClearClassPointers()) if not pytd_utils.ASTeq(ast1, ast2): raise AssertionError() serialize_ast.StoreAst(ast, options.output, options.open_function)
def test_load_with_same_module_name(self): """Explicitly set the module name and reload with the same name. The difference to testLoadTopLevel is that the module name does not match the filelocation. """ with file_utils.Tempdir() as d: module_name = "foo.bar.module1" pickled_ast_filename = os.path.join(d.path, "module1.pyi.pickled") module_map = self._store_ast(d, module_name, pickled_ast_filename) original_ast = module_map[module_name] del module_map[module_name] loaded_ast = serialize_ast.ProcessAst( pytd_utils.LoadPickle(pickled_ast_filename), module_map) self.assertTrue(loaded_ast) self.assertIsNot(loaded_ast, original_ast) self.assertEqual(loaded_ast.name, "foo.bar.module1") self.assertTrue(pytd_utils.ASTeq(original_ast, loaded_ast)) loaded_ast.Visit(visitors.VerifyLookup())
def test_method_pyi(self): src = """ from typing import overload class Foo(object): @overload def f(self, x: int) -> int: pass @overload def f(self, x: str) -> str: pass def f(self, x): return x """ ty = self.Infer(src, analyze_annotated=False) self.assertTrue( pytd_utils.ASTeq(ty, self.Infer(src, analyze_annotated=True))) self.assertTypesMatchPytd( ty, """ class Foo(object): @overload def f(self, x: int) -> int: ... @overload def f(self, x: str) -> str: ... """)
def test_pyi(self): src = """ from typing import overload @overload def f(x: int) -> int: pass @overload def f(x: str) -> str: pass def f(x): return x def g(): return f """ ty = self.Infer(src, analyze_annotated=False) self.assertTrue( pytd_utils.ASTeq(ty, self.Infer(src, analyze_annotated=True))) self.assertTypesMatchPytd( ty, """ from typing import Callable @overload def f(x: int) -> int: ... @overload def f(x: str) -> str: ... def g() -> Callable: ... """) with file_utils.Tempdir() as d: d.create_file("foo.pyi", pytd_utils.Print(ty)) errors = self.CheckWithErrors(""" import foo foo.f(0) # ok foo.f("") # ok foo.f(0.0) # wrong-arg-types[e] """, pythonpath=[d.path]) self.assertErrorRegexes(errors, {"e": r"int.*float"})