def test_resolve(self):

    source = """
        def test_fn(x):
          '''Docstring.'''
          return x  # comment
    """
    source = textwrap.dedent(source)

    node = parser.parse_str(source)

    origin_info.resolve(node, source)

    origin = anno.getanno(node, anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 2)
    self.assertEqual(origin.loc.col_offset, 0)
    self.assertEqual(origin.source_code_line, 'def test_fn(x):')
    self.assertIsNone(origin.comment)

    origin = anno.getanno(node.body[0], anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 3)
    self.assertEqual(origin.loc.col_offset, 2)
    self.assertEqual(origin.source_code_line, "  '''Docstring.'''")
    self.assertIsNone(origin.comment)

    origin = anno.getanno(node.body[1], anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 4)
    self.assertEqual(origin.loc.col_offset, 2)
    self.assertEqual(origin.source_code_line, '  return x  # comment')
    self.assertEqual(origin.comment, 'comment')
  def test_resolve(self):

    def test_fn(x):
      """Docstring."""
      return x  # comment

    node, source = parser.parse_entity(test_fn)
    fn_node = node.body[0]

    origin_info.resolve(fn_node, source)

    origin = anno.getanno(fn_node, anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 1)
    self.assertEqual(origin.loc.col_offset, 0)
    self.assertEqual(origin.source_code_line, 'def test_fn(x):')
    self.assertIsNone(origin.comment)

    origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 2)
    self.assertEqual(origin.loc.col_offset, 2)
    self.assertEqual(origin.source_code_line, '  """Docstring."""')
    self.assertIsNone(origin.comment)

    origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 3)
    self.assertEqual(origin.loc.col_offset, 2)
    self.assertEqual(origin.source_code_line, '  return x  # comment')
    self.assertEqual(origin.comment, 'comment')
示例#3
0
  def prepare(self,
              test_fn,
              namespace,
              namer=None,
              arg_types=None,
              owner_type=None,
              recursive=True,
              strip_decorators=()):
    namespace['ConversionOptions'] = converter.ConversionOptions

    node, source = parser.parse_entity(test_fn)
    node = node.body[0]
    if namer is None:
      namer = FakeNamer()
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(
            recursive=recursive,
            strip_decorators=strip_decorators,
            verbose=True),
        partial_types=None,
        autograph_module=None,
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    entity_info = transformer.EntityInfo(
        source_code=source,
        source_file='<fragment>',
        namespace=namespace,
        arg_values=None,
        arg_types=arg_types,
        owner_type=owner_type)
    ctx = converter.EntityContext(namer, entity_info, program_ctx)
    origin_info.resolve(node, source, test_fn)
    node = converter.standard_analysis(node, ctx, is_initial=True)
    return node, ctx
    def test_resolve(self):
        def test_fn(x):
            """Docstring."""
            return x  # comment

        node, source = parser.parse_entity(test_fn)
        fn_node = node.body[0]
        origin_info.resolve(fn_node, source)

        origin = anno.getanno(fn_node, anno.Basic.ORIGIN)
        self.assertEqual(origin.loc.lineno, 1)
        self.assertEqual(origin.loc.col_offset, 0)
        self.assertEqual(origin.source_code_line, 'def test_fn(x):')
        self.assertIsNone(origin.comment)

        origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN)
        self.assertEqual(origin.loc.lineno, 2)
        self.assertEqual(origin.loc.col_offset, 2)
        self.assertEqual(origin.source_code_line, '  """Docstring."""')
        self.assertIsNone(origin.comment)

        origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN)
        self.assertEqual(origin.loc.lineno, 3)
        self.assertEqual(origin.loc.col_offset, 2)
        self.assertEqual(origin.source_code_line, '  return x  # comment')
        self.assertEqual(origin.comment, 'comment')
示例#5
0
    def test_resolve(self):

        source = """
      def test_fn(x):
        '''Docstring.'''
        return x  # comment
    """
        source = textwrap.dedent(source)
        node = parser.parse(source)
        origin_info.resolve(node, source, 'test_file', 10, 10)

        def_origin = anno.getanno(node, anno.Basic.ORIGIN)
        self.assertEqual(def_origin.loc.filename, 'test_file')
        self.assertEqual(def_origin.loc.lineno, 10)
        self.assertEqual(def_origin.loc.col_offset, 10)
        self.assertEqual(def_origin.source_code_line, 'def test_fn(x):')
        self.assertIsNone(def_origin.comment)

        docstring_origin = anno.getanno(node.body[0], anno.Basic.ORIGIN)
        self.assertEqual(def_origin.loc.filename, 'test_file')
        self.assertEqual(docstring_origin.loc.lineno, 11)
        self.assertEqual(docstring_origin.loc.col_offset, 12)
        self.assertEqual(docstring_origin.source_code_line,
                         "  '''Docstring.'''")
        self.assertIsNone(docstring_origin.comment)

        ret_origin = anno.getanno(node.body[1], anno.Basic.ORIGIN)
        self.assertEqual(def_origin.loc.filename, 'test_file')
        self.assertEqual(ret_origin.loc.lineno, 12)
        self.assertEqual(ret_origin.loc.col_offset, 12)
        self.assertEqual(ret_origin.source_code_line, '  return x  # comment')
        self.assertEqual(ret_origin.comment, 'comment')
示例#6
0
  def test_origin_info_preserved_in_moved_nodes(self):

    class TestTransformer(transformer.Base):

      def visit_If(self, node):
        return node.body

    tr = TestTransformer(self._simple_context())

    def test_fn():
      x = 1
      if x > 0:
        x = 1
        x += 3
      return x

    node, source = parser.parse_entity(test_fn, future_features=())
    origin_info.resolve(node, source)
    node = tr.visit(node)

    assign_node = node.body[1]
    aug_assign_node = node.body[2]
    self.assertEqual(
        anno.getanno(assign_node, anno.Basic.ORIGIN).loc.lineno, 4)
    self.assertEqual(
        anno.getanno(aug_assign_node, anno.Basic.ORIGIN).loc.lineno, 5)
示例#7
0
    def disabled_test_resolve_with_future_imports(self):
        def test_fn(x):
            """Docstring."""
            print(x)
            return x  # comment

        node, source = parser.parse_entity(test_fn)
        fn_node = node.body[-1]

        origin_info.resolve(fn_node, source)

        origin = anno.getanno(fn_node, anno.Basic.ORIGIN)
        self.assertEqual(origin.loc.lineno, 2)
        self.assertEqual(origin.loc.col_offset, 0)
        self.assertEqual(origin.source_code_line, 'def test_fn(x):')
        self.assertIsNone(origin.comment)

        origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN)
        self.assertEqual(origin.loc.lineno, 3)
        self.assertEqual(origin.loc.col_offset, 2)
        self.assertEqual(origin.source_code_line, '  """Docstring."""')
        self.assertIsNone(origin.comment)

        origin = anno.getanno(fn_node.body[2], anno.Basic.ORIGIN)
        self.assertEqual(origin.loc.lineno, 5)
        self.assertEqual(origin.loc.col_offset, 2)
        self.assertEqual(origin.source_code_line, '  return x  # comment')
        self.assertEqual(origin.comment, 'comment')
示例#8
0
  def test_origin_info_preserved_in_moved_nodes(self):

    class TestTransformer(transformer.Base):

      def visit_If(self, node):
        return node.body

    tr = TestTransformer(self._simple_context())

    def test_fn():
      x = 1
      if x > 0:
        x = 1
        x += 3
      return x

    node, source = parser.parse_entity(test_fn, future_features=())
    origin_info.resolve(node, source, 'test_file', 100, 0)
    node = tr.visit(node)

    assign_node = node.body[1]
    aug_assign_node = node.body[2]
    # Keep their original line numbers.
    self.assertEqual(
        anno.getanno(assign_node, anno.Basic.ORIGIN).loc.lineno, 103)
    self.assertEqual(
        anno.getanno(aug_assign_node, anno.Basic.ORIGIN).loc.lineno, 104)
示例#9
0
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None):
    """Specialization of `entity_to_graph` for callable functions."""

    node, source = parser.parse_entity(f)
    node = node.body[0]
    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve(node, source, f)
    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = program_ctx.new_namer(namespace)

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         namespace=namespace,
                                         arg_values=arg_values,
                                         arg_types=arg_types,
                                         owner_type=owner_type)
    context = converter.EntityContext(namer, entity_info, program_ctx)
    node = node_to_graph(node, context)

    # TODO(mdan): This somewhat duplicates the call rename logic in call_trees.py
    new_name, did_rename = namer.compiled_function_name(
        f.__name__, f, owner_type)
    if not did_rename:
        new_name = f.__name__
        if node.name != f.__name__:
            raise NotImplementedError(
                'Strange corner case. Send us offending code!')
    node.name = new_name

    program_ctx.update_name_map(namer)
    # TODO(mdan): Use this at compilation.

    return [node], new_name, namespace
示例#10
0
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None):
    """Specialization of `entity_to_graph` for callable functions."""

    node, source = parser.parse_entity(f)
    node = node.body[0]

    # In general, the output of inspect.getsource is inexact because it uses
    # regex matching to adjust the exact location around the line number that
    # CPython records. This is particularly problematic for lambda functions,
    # where the entire containing lines are returned.
    nodes = ast_util.find_matching_definitions(node, f)
    if len(nodes) != 1:
        if f.__name__ == '<lambda>':
            raise ValueError(
                'Unable to identify source code of lambda function {}. It was'
                ' defined on this line: {}, which must contain a single lambda with'
                ' matching signature. To avoid ambiguity, define each lambda'
                ' in a separate expression.'.format(f, source))
        else:
            raise ValueError(
                'Unable to identify source code of function {}. The source code'
                ' reported by Python did not include exactly one matching signature:'
                '\n{}\n. This is an extremely rare occurrence. Please report it to'
                ' the TensorFlow team.'.format(f, source))
    node, = nodes

    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve(node, source, f)
    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = program_ctx.new_namer(namespace)

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         namespace=namespace,
                                         arg_values=arg_values,
                                         arg_types=arg_types,
                                         owner_type=owner_type)
    context = converter.EntityContext(namer, entity_info, program_ctx)
    node = node_to_graph(node, context)

    if isinstance(node, gast.Lambda):
        new_name = namer.new_symbol('tf__lambda', ())
        node = gast.Assign(targets=[gast.Name(new_name, gast.Store(), None)],
                           value=node)

    else:
        # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py
        new_name, did_rename = namer.compiled_function_name(
            f.__name__, f, owner_type)
        if did_rename:
            node.name = new_name
        else:
            new_name = f.__name__
            assert node.name == new_name

    program_ctx.update_name_map(namer)
    # TODO(mdan): Use this at compilation.

    return [node], new_name, namespace
示例#11
0
    def test_resolve(self):

        source = """
        def test_fn(x):
          '''Docstring.'''
          return x  # comment
    """
        source = textwrap.dedent(source)

        node = parser.parse_str(source)

        origin_info.resolve(node, source)

        origin = anno.getanno(node, anno.Basic.ORIGIN)
        self.assertEqual(origin.loc.lineno, 2)
        self.assertEqual(origin.loc.col_offset, 0)
        self.assertEqual(origin.source_code_line, 'def test_fn(x):')
        self.assertIsNone(origin.comment)

        origin = anno.getanno(node.body[0], anno.Basic.ORIGIN)
        self.assertEqual(origin.loc.lineno, 3)
        self.assertEqual(origin.loc.col_offset, 2)
        self.assertEqual(origin.source_code_line, "  '''Docstring.'''")
        self.assertIsNone(origin.comment)

        origin = anno.getanno(node.body[1], anno.Basic.ORIGIN)
        self.assertEqual(origin.loc.lineno, 4)
        self.assertEqual(origin.loc.col_offset, 2)
        self.assertEqual(origin.source_code_line, '  return x  # comment')
        self.assertEqual(origin.comment, 'comment')
示例#12
0
  def prepare(self,
              test_fn,
              namespace,
              namer=None,
              arg_types=None,
              owner_type=None,
              recursive=True,
              strip_decorators=()):
    namespace['ConversionOptions'] = converter.ConversionOptions

    node, source = parser.parse_entity(test_fn)
    node = node.body[0]
    if namer is None:
      namer = FakeNamer()
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(
            recursive=recursive,
            strip_decorators=strip_decorators,
            verbose=True),
        partial_types=None,
        autograph_module=None,
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    entity_info = transformer.EntityInfo(
        source_code=source,
        source_file='<fragment>',
        namespace=namespace,
        arg_values=None,
        arg_types=arg_types,
        owner_type=owner_type)
    ctx = converter.EntityContext(namer, entity_info, program_ctx)
    origin_info.resolve(node, source, test_fn)
    node = converter.standard_analysis(node, ctx, is_initial=True)
    return node, ctx
示例#13
0
def function_to_graph(f, program_ctx, arg_values, arg_types, do_rename=True):
    """Specialization of `entity_to_graph` for callable functions."""

    future_features = inspect_utils.getfutureimports(f)
    node, source = parser.parse_entity(f, future_features=future_features)
    logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
    # Parsed AST should contain future imports and one function def node.

    # In general, the output of inspect.getsource is inexact for lambdas because
    # it uses regex matching to adjust the exact location around the line number
    # that CPython records. Then, the entire containing line is returned, which
    # we may have trouble disambiguating. For example:
    # x, y = lambda: 1, lambda: 2
    if f.__name__ == '<lambda>':
        nodes = ast_util.find_matching_definitions(node, f)
        if len(nodes) != 1:
            raise ValueError(
                'Unable to identify source code of lambda function {}. It was'
                ' defined on this line: {}, which must contain a single lambda with'
                ' matching signature. To avoid ambiguity, define each lambda'
                ' in a separate expression.'.format(f, source))
        node, = nodes

    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve(node, source, f)
    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = naming.Namer(namespace)

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         future_features=future_features,
                                         namespace=namespace,
                                         arg_values=arg_values,
                                         arg_types=arg_types)
    context = converter.EntityContext(namer, entity_info, program_ctx)
    try:
        node = node_to_graph(node, context)
    except (ValueError, AttributeError, KeyError, NotImplementedError) as e:
        logging.error(1, 'Error converting %s', f, exc_info=True)
        raise errors.InternalError('conversion', e)
        # TODO(mdan): Catch and rethrow syntax errors.

    if isinstance(node, gast.Lambda):
        new_name = namer.new_symbol('tf__lambda', ())
        node = gast.Assign(targets=[gast.Name(new_name, gast.Store(), None)],
                           value=node)

    elif do_rename:
        new_name = namer.function_name(f.__name__)
        node.name = new_name
    else:
        new_name = f.__name__
        assert node.name == new_name

    return (node, ), new_name, entity_info
示例#14
0
def function_to_graph(f,
                      program_ctx,
                      arg_values,
                      arg_types,
                      owner_type=None):
  """Specialization of `entity_to_graph` for callable functions."""

  node, source = parser.parse_entity(f)
  node = node.body[0]

  # TODO(mdan): Can we convert everything and scoop the lambda afterwards?
  if f.__name__ == '<lambda>':
    nodes = ast_util.find_matching_lambda_definitions(node, f)
    if len(nodes) != 1:
      raise ValueError(
          'Unable to identify source code of lambda function {}. It was'
          ' defined on this line: {}, which contains multiple lambdas with'
          ' identical argument names. To avoid ambiguity, define each lambda'
          ' in a separate expression.'.format(f, source))
    node, = nodes

  # TODO(znado): Place inside standard_analysis.
  origin_info.resolve(node, source, f)
  namespace = inspect_utils.getnamespace(f)
  _add_self_references(namespace, program_ctx.autograph_module)
  namer = program_ctx.new_namer(namespace)

  entity_info = transformer.EntityInfo(
      source_code=source,
      source_file='<fragment>',
      namespace=namespace,
      arg_values=arg_values,
      arg_types=arg_types,
      owner_type=owner_type)
  context = converter.EntityContext(namer, entity_info, program_ctx)
  node = node_to_graph(node, context)

  if isinstance(node, gast.Lambda):
    new_name = namer.new_symbol('tf__lambda', ())
    node = gast.Assign(
        targets=[gast.Name(new_name, gast.Store(), None)], value=node)

  else:
    # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py
    new_name, did_rename = namer.compiled_function_name(f.__name__, f,
                                                        owner_type)
    if did_rename:
      node.name = new_name
    else:
      new_name = f.__name__
      assert node.name == new_name

  program_ctx.update_name_map(namer)
  # TODO(mdan): Use this at compilation.

  return [node], new_name, namespace
示例#15
0
def convert_func_to_ast(f, program_ctx, do_rename=True):
  """Specialization of `convert_entity_to_ast` for callable functions."""

  future_features = inspect_utils.getfutureimports(f)
  node, source = parser.parse_entity(f, future_features=future_features)
  logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
  # Parsed AST should contain future imports and one function def node.

  # In general, the output of inspect.getsource is inexact for lambdas because
  # it uses regex matching to adjust the exact location around the line number
  # that CPython records. Then, the entire containing line is returned, which
  # we may have trouble disambiguating. For example:
  # x, y = lambda: 1, lambda: 2
  if f.__name__ == '<lambda>':
    nodes = ast_util.find_matching_definitions(node, f)
    if len(nodes) != 1:
      raise ValueError(
          'Unable to identify source code of lambda function {}. It was'
          ' defined on this line: {}, which must contain a single lambda with'
          ' matching signature. To avoid ambiguity, define each lambda'
          ' in a separate expression.'.format(f, source))
    node, = nodes

  # TODO(znado): Place inside standard_analysis.
  origin_info.resolve(node, source, f)
  namespace = inspect_utils.getnamespace(f)
  _add_self_references(namespace, program_ctx.autograph_module)
  namer = naming.Namer(namespace)

  entity_info = transformer.EntityInfo(
      source_code=source,
      source_file='<fragment>',
      future_features=future_features,
      namespace=namespace)
  context = converter.EntityContext(namer, entity_info, program_ctx)
  try:
    node = node_to_graph(node, context)
  except (ValueError, AttributeError, KeyError, NotImplementedError) as e:
    logging.error(1, 'Error converting %s', f, exc_info=True)
    raise errors.InternalError('conversion', e)
    # TODO(mdan): Catch and rethrow syntax errors.

  if isinstance(node, gast.Lambda):
    new_name = namer.new_symbol('tf__lambda', ())
    node = gast.Assign(
        targets=[gast.Name(new_name, gast.Store(), None)], value=node)

  elif do_rename:
    new_name = namer.function_name(f.__name__)
    node.name = new_name
  else:
    new_name = f.__name__
    assert node.name == new_name

  return (node,), new_name, entity_info
示例#16
0
    def test_resolve_with_trailing_garbage(self):
        # This comment will be missed because the tokenizer fails to reach it.
        source = '   lambda: foo([], bar=1)), baz=2)()'
        clean_source = 'lambda: foo([], bar=1)'
        node = parser.parse(clean_source).value
        origin_info.resolve(node, source, 'test_file', 10, 10)

        def_origin = anno.getanno(node, anno.Basic.ORIGIN)
        self.assertEqual(def_origin.loc.lineno, 10)
        self.assertEqual(def_origin.loc.col_offset, 10)
        self.assertEqual(def_origin.source_code_line, source)
        self.assertIsNone(def_origin.comment)
示例#17
0
  def test_basic_codegen(self):

    class TestCodegen(transformer.CodeGenerator):

      def visit_Assign(self, node):
        self.emit(parser.unparse(node, include_encoding_marker=False))
        self.emit('\n')

      def visit_Return(self, node):
        self.emit(parser.unparse(node, include_encoding_marker=False))
        self.emit('\n')

      def visit_If(self, node):
        self.emit('if ')
        # This is just for simplifity. A real generator will walk the tree and
        # emit proper code.
        self.emit(parser.unparse(node.test, include_encoding_marker=False))
        self.emit(' {\n')
        self.visit_block(node.body)
        self.emit('} else {\n')
        self.visit_block(node.orelse)
        self.emit('}\n')

    tg = TestCodegen(self._simple_context())

    def test_fn():
      x = 1
      if x > 0:
        x = 2
        if x > 1:
          x = 3
      return x

    node, source = parser.parse_entity(test_fn, future_features=())
    origin_info.resolve(node, source, 'test_file', 100, 0)
    tg.visit(node)

    self.assertEqual(
        tg.code_buffer, '\n'.join([
            'x = 1',
            'if (x > 0) {',
            'x = 2',
            'if (x > 1) {',
            'x = 3',
            '} else {',
            '}',
            '} else {',
            '}',
            'return x',
            '',
        ]))
    def prepare(self, test_fn, namespace, arg_types=None, recursive=True):
        namespace['ConversionOptions'] = converter.ConversionOptions

        node, source, _ = parser.parse_entity(test_fn)
        namer = naming.Namer(namespace)
        program_ctx = converter.ProgramContext(
            options=converter.ConversionOptions(recursive=recursive),
            autograph_module=None)
        entity_info = transformer.EntityInfo(source_code=source,
                                             source_file='<fragment>',
                                             namespace=namespace,
                                             arg_values=None,
                                             arg_types=arg_types)
        ctx = converter.EntityContext(namer, entity_info, program_ctx)
        origin_info.resolve(node, source, test_fn)
        node = converter.standard_analysis(node, ctx, is_initial=True)
        return node, ctx
示例#19
0
  def prepare(self, test_fn, namespace, recursive=True):
    namespace['ConversionOptions'] = converter.ConversionOptions

    future_features = ('print_function', 'division')
    node, source = parser.parse_entity(test_fn, future_features=future_features)
    namer = naming.Namer(namespace)
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(recursive=recursive),
        autograph_module=None)
    entity_info = transformer.EntityInfo(
        source_code=source,
        source_file='<fragment>',
        future_features=future_features,
        namespace=namespace)
    ctx = converter.EntityContext(namer, entity_info, program_ctx)
    origin_info.resolve(node, source, test_fn)
    node = converter.standard_analysis(node, ctx, is_initial=True)
    return node, ctx
    def test_source_map(self):
        def test_fn(x):
            if x > 0:
                x += 1
            return x

        node, source = parser.parse_entity(test_fn)
        fn_node = node.body[0]
        origin_info.resolve(fn_node, source)

        # Insert a traced line.
        new_node = parser.parse_str('x = abs(x)').body[0]
        anno.copyanno(fn_node.body[0], new_node, anno.Basic.ORIGIN)
        fn_node.body.insert(0, new_node)

        # Insert an untraced line.
        fn_node.body.insert(0, parser.parse_str('x = 0').body[0])

        modified_source = compiler.ast_to_source(fn_node)

        source_map = origin_info.source_map(fn_node, modified_source,
                                            'test_filename', [0])

        loc = origin_info.LineLocation('test_filename', 1)
        origin = source_map[loc]
        self.assertEqual(origin.source_code_line, 'def test_fn(x):')
        self.assertEqual(origin.loc.lineno, 1)

        # The untraced line, inserted second.
        loc = origin_info.LineLocation('test_filename', 2)
        self.assertFalse(loc in source_map)

        # The traced line, inserted first.
        loc = origin_info.LineLocation('test_filename', 3)
        origin = source_map[loc]
        self.assertEqual(origin.source_code_line, '  if x > 0:')
        self.assertEqual(origin.loc.lineno, 2)

        loc = origin_info.LineLocation('test_filename', 4)
        origin = source_map[loc]
        self.assertEqual(origin.source_code_line, '  if x > 0:')
        self.assertEqual(origin.loc.lineno, 2)
示例#21
0
    def test_origin_info_propagated_to_new_nodes(self):
        class TestTransformer(transformer.Base):
            def visit_If(self, node):
                return gast.Pass()

        tr = TestTransformer(self._simple_context())

        def test_fn():
            x = 1
            if x > 0:
                x = 1
            return x

        node, source = parser.parse_entity(test_fn, future_features=())
        origin_info.resolve(node, source, 'test_file', 100, 0)
        node = tr.visit(node)

        created_pass_node = node.body[1]
        # Takes the line number of the if statement.
        self.assertEqual(
            anno.getanno(created_pass_node, anno.Basic.ORIGIN).loc.lineno, 102)
示例#22
0
  def test_origin_info_propagated_to_new_nodes(self):

    class TestTransformer(transformer.Base):

      def visit_If(self, node):
        return gast.Pass()

    tr = TestTransformer(self._simple_context())

    def test_fn():
      x = 1
      if x > 0:
        x = 1
      return x

    node, source = parser.parse_entity(test_fn, future_features=())
    origin_info.resolve(node, source)
    node = tr.visit(node)

    created_pass_node = node.body[1]
    self.assertEqual(
        anno.getanno(created_pass_node, anno.Basic.ORIGIN).loc.lineno, 3)
示例#23
0
def function_to_graph(f,
                      program_ctx,
                      arg_values,
                      arg_types,
                      owner_type=None,
                      rewrite_errors=True):
  """Specialization of `entity_to_graph` for callable functions."""

  node, source = parser.parse_entity(f)
  node = node.body[0]
  origin_info.resolve(node, source, f)
  namespace = inspect_utils.getnamespace(f)
  _add_self_references(namespace, program_ctx.autograph_module)
  namer = program_ctx.new_namer(namespace)

  entity_info = transformer.EntityInfo(
      source_code=source,
      source_file='<fragment>',
      namespace=namespace,
      arg_values=arg_values,
      arg_types=arg_types,
      owner_type=owner_type)
  context = converter.EntityContext(namer, entity_info, program_ctx)
  node = node_to_graph(node, context, rewrite_errors=rewrite_errors)

  # TODO(mdan): This somewhat duplicates the call rename logic in call_trees.py
  new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type)
  if not did_rename:
    new_name = f.__name__
    if node.name != f.__name__:
      raise NotImplementedError('Strange corner case. Send us offending code!')
  node.name = new_name

  program_ctx.update_name_map(namer)
  # TODO(mdan): Use this at compilation.

  return [node], new_name, namespace
示例#24
0
def function_to_graph(f,
                      program_ctx,
                      arg_values,
                      arg_types,
                      owner_type=None):
  """Specialization of `entity_to_graph` for callable functions."""

  node, source = parser.parse_entity(f)
  node = node.body[0]

  # In general, the output of inspect.getsource is inexact because it uses crude
  # regex matching methods to search the source file. This is particularly
  # problematic for lambda functions, where the entire containing lines are
  # returned. Certain distributions of CPython may also return the enclosing
  # function for local functions.
  nodes = ast_util.find_matching_definitions(node, f)
  if len(nodes) != 1:
    if f.__name__ == '<lambda>':
      raise ValueError(
          'Unable to identify source code of lambda function {}. It was'
          ' defined on this line: {}, which must contain a single lambda with'
          ' matching signature. To avoid ambiguity, define each lambda'
          ' in a separate expression.'.format(f, source))
    else:
      raise ValueError(
          'Unable to identify source code of function {}. The source code'
          ' reported by Python did not include exactly one matching signature:'
          '\n{}\nTo avoid ambiguity, use a unique name for each'
          ' function.'.format(f, source))
  node, = nodes

  # TODO(znado): Place inside standard_analysis.
  origin_info.resolve(node, source, f)
  namespace = inspect_utils.getnamespace(f)
  _add_self_references(namespace, program_ctx.autograph_module)
  namer = program_ctx.new_namer(namespace)

  entity_info = transformer.EntityInfo(
      source_code=source,
      source_file='<fragment>',
      namespace=namespace,
      arg_values=arg_values,
      arg_types=arg_types,
      owner_type=owner_type)
  context = converter.EntityContext(namer, entity_info, program_ctx)
  node = node_to_graph(node, context)

  if isinstance(node, gast.Lambda):
    new_name = namer.new_symbol('tf__lambda', ())
    node = gast.Assign(
        targets=[gast.Name(new_name, gast.Store(), None)], value=node)

  else:
    # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py
    new_name, did_rename = namer.compiled_function_name(f.__name__, f,
                                                        owner_type)
    if did_rename:
      node.name = new_name
    else:
      new_name = f.__name__
      assert node.name == new_name

  program_ctx.update_name_map(namer)
  # TODO(mdan): Use this at compilation.

  return [node], new_name, namespace
示例#25
0
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None):
    """Specialization of `entity_to_graph` for callable functions."""

    node, source = parser.parse_entity(f)
    node = node.body[0]

    # In general, the output of inspect.getsource is inexact because it uses crude
    # regex matching methods to search the source file. This is particularly
    # problematic for lambda functions, where the entire containing lines are
    # returned. Certain distributions of CPython may also return the enclosing
    # function for local functions.
    nodes = ast_util.find_matching_definitions(node, f)
    if len(nodes) != 1:
        if f.__name__ == '<lambda>':
            raise ValueError(
                'Unable to identify source code of lambda function {}. It was'
                ' defined on this line: {}, which must contain a single lambda with'
                ' matching signature. To avoid ambiguity, define each lambda'
                ' in a separate expression.'.format(f, source))
        else:
            # The inspect.getsource bug is currently known to occur in the Windows
            # integration tests which run Python 3.6.
            # TODO(mdan): Find out eaxctly which distribution of Python is that.
            raise ValueError(
                'Unable to identify source code of function {}. The source code'
                ' reported by Python did not include exactly one matching signature:'
                '\n{}\nTo avoid ambiguity, use a unique name for each'
                ' function.\nNote that some distributions of Python may report source'
                ' code incorrectly. It may be possible to avoid that bug by'
                ' organizing the code into smaller units (smaller files, functions or'
                ' classes), or by turning AutoGraph off.'.format(f, source))
    node, = nodes

    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve(node, source, f)
    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = program_ctx.new_namer(namespace)

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         namespace=namespace,
                                         arg_values=arg_values,
                                         arg_types=arg_types,
                                         owner_type=owner_type)
    context = converter.EntityContext(namer, entity_info, program_ctx)
    node = node_to_graph(node, context)

    if isinstance(node, gast.Lambda):
        new_name = namer.new_symbol('tf__lambda', ())
        node = gast.Assign(targets=[gast.Name(new_name, gast.Store(), None)],
                           value=node)

    else:
        # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py
        new_name, did_rename = namer.compiled_function_name(
            f.__name__, f, owner_type)
        if did_rename:
            node.name = new_name
        else:
            new_name = f.__name__
            assert node.name == new_name

    program_ctx.update_name_map(namer)
    # TODO(mdan): Use this at compilation.

    return [node], new_name, namespace