def test_parallel_walk_inconsistent_trees(self): node_1 = parser.parse_str( textwrap.dedent(""" def f(a): return a + 1 """)) node_2 = parser.parse_str( textwrap.dedent(""" def f(a): return a + (a * 2) """)) node_3 = parser.parse_str( textwrap.dedent(""" def f(a): return a + 2 """)) with self.assertRaises(ValueError): for _ in ast_util.parallel_walk(node_1, node_2): pass # There is not particular reason to reject trees that differ only in the # value of a constant. # TODO(mdan): This should probably be allowed. with self.assertRaises(ValueError): for _ in ast_util.parallel_walk(node_1, node_3): pass
def _build_source_map(node, code): """Return the Python objects represented by given AST. Compiling the AST code this way ensures that the source code is readable by e.g. `pdb` or `inspect`. Args: node: An AST node of the original generated code, before the source code is generated. code: The string representation of the source code for the newly generated code. Returns: Dict[CodeLocation, OriginInfo], a mapping between the user and AutoGraph generated code. """ # After we have the final generated code we reparse it to get the final line # numbers. Then we walk through the generated and original ASTs in parallel # to build the mapping between the user and generated code. new_node = parser.parse_str(code) origin_info.resolve(new_node, code) source_mapping = {} for before, after in ast_util.parallel_walk(node, new_node): # Need both checks because if origin information is ever copied over to new # nodes then we need to rely on the fact that only the original user code # has the origin annotation. if (anno.hasanno(before, anno.Basic.ORIGIN) and anno.hasanno(after, anno.Basic.ORIGIN)): source_info = anno.getanno(before, anno.Basic.ORIGIN) new_line_number = anno.getanno(after, anno.Basic.ORIGIN).line_number source_mapping[new_line_number] = source_info return source_mapping
def test_parallel_walk(self): node = parser.parse_str( textwrap.dedent(""" def f(a): return a + 1 """)) for child_a, child_b in ast_util.parallel_walk(node, node): self.assertEqual(child_a, child_b)
def source_map(nodes, code, filename, indices_in_code): """Creates a source map between an annotated AST and the code it compiles to. Args: nodes: Iterable[ast.AST, ...] code: Text filename: Optional[Text] indices_in_code: Union[int, Iterable[int, ...]], the positions at which nodes appear in code. The parser always returns a module when parsing code. This argument indicates the position in that module's body at which the corresponding of node should appear. Returns: Dict[CodeLocation, OriginInfo], mapping locations in code to locations indicated by origin annotations in node. """ reparsed_nodes = parser.parse_str(code) reparsed_nodes = [reparsed_nodes.body[i] for i in indices_in_code] resolve(reparsed_nodes, code) result = {} for before, after in ast_util.parallel_walk(nodes, reparsed_nodes): # Note: generated code might not be mapped back to its origin. # TODO(mdan): Generated code should always be mapped to something. origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None) final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None) if origin_info is None or final_info is None: continue line_loc = LineLocation(filename, final_info.loc.lineno) existing_origin = result.get(line_loc) if existing_origin is not None: # Overlaps may exist because of child nodes, but almost never to # different line locations. Exception make decorated functions, where # both lines are mapped to the same line in the AST. # Line overlaps: keep bottom node. if existing_origin.loc.line_loc == origin_info.loc.line_loc: if existing_origin.loc.lineno >= origin_info.loc.lineno: continue # In case of overlaps, keep the leftmost node. if existing_origin.loc.col_offset <= origin_info.loc.col_offset: continue result[line_loc] = origin_info return result
def test_parallel_walk(self): ret = ast.Return( ast.BinOp(op=ast.Add(), left=ast.Name(id='a', ctx=ast.Load()), right=ast.Num(1))) node = ast.FunctionDef(name='f', args=ast.arguments( args=[ast.Name(id='a', ctx=ast.Param())], vararg=None, kwarg=None, defaults=[]), body=[ret], decorator_list=[], returns=None) for child_a, child_b in ast_util.parallel_walk(node, node): self.assertEqual(child_a, child_b)
def test_parallel_walk(self): ret = ast.Return( ast.BinOp( op=ast.Add(), left=ast.Name(id='a', ctx=ast.Load()), right=ast.Num(1))) node = ast.FunctionDef( name='f', args=ast.arguments( args=[ast.Name(id='a', ctx=ast.Param())], vararg=None, kwarg=None, defaults=[]), body=[ret], decorator_list=[], returns=None) for child_a, child_b in ast_util.parallel_walk(node, node): self.assertEqual(child_a, child_b)