def test_insert_ast(self): new_ast = AST.from_string("y = 2\n", language=ASTLanguage.Python, deepest=True) new_root = AST.insert(self.root, self.statement, new_ast) self.assertNotEqual(new_root.oid, self.root.oid) self.assertEqual(2, len(new_root.children)) self.assertEqual("y = 2\nx = 88\n", new_root.source_text)
def x_to_z_gt(ast: AST) -> Optional[LiteralOrAST]: """Convert 'x' identifiers in greater than operations to 'z'.""" if isinstance(ast, PythonComparisonOperator): lhs, *rest = ast.child_slot("CHILDREN") operator, *_ = ast.child_slot("OPERATORS") if isinstance(operator, PythonGreaterThan) and lhs.source_text == "x": return AST.copy(ast, children=["z", *rest])
def test_asts_from_template(self): asts = AST.asts_from_template("$1;", ASTLanguage.C, '"Foo: %d"') self.assertEqual(len(asts), 1) self.assertEqual(asts[0].source_text, '"Foo: %d"') self.assertIsInstance(asts[0], CStringLiteral) asts = AST.asts_from_template("$1 + $2", ASTLanguage.Python, "x", 1) self.assertEqual(len(asts), 2) self.assertEqual(asts[0].source_text, "x") self.assertIsInstance(asts[0], PythonIdentifier) self.assertEqual(asts[1].source_text, "1") self.assertIsInstance(asts[1], PythonInteger)
def test_delete_print_statements(self): def is_print_statement(ast: AST) -> bool: """Return TRUE if AST is an statement calling the print function.""" if isinstance(ast, ExpressionStatementAST): fn_calls = [ c.call_function().source_text for c in ast.call_asts() ] return "print" in fn_calls return False def delete_print_statements(ast: AST) -> Optional[LiteralOrAST]: """Delete all print statements from the children of AST.""" if isinstance(ast, RootAST) or isinstance(ast, CompoundAST): # Build a list of new children under the AST, eliding print statements. new_children = [ c for c in ast.children if not is_print_statement(c) ] # Special case; if no children remain, add a "pass" statement nop to # avoid syntax errors. new_children = new_children if new_children else ["pass\n"] return AST.copy(ast, children=new_children) transformed = AST.transform(self.root, delete_print_statements) expected = slurp(DATA_DIR / "transform" / "delete_print_statements.py") self.assertEqual(transformed.source_text, expected)
def test_provided_by(self): root = AST.from_string("import os\nos.path.join(a, b)", ASTLanguage.Python) self.assertEqual(1, len(root.call_asts())) call = root.call_asts()[0] self.assertEqual("os.path", call.provided_by(root))
def test_imports(self): code = "import os\nimport sys as s\nfrom json import dump\nprint('Hello')" root = AST.from_string(code, ASTLanguage.Python) ast = root.children[-1] imports = ast.imports(root) self.assertEqual([["os"], ["sys", "s"], ["json", None, "dump"]], imports)
def test_remove_cast(self): def remove_cast(ast: AST): if isinstance(ast, CCastExpression): return ast.value transformed = AST.transform(self.root, remove_cast) expected = slurp(DATA_DIR / "transform" / "ccast-transformed.c") self.assertEqual(transformed.source_text, expected)
def test_ast_constructor_deepest_parameter(self): new = AST.from_string( self.root.source_text, language=self.root.language, deepest=True, ) self.assertEqual(new.source_text, self.root.source_text) self.assertNotEqual(type(new), type(self.root))
def test_no_arguments(self): root = AST.from_string("foo()", ASTLanguage.Python) self.assertEqual(1, len(root.call_asts())) call = root.call_asts()[0] self.assertEqual(None, call.provided_by(root)) self.assertEqual("foo", call.call_function().source_text) self.assertEqual([], call.call_arguments())
def test_no_params(self): ast = AST.from_string("def foo(): return None", ASTLanguage.Python) self.assertEqual(1, len(ast.function_asts())) function = ast.function_asts()[0] self.assertEqual("foo", function.function_name()) self.assertEqual([], function.function_parameters()) self.assertEqual("return None", function.function_body().source_text)
def is_print_statement(ast: AST) -> bool: """Return TRUE if AST is an statement calling the print function.""" if isinstance(ast, ExpressionStatementAST): fn_calls = [ c.call_function().source_text for c in ast.call_asts() ] return "print" in fn_calls return False
def test_ast_template(self): a = AST.ast_template("$ID = 1", ASTLanguage.Python, id="x") self.assertEqual(a.source_text, "x = 1") self.assertIsInstance(a, PythonAssignment0) a = AST.ast_template("fn(@ARGS)", ASTLanguage.Python, args=[1, 2, 3]) self.assertEqual(a.source_text, "fn(1, 2, 3)") self.assertIsInstance(a, PythonCall) a = AST.ast_template("$1 = $2", ASTLanguage.Python, "x", 1) self.assertEqual(a.source_text, "x = 1") self.assertIsInstance(a, PythonAssignment0) a = AST.ast_template("fn(@1)", ASTLanguage.Python, [1, 2, 3]) self.assertEqual(a.source_text, "fn(1, 2, 3)") self.assertIsInstance(a, PythonCall) lhs = AST.from_string("x", ASTLanguage.Python, deepest=True) a = AST.ast_template("$1 = value", ASTLanguage.Python, lhs) self.assertEqual(a.source_text, "x = value") self.assertIsInstance(a, PythonAssignment0) template = "$LEFT_HAND_SIDE = $RIGHT_HAND_SIDE" a = AST.ast_template(template, ASTLanguage.Python, left_hand_side="x", right_hand_side=1) self.assertEqual(a.source_text, "x = 1") self.assertIsInstance(a, PythonAssignment0)
def test_multiple_arguments(self): root = AST.from_string("bar(a, b)", ASTLanguage.Python) self.assertEqual(1, len(root.call_asts())) call = root.call_asts()[0] args = [a.source_text for a in call.call_arguments()] self.assertEqual(None, call.provided_by(root)) self.assertEqual("bar", call.call_function().source_text) self.assertEqual(["a", "b"], args)
def test_transform_x_to_y(self): def x_to_y(ast: AST) -> Optional[LiteralOrAST]: """Convert 'x' identifier ASTs to 'y'.""" if isinstance(ast, IdentifierAST) and "x" == ast.source_text: return "y" transformed = AST.transform(self.root, x_to_y) expected = slurp(DATA_DIR / "transform" / "transform_x_to_y.py") self.assertEqual(transformed.source_text, expected)
def test_multiple_parameters(self): ast = AST.from_string("def bar(a, b): return a*b", ASTLanguage.Python) self.assertEqual(1, len(ast.function_asts())) function = ast.function_asts()[0] params = [p.source_text for p in function.function_parameters()] self.assertEqual("bar", function.function_name()) self.assertEqual(["a", "b"], params) self.assertEqual("return a*b", function.function_body().source_text)
def test_inner_parent_asts(self): text = """__all__ = [ "c", # first comment # second comment ]""" root = AST.from_string(text, ASTLanguage.Python) inner_parent = root.children[0].children[0].children[1].children[1] self.assertIsInstance(inner_parent, InnerParent) self.assertTrue( inner_parent.source_text.startswith(" # first comment"))
def delete_print_statements(ast: AST) -> Optional[LiteralOrAST]: """Delete all print statements from the children of AST.""" if isinstance(ast, RootAST) or isinstance(ast, CompoundAST): # Build a list of new children under the AST, eliding print statements. new_children = [ c for c in ast.children if not is_print_statement(c) ] # Special case; if no children remain, add a "pass" statement nop to # avoid syntax errors. new_children = new_children if new_children else ["pass\n"] return AST.copy(ast, children=new_children)
def test_transform_x_to_z_gt(self): def x_to_z_gt(ast: AST) -> Optional[LiteralOrAST]: """Convert 'x' identifiers in greater than operations to 'z'.""" if isinstance(ast, PythonComparisonOperator): lhs, *rest = ast.child_slot("CHILDREN") operator, *_ = ast.child_slot("OPERATORS") if isinstance(operator, PythonGreaterThan) and lhs.source_text == "x": return AST.copy(ast, children=["z", *rest]) transformed = AST.transform(self.root, x_to_z_gt) expected = slurp(DATA_DIR / "transform" / "transform_x_to_z_gt.py") self.assertEqual(transformed.source_text, expected)
def test_vars_in_scope_no_globals(self): root = AST.from_string("def bar(a, b): return a*b", ASTLanguage.Python) ast = root.children[-1].children[-1].children[-1] vars_in_scope = ast.get_vars_in_scope(root, keep_globals=False) names = [var["name"] for var in vars_in_scope] self.assertEqual(names[0], "a") self.assertEqual(names[1], "b") scopes = [var["scope"] for var in vars_in_scope] self.assertIsInstance(scopes[0], PythonFunctionDefinition2) self.assertIsInstance(scopes[1], PythonFunctionDefinition2) decls = [var["decl"] for var in vars_in_scope] self.assertIsInstance(decls[0], PythonIdentifier) self.assertIsInstance(decls[1], PythonIdentifier)
def test_copy_with_kwargs(self): copy = AST.copy( self.root, left=AST.from_string("y", ASTLanguage.Python, deepest=True), ) self.assertEqual(copy.source_text, "y + 1") self.assertNotEqual(copy.oid, self.root.oid) copy = AST.copy(self.root, left=0.5) self.assertEqual(copy.source_text, "0.5 + 1") self.assertNotEqual(copy.oid, self.root.oid) copy = AST.copy(self.root, left=2) self.assertEqual(copy.source_text, "2 + 1") self.assertNotEqual(copy.oid, self.root.oid) copy = AST.copy(self.root, left='"hi"') self.assertEqual(copy.source_text, '"hi" + 1') self.assertNotEqual(copy.oid, self.root.oid) copy = AST.copy(self.root, left="y") self.assertEqual(copy.source_text, "y + 1") self.assertNotEqual(copy.oid, self.root.oid)
def test_no_imports(self): root = AST.from_string("", ASTLanguage.Python) self.assertEqual([], root.imports(root))
def test_no_vars_in_scope(self): root = AST.from_string("", ASTLanguage.Python) self.assertEqual([], root.get_vars_in_scope(root))
def test_error_handling(self): with self.assertRaises(ASTException): AST.from_string("foo()", language="foo")
def test_replace_literal(self): lhs = self.statement.children[0].children[0] new_root = AST.replace(self.root, lhs, "y") self.assertNotEqual(new_root.oid, self.root.oid) self.assertEqual("y = 88\n", new_root.source_text)
def setUp(self): text = slurp(DATA_DIR / "transform" / "original.py") self.root = AST.from_string(text, ASTLanguage.Python)
def setUp(self): text = slurp(DATA_DIR / "transform" / "ccast-original.c") self.root = AST.from_string(text, ASTLanguage.C)
def test_utf8_multibyte_characters(self): root = AST.from_string('"反复请求多次"', ASTLanguage.Python) rnge = root.ast_source_ranges()[0][1] self.assertEqual('"反复请求多次"', root.source_text) self.assertEqual([[1, 1], [1, 9]], rnge)
def test_no_calls(self): root = AST.from_string("", ASTLanguage.Python) self.assertEqual([], root.call_asts())
def simple_parse_driver(self, text, language): root = AST.from_string(text, language) self.assertEqual(root.language, language) self.assertEqual(root.source_text, text)
def test_no_functions(self): ast = AST.from_string("", ASTLanguage.Python) self.assertEqual([], ast.function_asts())