def assert_ast_eq(testcase, orig_ast, expected_ast): if not cmp_ast(orig_ast, expected_ast): str1 = str_ast(orig_ast, indent=' ', newline='\n') str2 = str_ast(expected_ast, indent=' ', newline='\n') msg = 'AST Trees are not equal\n## left ########### \n%s\n## right ########### \n%s' % (str1, str2) testcase.fail(msg)
def UNPACK_SEQUENCE(self, instr): nargs = instr.oparg nodes = [] ast_tuple = _ast.Tuple(elts=nodes, ctx=_ast.Store(), lineno=instr.lineno, col_offset=0) for i in range(nargs): nex_instr = self.ilst.pop(0) self.ast_stack.append(None) self.visit(nex_instr) node = self.ast_stack.pop() nodes.append(node.targets[0]) expr = self.ast_stack.pop() if isinstance(expr, _ast.Assign): assgn = expr assgn.targets.append(ast_tuple) value_dup = self.ast_stack.pop() assert cmp_ast(assgn.value, value_dup) else: assgn = _ast.Assign(targets=[ast_tuple], value=expr, lineno=instr.lineno, col_offset=0) self.ast_stack.append(assgn)
def UNPACK_SEQUENCE(self, instr): nargs = instr.oparg nodes = [] ast_tuple = _ast.Tuple(elts=nodes, ctx=_ast.Store(), lineno=instr.lineno, col_offset=0) for i in range(nargs): nex_instr = self.ilst.pop(0) self.push_ast_item(None) self.visit(nex_instr) node = self.pop_ast_item() nodes.append(node.targets[0]) expr = self.pop_ast_item() if isinstance(expr, _ast.Assign): assgn = expr assgn.targets.append(ast_tuple) value_dup = self.pop_ast_item() assert cmp_ast(assgn.value, value_dup) else: assgn = _ast.Assign(targets=[ast_tuple], value=expr, lineno=instr.lineno, col_offset=0) self.push_ast_item(assgn)
def test_to_file(self) -> None: """ Tests whether `file` constructs a file, and fills it with the right content """ with TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, "delete_me.py") try: emit.file(class_ast, filename, skip_black=True) with open(filename, "rt") as f: ugly = f.read() os.remove(filename) emit.file(class_ast, filename, skip_black=False) with open(filename, "rt") as f: blacked = f.read() self.assertNotEqual(ugly, blacked) # if PY3_8: self.assertTrue( cmp_ast(ast.parse(ugly), ast.parse(blacked)), "Ugly AST doesn't match blacked AST", ) finally: if os.path.isfile(filename): os.remove(filename)
def test_parse_to_scalar(self) -> None: """ Test various inputs and outputs for `parse_to_scalar` """ for fst, snd in ( (5, 5), ("5", "5"), (set_value(5), 5), (ast.Expr(None), NoneStr), ): self.assertEqual(parse_to_scalar(fst), snd) self.assertEqual( get_value(parse_to_scalar(ast.parse("[5]").body[0]).elts[0]), 5 ) self.assertTrue( cmp_ast( parse_to_scalar(ast.parse("[5]").body[0]), List([set_value(5)], Load()), ) ) self.assertEqual(parse_to_scalar(ast.parse("[5]")), "[5]") parse_to_scalar(ast.parse("[5]").body[0]) self.assertRaises(NotImplementedError, parse_to_scalar, memoryview(b"")) self.assertRaises(NotImplementedError, parse_to_scalar, memoryview(b""))
def assert_ast_eq(testcase, orig_ast, expected_ast): if not cmp_ast(orig_ast, expected_ast): str1 = str_ast(orig_ast, indent=' ', newline='\n') str2 = str_ast(expected_ast, indent=' ', newline='\n') msg = 'AST Trees are not equal\n## left ########### \n%s\n## right ########### \n%s' % ( str1, str2) testcase.fail(msg)
def run_ast_test(test_case_instance, gen_ast, gold, skip_black=False): """ Compares `gen_ast` with `gold` standard :param test_case_instance: instance of `TestCase` :type test_case_instance: ```unittest.TestCase``` :param gen_ast: generated AST :type gen_ast: ```Union[ast.Module, ast.ClassDef, ast.FunctionDef]``` :param skip_black: Whether to skip formatting with black. Turned off for performance, turn on for pretty debug. :type skip_black: ```bool``` :param gold: mocked AST :type gold: ```Union[ast.Module, ast.ClassDef, ast.FunctionDef]``` """ if isinstance(gen_ast, str): gen_ast = ast.parse(gen_ast).body[0] assert gen_ast is not None, "gen_ast is None" assert gold is not None, "gold is None" gen_ast = deepcopy(gen_ast) gold = deepcopy(gold) # if reindent_docstring: # gen_docstring = ast.get_docstring(gen_ast) # if gen_docstring is not None: # gen_ast.body[0] = set_value( # "\n{}".format(indent(cleandoc(gen_docstring), tab)) # ) # gold.body[0] = set_value( # "\n{}".format(indent(ast.get_docstring(gold, clean=True), tab)) # ) # from meta.asttools import print_ast # # print("#gen") # print_ast(gen_ast) # print("#gold") # print_ast(gold) test_case_instance.assertEqual(*map( identity if skip_black else partial( format_str, mode=Mode( target_versions=set(), line_length=60, is_pyi=False, string_normalization=False, ), ), map(source_transformer.to_code, (gold, gen_ast)), )) test_case_instance.assertTrue(cmp_ast(gen_ast, gold), "Generated AST doesn't match reference AST")
def run_ast_test(test_case_instance, gen_ast, gold, skip_black=False): """ Compares `gen_ast` with `gold` standard :param test_case_instance: instance of `TestCase` :type test_case_instance: ```unittest.TestCase``` :param gen_ast: generated AST :type gen_ast: ```Union[ast.Module, ast.ClassDef, ast.FunctionDef]``` :param gold: mocked AST :type gold: ```Union[ast.Module, ast.ClassDef, ast.FunctionDef]``` :param skip_black: Whether to skip black :type skip_black: ```bool``` """ if isinstance(gen_ast, str): gen_ast = ast.parse(gen_ast).body[0] assert gen_ast is not None, "gen_ast is None" assert gold is not None, "gold is None" gen_ast = deepcopy(gen_ast) gold = deepcopy(gold) if hasattr(gen_ast, "body") and len(gen_ast.body) > 0: gen_docstring = ast.get_docstring(gen_ast) gold_docstring = ast.get_docstring(gold) if gen_docstring is not None and gold_docstring is not None: test_case_instance.assertEqual(gold_docstring.strip(), gen_docstring.strip()) # Following test issue with docstring indentation, remove them from the AST, as symmetry has been confirmed gen_ast.body.pop(0) gold.body.pop(0) test_case_instance.assertEqual(*map( partial( (lambda _, **kwargs: _) if skip_black else format_str, mode=Mode( target_versions=set(), line_length=60, is_pyi=False, string_normalization=False, ), ), map(doctrans.source_transformer.to_code, (gold, gen_ast)), )) # from meta.asttools import print_ast # print_ast(gen_ast) # print_ast(gold) test_case_instance.assertTrue(cmp_ast(gen_ast, gold), "Generated AST doesn't match reference AST")
def assertAstEqual(self, left, right): if not isinstance(left, _ast.AST): raise self.failureException("%s is not an _ast.AST instance" % (left)) if not isinstance(right, _ast.AST): raise self.failureException("%s is not an _ast.AST instance" % (right)) result = cmp_ast(left, right) if not result: lstream = StringIO() print_ast(left, indent='', file=lstream, newline='') rstream = StringIO() print_ast(right, indent='', file=rstream, newline='') lstream.seek(0) rstream.seek(0) msg = 'Ast Not Equal:\nGenerated: %r\nExpected: %r' % (lstream.read(), rstream.read()) raise self.failureException(msg)
def STORE_SLICE_3(self, instr): 'obj[lower:upper] = expr' upper = self.ast_stack.pop() lower = self.ast_stack.pop() value = self.ast_stack.pop() expr = self.ast_stack.pop() kw = dict(lineno=instr.lineno, col_offset=0) slice = _ast.Slice(lower=lower, step=None, upper=upper, **kw) subscr = _ast.Subscript(value=value, slice=slice, ctx=_ast.Store(), **kw) if isinstance(expr, _ast.AugAssign): assign = expr result = cmp_ast(expr.target, subscr) assert result else: assign = _ast.Assign(targets=[subscr], value=expr, **kw) self.ast_stack.append(assign)
def test_get_at_root(self) -> None: """ Tests that `get_at_root` successfully gets the imports """ with open(path.join(path.dirname(__file__), "mocks", "eval.py")) as f: imports = get_at_root(ast.parse(f.read()), (Import, ImportFrom)) self.assertIsInstance(imports, list) self.assertEqual(len(imports), 1) self.assertTrue( cmp_ast( imports[0], ast.Import( names=[ ast.alias( asname=None, name="doctrans.tests.mocks", identifier=None, identifier_name=None, ) ], alias=None, ), ))
def assertAstEqual(self, left, right): if not isinstance(left, _ast.AST): raise self.failureException("%s is not an _ast.AST instance" % (left)) if not isinstance(right, _ast.AST): raise self.failureException("%s is not an _ast.AST instance" % (right)) result = cmp_ast(left, right) if not result: lstream = StringIO() print_ast(left, indent='', file=lstream, newline='') rstream = StringIO() print_ast(right, indent='', file=rstream, newline='') lstream.seek(0) rstream.seek(0) msg = 'Ast Not Equal:\nGenerated: %r\nExpected: %r' % ( lstream.read(), rstream.read()) raise self.failureException(msg)
def STORE_SLICE_3(self, instr): 'obj[lower:upper] = expr' upper = self.pop_ast_item() lower = self.pop_ast_item() value = self.pop_ast_item() expr = self.pop_ast_item() kw = dict(lineno=instr.lineno, col_offset=0) slice = _ast.Slice(lower=lower, step=None, upper=upper, **kw) subscr = _ast.Subscript(value=value, slice=slice, ctx=_ast.Store(), **kw) if isinstance(expr, _ast.AugAssign): assign = expr result = cmp_ast(expr.target, subscr) assert result else: assign = _ast.Assign(targets=[subscr], value=expr, **kw) self.push_ast_item(assign)
def _conform_filename( filename, search, emit_func, replacement_node_ir, type_wanted, ): """ Conform the given file to the `intermediate_repr` :param filename: Location of file :type filename: ```str``` :param search: Search query, e.g., ['node_name', 'function_name', 'arg_name'] :type search: ```List[str]``` :param replacement_node_ir: Replace what is found with the contents of this param :type replacement_node_ir: ```dict``` :param type_wanted: AST instance :type type_wanted: ```AST``` :return: filename, whether the file was modified :rtype: ```Tuple[str, bool]``` """ filename = path.realpath(path.expanduser(filename)) if not path.isfile(filename): emit.file( emit_func( replacement_node_ir, emit_default_doc=False, # emit_func.__name__ == "class_" ), filename=filename, mode="wt", skip_black=False, ) return filename, True with open(filename, "rt") as f: parsed_ast = ast_parse(f.read(), filename=filename) assert isinstance(parsed_ast, Module) original_node = find_in_ast(search, parsed_ast) replacement_node = emit_func( replacement_node_ir, **_default_options(node=original_node, search=search, type_wanted=type_wanted)()) if original_node is None: emit.file(replacement_node, filename=filename, mode="a", skip_black=False) return filename, True assert len(search) > 0 assert type( replacement_node) == type_wanted, "Expected {!r} got {!r}".format( type_wanted, type(replacement_node).__name__) replaced = False if not cmp_ast(original_node, replacement_node): rewrite_at_query = RewriteAtQuery( search=search, replacement_node=replacement_node, ) rewrite_at_query.visit(parsed_ast) print("modified" if rewrite_at_query.replaced else "unchanged", filename, sep="\t") if rewrite_at_query.replaced: emit.file(parsed_ast, filename, mode="wt", skip_black=False) replaced = rewrite_at_query.replaced return filename, replaced