Esempio n. 1
0
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None):
    """Specialization of `entity_to_graph` for callable functions."""

    node, source = parser.parse_entity(f)
    node = node.body[0]
    origin_info.resolve(node, source, f)
    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = program_ctx.new_namer(namespace)

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         namespace=namespace,
                                         arg_values=arg_values,
                                         arg_types=arg_types,
                                         owner_type=owner_type)
    context = converter.EntityContext(namer, entity_info, program_ctx)
    node = node_to_graph(node, context)

    # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py
    new_name, did_rename = namer.compiled_function_name(
        f.__name__, f, owner_type)
    if not did_rename:
        new_name = f.__name__
        if node.name != f.__name__:
            raise NotImplementedError(
                'Strange corner case. Send us offending code!')

    node.name = new_name
    program_ctx.update_name_map(namer)
    # TODO(mdan): Use this at compilation.

    return node, new_name, namespace
Esempio n. 2
0
def _build_source_map(node, code):
    """Return the Python objects represented by given AST.

  Compiling the AST code this way ensures that the source code is readable by
  e.g. `pdb` or `inspect`.

  Args:
    node: An AST node of the original generated code, before the source code is
      generated.
    code: The string representation of the source code for the newly generated
      code.

  Returns:
    Dict[CodeLocation, OriginInfo], a mapping between the user and AutoGraph
    generated code.
  """
    # After we have the final generated code we reparse it to get the final line
    # numbers. Then we walk through the generated and original ASTs in parallel
    # to build the mapping between the user and generated code.
    new_node = parser.parse_str(code)
    origin_info.resolve(new_node, code)
    source_mapping = {}
    for before, after in ast_util.parallel_walk(node, new_node):
        # Need both checks because if origin information is ever copied over to new
        # nodes then we need to rely on the fact that only the original user code
        # has the origin annotation.
        if (anno.hasanno(before, anno.Basic.ORIGIN)
                and anno.hasanno(after, anno.Basic.ORIGIN)):
            source_info = anno.getanno(before, anno.Basic.ORIGIN)
            new_line_number = anno.getanno(after,
                                           anno.Basic.ORIGIN).line_number
            source_mapping[new_line_number] = source_info
    return source_mapping
Esempio n. 3
0
def _build_source_map(node, code):
  """Return the Python objects represented by given AST.

  Compiling the AST code this way ensures that the source code is readable by
  e.g. `pdb` or `inspect`.

  Args:
    node: An AST node of the original generated code, before the source code is
      generated.
    code: The string representation of the source code for the newly generated
      code.

  Returns:
    Dict[CodeLocation, OriginInfo], a mapping between the user and AutoGraph
    generated code.
  """
  # After we have the final generated code we reparse it to get the final line
  # numbers. Then we walk through the generated and original ASTs in parallel
  # to build the mapping between the user and generated code.
  new_node = parser.parse_str(code)
  origin_info.resolve(new_node, code)
  source_mapping = {}
  for before, after in ast_util.parallel_walk(node, new_node):
    # Need both checks because if origin information is ever copied over to new
    # nodes then we need to rely on the fact that only the original user code
    # has the origin annotation.
    if (anno.hasanno(before, anno.Basic.ORIGIN) and
        anno.hasanno(after, anno.Basic.ORIGIN)):
      source_info = anno.getanno(before, anno.Basic.ORIGIN)
      new_line_number = anno.getanno(after, anno.Basic.ORIGIN).line_number
      source_mapping[new_line_number] = source_info
  return source_mapping
  def test_resolve(self):

    def test_fn(x):
      """Docstring."""
      return x  # comment

    node, source = parser.parse_entity(test_fn)
    fn_node = node.body[0]
    origin_info.resolve(fn_node, source)

    origin = anno.getanno(fn_node, anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 1)
    self.assertEqual(origin.loc.col_offset, 0)
    self.assertEqual(origin.source_code_line, 'def test_fn(x):')
    self.assertIsNone(origin.comment)

    origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 2)
    self.assertEqual(origin.loc.col_offset, 2)
    self.assertEqual(origin.source_code_line, '  """Docstring."""')
    self.assertIsNone(origin.comment)

    origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 3)
    self.assertEqual(origin.loc.col_offset, 2)
    self.assertEqual(origin.source_code_line, '  return x  # comment')
    self.assertEqual(origin.comment, 'comment')
  def test_source_map(self):

    def test_fn(x):
      if x > 0:
        x += 1
      return x

    node, source = parser.parse_entity(test_fn)
    fn_node = node.body[0]
    origin_info.resolve(fn_node, source)

    # Insert a traced line.
    new_node = parser.parse_str('x = abs(x)').body[0]
    anno.copyanno(fn_node.body[0], new_node, anno.Basic.ORIGIN)
    fn_node.body.insert(0, new_node)

    # Insert an untraced line.
    fn_node.body.insert(0, parser.parse_str('x = 0').body[0])

    modified_source = compiler.ast_to_source(fn_node)

    source_map = origin_info.source_map(fn_node, modified_source,
                                        'test_filename', [0])

    loc = origin_info.LineLocation('test_filename', 1)
    origin = source_map[loc]
    self.assertEqual(origin.source_code_line, 'def test_fn(x):')
    self.assertEqual(origin.loc.lineno, 1)

    # The untraced line, inserted second.
    loc = origin_info.LineLocation('test_filename', 2)
    self.assertFalse(loc in source_map)

    # The traced line, inserted first.
    loc = origin_info.LineLocation('test_filename', 3)
    origin = source_map[loc]
    self.assertEqual(origin.source_code_line, '  if x > 0:')
    self.assertEqual(origin.loc.lineno, 2)

    loc = origin_info.LineLocation('test_filename', 4)
    origin = source_map[loc]
    self.assertEqual(origin.source_code_line, '  if x > 0:')
    self.assertEqual(origin.loc.lineno, 2)
Esempio n. 6
0
def function_to_graph(f,
                      program_ctx,
                      arg_values,
                      arg_types,
                      owner_type=None,
                      rewrite_errors=True):
  """Specialization of `entity_to_graph` for callable functions."""

  node, source = parser.parse_entity(f)
  node = node.body[0]
  origin_info.resolve(node, source, f)
  namespace = inspect_utils.getnamespace(f)
  _add_self_references(namespace, program_ctx.autograph_module)
  namer = program_ctx.new_namer(namespace)

  entity_info = transformer.EntityInfo(
      source_code=source,
      source_file='<fragment>',
      namespace=namespace,
      arg_values=arg_values,
      arg_types=arg_types,
      owner_type=owner_type)
  context = converter.EntityContext(namer, entity_info, program_ctx)
  node = node_to_graph(node, context, rewrite_errors=rewrite_errors)

  # TODO(mdan): This somewhat duplicates the call rename logic in call_trees.py
  new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type)
  if not did_rename:
    new_name = f.__name__
    if node.name != f.__name__:
      raise NotImplementedError('Strange corner case. Send us offending code!')
  node.name = new_name

  program_ctx.update_name_map(namer)
  # TODO(mdan): Use this at compilation.

  return [node], new_name, namespace