예제 #1
0
 def test_parallel_walk_inconsistent_trees(self):
   node_1 = parser.parse_str(
       textwrap.dedent("""
     def f(a):
       return a + 1
   """))
   node_2 = parser.parse_str(
       textwrap.dedent("""
     def f(a):
       return a + (a * 2)
   """))
   node_3 = parser.parse_str(
       textwrap.dedent("""
     def f(a):
       return a + 2
   """))
   with self.assertRaises(ValueError):
     for _ in ast_util.parallel_walk(node_1, node_2):
       pass
   # There is not particular reason to reject trees that differ only in the
   # value of a constant.
   # TODO(mdan): This should probably be allowed.
   with self.assertRaises(ValueError):
     for _ in ast_util.parallel_walk(node_1, node_3):
       pass
예제 #2
0
 def test_parallel_walk_inconsistent_trees(self):
     node_1 = parser.parse_str(
         textwrap.dedent("""
   def f(a):
     return a + 1
 """))
     node_2 = parser.parse_str(
         textwrap.dedent("""
   def f(a):
     return a + (a * 2)
 """))
     node_3 = parser.parse_str(
         textwrap.dedent("""
   def f(a):
     return a + 2
 """))
     with self.assertRaises(ValueError):
         for _ in ast_util.parallel_walk(node_1, node_2):
             pass
     # There is not particular reason to reject trees that differ only in the
     # value of a constant.
     # TODO(mdan): This should probably be allowed.
     with self.assertRaises(ValueError):
         for _ in ast_util.parallel_walk(node_1, node_3):
             pass
예제 #3
0
def _build_source_map(node, code):
    """Return the Python objects represented by given AST.

  Compiling the AST code this way ensures that the source code is readable by
  e.g. `pdb` or `inspect`.

  Args:
    node: An AST node of the original generated code, before the source code is
      generated.
    code: The string representation of the source code for the newly generated
      code.

  Returns:
    Dict[CodeLocation, OriginInfo], a mapping between the user and AutoGraph
    generated code.
  """
    # After we have the final generated code we reparse it to get the final line
    # numbers. Then we walk through the generated and original ASTs in parallel
    # to build the mapping between the user and generated code.
    new_node = parser.parse_str(code)
    origin_info.resolve(new_node, code)
    source_mapping = {}
    for before, after in ast_util.parallel_walk(node, new_node):
        # Need both checks because if origin information is ever copied over to new
        # nodes then we need to rely on the fact that only the original user code
        # has the origin annotation.
        if (anno.hasanno(before, anno.Basic.ORIGIN)
                and anno.hasanno(after, anno.Basic.ORIGIN)):
            source_info = anno.getanno(before, anno.Basic.ORIGIN)
            new_line_number = anno.getanno(after,
                                           anno.Basic.ORIGIN).line_number
            source_mapping[new_line_number] = source_info
    return source_mapping
예제 #4
0
def _build_source_map(node, code):
  """Return the Python objects represented by given AST.

  Compiling the AST code this way ensures that the source code is readable by
  e.g. `pdb` or `inspect`.

  Args:
    node: An AST node of the original generated code, before the source code is
      generated.
    code: The string representation of the source code for the newly generated
      code.

  Returns:
    Dict[CodeLocation, OriginInfo], a mapping between the user and AutoGraph
    generated code.
  """
  # After we have the final generated code we reparse it to get the final line
  # numbers. Then we walk through the generated and original ASTs in parallel
  # to build the mapping between the user and generated code.
  new_node = parser.parse_str(code)
  origin_info.resolve(new_node, code)
  source_mapping = {}
  for before, after in ast_util.parallel_walk(node, new_node):
    # Need both checks because if origin information is ever copied over to new
    # nodes then we need to rely on the fact that only the original user code
    # has the origin annotation.
    if (anno.hasanno(before, anno.Basic.ORIGIN) and
        anno.hasanno(after, anno.Basic.ORIGIN)):
      source_info = anno.getanno(before, anno.Basic.ORIGIN)
      new_line_number = anno.getanno(after, anno.Basic.ORIGIN).line_number
      source_mapping[new_line_number] = source_info
  return source_mapping
예제 #5
0
 def test_parallel_walk(self):
   node = parser.parse_str(
       textwrap.dedent("""
     def f(a):
       return a + 1
   """))
   for child_a, child_b in ast_util.parallel_walk(node, node):
     self.assertEqual(child_a, child_b)
예제 #6
0
 def test_parallel_walk(self):
     node = parser.parse_str(
         textwrap.dedent("""
   def f(a):
     return a + 1
 """))
     for child_a, child_b in ast_util.parallel_walk(node, node):
         self.assertEqual(child_a, child_b)
def source_map(nodes, code, filename, indices_in_code):
    """Creates a source map between an annotated AST and the code it compiles to.

  Args:
    nodes: Iterable[ast.AST, ...]
    code: Text
    filename: Optional[Text]
    indices_in_code: Union[int, Iterable[int, ...]], the positions at which
        nodes appear in code. The parser always returns a module when parsing
        code. This argument indicates the position in that module's body at
        which the corresponding of node should appear.

  Returns:
    Dict[CodeLocation, OriginInfo], mapping locations in code to locations
    indicated by origin annotations in node.
  """
    reparsed_nodes = parser.parse_str(code)
    reparsed_nodes = [reparsed_nodes.body[i] for i in indices_in_code]

    resolve(reparsed_nodes, code)
    result = {}

    for before, after in ast_util.parallel_walk(nodes, reparsed_nodes):
        # Note: generated code might not be mapped back to its origin.
        # TODO(mdan): Generated code should always be mapped to something.
        origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None)
        final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None)
        if origin_info is None or final_info is None:
            continue

        line_loc = LineLocation(filename, final_info.loc.lineno)

        existing_origin = result.get(line_loc)
        if existing_origin is not None:
            # Overlaps may exist because of child nodes, but almost never to
            # different line locations. Exception make decorated functions, where
            # both lines are mapped to the same line in the AST.

            # Line overlaps: keep bottom node.
            if existing_origin.loc.line_loc == origin_info.loc.line_loc:
                if existing_origin.loc.lineno >= origin_info.loc.lineno:
                    continue

            # In case of overlaps, keep the leftmost node.
            if existing_origin.loc.col_offset <= origin_info.loc.col_offset:
                continue

        result[line_loc] = origin_info

    return result
예제 #8
0
def source_map(nodes, code, filename, indices_in_code):
  """Creates a source map between an annotated AST and the code it compiles to.

  Args:
    nodes: Iterable[ast.AST, ...]
    code: Text
    filename: Optional[Text]
    indices_in_code: Union[int, Iterable[int, ...]], the positions at which
        nodes appear in code. The parser always returns a module when parsing
        code. This argument indicates the position in that module's body at
        which the corresponding of node should appear.

  Returns:
    Dict[CodeLocation, OriginInfo], mapping locations in code to locations
    indicated by origin annotations in node.
  """
  reparsed_nodes = parser.parse_str(code)
  reparsed_nodes = [reparsed_nodes.body[i] for i in indices_in_code]

  resolve(reparsed_nodes, code)
  result = {}

  for before, after in ast_util.parallel_walk(nodes, reparsed_nodes):
    # Note: generated code might not be mapped back to its origin.
    # TODO(mdan): Generated code should always be mapped to something.
    origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None)
    final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None)
    if origin_info is None or final_info is None:
      continue

    line_loc = LineLocation(filename, final_info.loc.lineno)

    existing_origin = result.get(line_loc)
    if existing_origin is not None:
      # Overlaps may exist because of child nodes, but almost never to
      # different line locations. Exception make decorated functions, where
      # both lines are mapped to the same line in the AST.

      # Line overlaps: keep bottom node.
      if existing_origin.loc.line_loc == origin_info.loc.line_loc:
        if existing_origin.loc.lineno >= origin_info.loc.lineno:
          continue

      # In case of overlaps, keep the leftmost node.
      if existing_origin.loc.col_offset <= origin_info.loc.col_offset:
        continue

    result[line_loc] = origin_info

  return result
예제 #9
0
 def test_parallel_walk(self):
     ret = ast.Return(
         ast.BinOp(op=ast.Add(),
                   left=ast.Name(id='a', ctx=ast.Load()),
                   right=ast.Num(1)))
     node = ast.FunctionDef(name='f',
                            args=ast.arguments(
                                args=[ast.Name(id='a', ctx=ast.Param())],
                                vararg=None,
                                kwarg=None,
                                defaults=[]),
                            body=[ret],
                            decorator_list=[],
                            returns=None)
     for child_a, child_b in ast_util.parallel_walk(node, node):
         self.assertEqual(child_a, child_b)
예제 #10
0
 def test_parallel_walk(self):
   ret = ast.Return(
       ast.BinOp(
           op=ast.Add(),
           left=ast.Name(id='a', ctx=ast.Load()),
           right=ast.Num(1)))
   node = ast.FunctionDef(
       name='f',
       args=ast.arguments(
           args=[ast.Name(id='a', ctx=ast.Param())],
           vararg=None,
           kwarg=None,
           defaults=[]),
       body=[ret],
       decorator_list=[],
       returns=None)
   for child_a, child_b in ast_util.parallel_walk(node, node):
     self.assertEqual(child_a, child_b)