示例#1
0
    def test_resolve_entity(self):
        test_fn = basic_definitions.simple_function
        node, source = parser.parse_entity(
            test_fn, inspect_utils.getfutureimports(test_fn))
        origin_info.resolve_entity(node, source, test_fn)

        # The line numbers below should match those in basic_definitions.py
        fn_start = inspect.getsourcelines(test_fn)[1]

        def_origin = anno.getanno(node, anno.Basic.ORIGIN)
        self.assertEqual(def_origin.loc.lineno, fn_start)
        self.assertEqual(def_origin.loc.col_offset, 0)
        self.assertEqual(def_origin.source_code_line,
                         'def simple_function(x):')
        self.assertIsNone(def_origin.comment)

        docstring_origin = anno.getanno(node.body[0], anno.Basic.ORIGIN)
        self.assertEqual(docstring_origin.loc.lineno, fn_start + 1)
        self.assertEqual(docstring_origin.loc.col_offset, 2)
        self.assertEqual(docstring_origin.source_code_line,
                         '  """Docstring."""')
        self.assertIsNone(docstring_origin.comment)

        ret_origin = anno.getanno(node.body[1], anno.Basic.ORIGIN)
        self.assertEqual(ret_origin.loc.lineno, fn_start + 2)
        self.assertEqual(ret_origin.loc.col_offset, 2)
        self.assertEqual(ret_origin.source_code_line, '  return x  # comment')
        self.assertEqual(ret_origin.comment, 'comment')
示例#2
0
 def test_getfutureimports_methods(self):
     imps = inspect_utils.getfutureimports(
         basic_definitions.SimpleClass.method_with_print)
     self.assertIn('absolute_import', imps)
     self.assertIn('division', imps)
     self.assertIn('print_function', imps)
     self.assertNotIn('generators', imps)
示例#3
0
def convert_func_to_ast(f, program_ctx, do_rename=True):
    """Specialization of `convert_entity_to_ast` for callable functions."""

    future_features = inspect_utils.getfutureimports(f)
    node, source = parser.parse_entity(f, future_features=future_features)
    logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
    # Parsed AST should contain future imports and one function def node.

    # In general, the output of inspect.getsource is inexact for lambdas because
    # it uses regex matching to adjust the exact location around the line number
    # that CPython records. Then, the entire containing line is returned, which
    # we may have trouble disambiguating. For example:
    # x, y = lambda: 1, lambda: 2
    if f.__name__ == '<lambda>':
        nodes = ast_util.find_matching_definitions(node, f)
        if len(nodes) != 1:
            raise ValueError(
                'Unable to identify source code of lambda function {}. It was'
                ' defined on this line: {}, which must contain a single lambda with'
                ' matching signature. To avoid ambiguity, define each lambda'
                ' in a separate expression.'.format(f, source))
        node, = nodes

    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve_entity(node, source, f)

    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = naming.Namer(namespace)

    if isinstance(node, gast.Lambda):
        new_name = namer.new_symbol('tf__lambda', ())
    elif do_rename:
        new_name = namer.function_name(f.__name__)
    else:
        new_name = f.__name__

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         future_features=future_features,
                                         namespace=namespace)
    context = converter.EntityContext(namer, entity_info, program_ctx,
                                      new_name)
    node = node_to_graph(node, context)

    if isinstance(node, gast.Lambda):
        node = gast.Assign(targets=[
            gast.Name(new_name,
                      ctx=gast.Store(),
                      annotation=None,
                      type_comment=None)
        ],
                           value=node)
    elif do_rename:
        node.name = new_name
    else:
        assert node.name == new_name

    return (node, ), new_name, entity_info
示例#4
0
def convert_func_to_ast(f, program_ctx, do_rename=True):
  """Specialization of `convert_entity_to_ast` for callable functions."""

  future_features = inspect_utils.getfutureimports(f)
  node, source = parser.parse_entity(f, future_features=future_features)
  logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
  # Parsed AST should contain future imports and one function def node.

  # In general, the output of inspect.getsource is inexact for lambdas because
  # it uses regex matching to adjust the exact location around the line number
  # that CPython records. Then, the entire containing line is returned, which
  # we may have trouble disambiguating. For example:
  # x, y = lambda: 1, lambda: 2
  if f.__name__ == '<lambda>':
    nodes = ast_util.find_matching_definitions(node, f)
    if len(nodes) != 1:
      raise ValueError(
          'Unable to identify source code of lambda function {}. It was'
          ' defined on this line: {}, which must contain a single lambda with'
          ' matching signature. To avoid ambiguity, define each lambda'
          ' in a separate expression.'.format(f, source))
    node, = nodes

  # TODO(znado): Place inside standard_analysis.
  origin_info.resolve(node, source, f)
  namespace = inspect_utils.getnamespace(f)
  _add_self_references(namespace, program_ctx.autograph_module)
  namer = naming.Namer(namespace)

  entity_info = transformer.EntityInfo(
      source_code=source,
      source_file='<fragment>',
      future_features=future_features,
      namespace=namespace)
  context = converter.EntityContext(namer, entity_info, program_ctx)
  try:
    node = node_to_graph(node, context)
  except (ValueError, AttributeError, KeyError, NotImplementedError) as e:
    logging.error(1, 'Error converting %s', f, exc_info=True)
    raise errors.InternalError('conversion', e)
    # TODO(mdan): Catch and rethrow syntax errors.

  if isinstance(node, gast.Lambda):
    new_name = namer.new_symbol('tf__lambda', ())
    node = gast.Assign(
        targets=[gast.Name(new_name, gast.Store(), None)], value=node)

  elif do_rename:
    new_name = namer.function_name(f.__name__)
    node.name = new_name
  else:
    new_name = f.__name__
    assert node.name == new_name

  return (node,), new_name, entity_info
示例#5
0
  def test_create_source_map_no_origin_info(self):

    test_fn = basic_definitions.simple_function
    node, _ = parser.parse_entity(test_fn,
                                  inspect_utils.getfutureimports(test_fn))
    # No origin information should result in an empty map.
    test_fn_lines, _ = tf_inspect.getsourcelines(test_fn)
    source_map = origin_info.create_source_map(node, '\n'.join(test_fn_lines),
                                               test_fn)

    self.assertEmpty(source_map)
示例#6
0
    def _transform_function(self, fn, user_context):
        """Performs source code transformation on a function."""
        future_features = inspect_utils.getfutureimports(fn)
        node, source = parser.parse_entity(fn, future_features=future_features)
        logging.log(3, 'Source code of %s:\n\n%s\n', fn, source)

        # In general, the output of inspect.getsource is inexact for lambdas
        # because it uses regex matching to adjust the exact location around
        # the line number that CPython records. Then, the entire containing line
        # is returned, which we may have trouble disambiguating.
        # For example:
        #   x, y = lambda: 1, lambda: 2
        is_lambda = fn.__name__ == '<lambda>'
        if is_lambda:
            nodes = ast_util.find_matching_definitions(node, fn)
            if len(nodes) != 1:
                raise ValueError(
                    'Unable to identify source code of lambda function {}.'
                    ' It was defined in this code:\n'
                    '{}\n'
                    'This code must contain a single distinguishable lambda.'
                    ' To avoid this problem, define each lambda in a separate'
                    ' expression.'.format(fn, source))
            node, = nodes

        origin_info.resolve_entity(node, source, fn)

        namespace = inspect_utils.getnamespace(fn)
        namer = naming.Namer(namespace)
        new_name = namer.new_symbol(self.get_transformed_name(node), ())
        entity_info = transformer.EntityInfo(name=new_name,
                                             source_code=source,
                                             source_file='<fragment>',
                                             future_features=future_features,
                                             namespace=namespace)
        context = transformer.Context(entity_info, namer, user_context)

        node = self._erase_arg_defaults(node)
        node = self.transform_ast(node, context)

        if is_lambda:
            node = gast.Assign(targets=[
                gast.Name(new_name,
                          ctx=gast.Store(),
                          annotation=None,
                          type_comment=None)
            ],
                               value=node)
        else:
            node.name = new_name

        return node, context
示例#7
0
def _parse_and_analyze(f):
    """Performs preliminary analyses and transformations.

  The goal is to massage the source program into a form on which
  the `_AutoBatchingTransformer` below will be successful.

  Args:
    f: Function to analyze

  Returns:
    node: A Python AST node representing the function, suitable for
      passing to `_AutoBatchingTransformer.visit`
    entity_info: An AutoGraph `EntityInfo` object, with some information
      about `f`.  Required for initializing `_AutoBatchingTransformer`.
  """
    namespace = {}

    # Get the AST of the function
    future_features = inspect_utils.getfutureimports(f)
    node, _ = parser.parse_entity(f, future_features=future_features)

    # Boilerplate for AutoGraph transforms
    entity_info = transformer.EntityInfo(source_code='',
                                         source_file=None,
                                         future_features=future_features,
                                         namespace=namespace)
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(recursive=True),
        autograph_module=None)
    ctx = converter.EntityContext(namer=naming.Namer(namespace),
                                  entity_info=entity_info,
                                  program_ctx=program_ctx)

    # Canonicalize away break statements
    node = converter.standard_analysis(node, ctx, is_initial=True)
    node = break_statements.transform(node, ctx)

    # Canonicalize away continue statements
    node = converter.standard_analysis(node, ctx, is_initial=False)
    node = continue_statements.transform(node, ctx)

    # Force single returns
    node = converter.standard_analysis(node, ctx, is_initial=False)
    node = return_statements.transform(node, ctx, default_to_null_return=False)

    # Transform into ANF
    node = anf.transform(node, ctx)
    node = converter.standard_analysis(node, ctx, is_initial=False)

    return node, ctx
示例#8
0
    def transform_function(self, fn, user_context):
        """Transforms a function.

    Subclasses may override this method. The return value is opaque.

    The method receives the original AST. The result is passed as-is to the
    output of `transform`.

    Args:
      fn: A function or lambda.
      user_context: An opaque object (may be None) that is forwarded to
        transform_ast, through the ctx.user_context argument.
    Returns:
      Any. By default it returns the output of transform_ast.
    """
        future_features = inspect_utils.getfutureimports(fn)
        node, source = parser.parse_entity(fn, future_features=future_features)
        logging.log(3, 'Source code of %s:\n\n%s\n', fn, source)

        origin_info.resolve_entity(node, source, fn)

        namespace = inspect_utils.getnamespace(fn)
        namer = naming.Namer(namespace)
        new_name = namer.new_symbol(self.get_transformed_name(node), ())
        entity_info = transformer.EntityInfo(name=new_name,
                                             source_code=source,
                                             source_file='<fragment>',
                                             future_features=future_features,
                                             namespace=namespace)
        context = transformer.Context(entity_info, namer, user_context)

        node = self._erase_arg_defaults(node)
        node = self.transform_ast(node, context)

        if isinstance(node, gast.Lambda):
            node = gast.Assign(targets=[
                gast.Name(new_name,
                          ctx=gast.Store(),
                          annotation=None,
                          type_comment=None)
            ],
                               value=node)
        else:
            node.name = new_name

        return node, context
示例#9
0
  def test_resolve_entity_nested_function(self):
    test_fn = basic_definitions.nested_functions
    node, source = parser.parse_entity(
        test_fn, inspect_utils.getfutureimports(test_fn))
    origin_info.resolve_entity(node, source, test_fn)

    # The line numbers below should match those in basic_definitions.py
    fn_start = inspect.getsourcelines(test_fn)[1]

    inner_def_origin = anno.getanno(node.body[1], anno.Basic.ORIGIN)
    self.assertEqual(inner_def_origin.loc.lineno, fn_start + 3)
    self.assertEqual(inner_def_origin.loc.col_offset, 2)
    self.assertEqual(inner_def_origin.source_code_line, '  def inner_fn(y):')
    self.assertIsNone(inner_def_origin.comment)

    inner_ret_origin = anno.getanno(node.body[1].body[0], anno.Basic.ORIGIN)
    self.assertEqual(inner_ret_origin.loc.lineno, fn_start + 4)
    self.assertEqual(inner_ret_origin.loc.col_offset, 4)
    self.assertEqual(inner_ret_origin.source_code_line, '    return y')
    self.assertIsNone(inner_ret_origin.comment)
示例#10
0
    def test_resolve_entity_indented_block(self):

        test_fn = basic_definitions.SimpleClass.simple_method
        node, source = parser.parse_entity(
            test_fn, inspect_utils.getfutureimports(test_fn))
        origin_info.resolve_entity(node, source, test_fn)

        # The line numbers below should match those in basic_definitions.py

        def_origin = anno.getanno(node, anno.Basic.ORIGIN)
        self.assertEqual(def_origin.loc.lineno, 46)
        self.assertEqual(def_origin.loc.col_offset, 2)
        self.assertEqual(def_origin.source_code_line,
                         'def simple_method(self):')
        self.assertIsNone(def_origin.comment)

        ret_origin = anno.getanno(node.body[0], anno.Basic.ORIGIN)
        self.assertEqual(ret_origin.loc.lineno, 47)
        self.assertEqual(ret_origin.loc.col_offset, 4)
        self.assertEqual(ret_origin.source_code_line, '  return self')
        self.assertIsNone(ret_origin.comment)
示例#11
0
  def transform_function(self, fn, user_context):
    """Transforms a function.

    Subclasses may override this method. The return value is opaque.

    The method receives the original AST. The result is passed as-is to the
    output of `transform`.

    Args:
      fn: A function or lambda.
      user_context: An opaque object (may be None) that is forwarded to
        transform_ast, through the ctx.user_context argument.
    Returns:
      Tuple[Any, Any]. By default it returns the output of transform_ast,
      together with a `transformer.Context` containing information about the
      transformation process.
    """
    future_features = inspect_utils.getfutureimports(fn)
    node, source = parser.parse_entity(fn, future_features=future_features)
    logging.log(3, 'Source code of %s:\n\n%s\n', fn, source)

    origin_info.resolve_entity(node, source, fn)

    namespace = inspect_utils.getnamespace(fn)
    namer = naming.Namer(namespace)
    new_name = namer.new_symbol(self.get_transformed_name(node), ())
    entity_info = transformer.EntityInfo(
        name=new_name,
        source_code=source,
        source_file='<fragment>',
        future_features=future_features,
        namespace=namespace)
    context = transformer.Context(entity_info, namer, user_context)

    node = self._erase_arg_defaults(node)
    result = self.transform_ast(node, context)

    return result, context
示例#12
0
    def test_resolve_entity_decorated_function(self):
        test_fn = basic_definitions.decorated_function
        node, source = parser.parse_entity(
            test_fn, inspect_utils.getfutureimports(test_fn))
        origin_info.resolve_entity(node, source, test_fn)

        # The line numbers below should match those in basic_definitions.py
        fn_start = inspect.getsourcelines(test_fn)[1]

        def_origin = anno.getanno(node, anno.Basic.ORIGIN)
        if sys.version_info >= (3, 8):
            self.assertEqual(def_origin.loc.lineno, fn_start + 2)
            self.assertEqual(def_origin.source_code_line,
                             'def decorated_function(x):')
        else:
            self.assertEqual(def_origin.loc.lineno, fn_start)
            self.assertEqual(def_origin.source_code_line, '@basic_decorator')
        self.assertEqual(def_origin.loc.col_offset, 0)
        self.assertIsNone(def_origin.comment)

        if_origin = anno.getanno(node.body[0], anno.Basic.ORIGIN)
        self.assertEqual(if_origin.loc.lineno, fn_start + 3)
        self.assertEqual(if_origin.loc.col_offset, 2)
        self.assertEqual(if_origin.source_code_line, '  if x > 0:')
        self.assertIsNone(if_origin.comment)

        ret1_origin = anno.getanno(node.body[0].body[0], anno.Basic.ORIGIN)
        self.assertEqual(ret1_origin.loc.lineno, fn_start + 4)
        self.assertEqual(ret1_origin.loc.col_offset, 4)
        self.assertEqual(ret1_origin.source_code_line, '    return 1')
        self.assertIsNone(ret1_origin.comment)

        ret2_origin = anno.getanno(node.body[1], anno.Basic.ORIGIN)
        self.assertEqual(ret2_origin.loc.lineno, fn_start + 5)
        self.assertEqual(ret2_origin.loc.col_offset, 2)
        self.assertEqual(ret2_origin.source_code_line, '  return 2')
        self.assertIsNone(ret2_origin.comment)
示例#13
0
  def _transform_function(self, fn, user_context):
    """Performs source code transformation on a function."""
    future_features = inspect_utils.getfutureimports(fn)
    node, source = parser.parse_entity(fn, future_features=future_features)
    logging.log(3, 'Source code of %s:\n\n%s\n', fn, source)

    origin_info.resolve_entity(node, source, fn)

    namespace = inspect_utils.getnamespace(fn)
    namer = naming.Namer(namespace)
    new_name = namer.new_symbol(self.get_transformed_name(node), ())
    entity_info = transformer.EntityInfo(
        name=new_name,
        source_code=source,
        source_file='<fragment>',
        future_features=future_features,
        namespace=namespace)
    context = transformer.Context(entity_info, namer, user_context)

    node = self._erase_arg_defaults(node)
    node = self.transform_ast(node, context)

    if isinstance(node, gast.Lambda):
      node = gast.Assign(
          targets=[
              gast.Name(
                  new_name,
                  ctx=gast.Store(),
                  annotation=None,
                  type_comment=None)
          ],
          value=node)
    else:
      node.name = new_name

    return node, context
示例#14
0
 def test_getfutureimports_methods(self):
     self.assertEqual(
         inspect_utils.getfutureimports(
             basic_definitions.SimpleClass.method_with_print),
         ('absolute_import', 'division', 'print_function',
          'with_statement'))
示例#15
0
 def test_getfutureimports_lambdas(self):
     self.assertEqual(
         inspect_utils.getfutureimports(basic_definitions.simple_lambda),
         ('absolute_import', 'division', 'print_function',
          'with_statement'))
 def test_getfutureimports_methods(self):
   self.assertEqual(inspect_utils.getfutureimports(future_import_module.Foo.f),
                    future_import_module_statements)
 def test_getfutureimports_lambdas(self):
   self.assertEqual(
       inspect_utils.getfutureimports(future_import_module.lambda_f),
       future_import_module_statements)
示例#18
0
 def test_getfutureimports_lambdas(self):
     imps = inspect_utils.getfutureimports(basic_definitions.simple_lambda)
     self.assertIn('absolute_import', imps)
     self.assertIn('division', imps)
     self.assertIn('print_function', imps)
     self.assertNotIn('generators', imps)
示例#19
0
 def test_getfutureimports_functions(self):
   self.assertEqual(inspect_utils.getfutureimports(future_import_module.f),
                    future_import_module_statements)
示例#20
0
 def test_getfutureimports_simple_case(self):
     expected_imports = ('absolute_import', 'division', 'print_function',
                         'with_statement')
     self.assertEqual(
         inspect_utils.getfutureimports(future_import_module.f),
         expected_imports)
示例#21
0
 def test_getfutureimports_lambdas(self):
   self.assertEqual(
       inspect_utils.getfutureimports(future_import_module.lambda_f),
       future_import_module_statements)
示例#22
0
 def test_getfutureimports_methods(self):
   self.assertEqual(inspect_utils.getfutureimports(future_import_module.Foo.f),
                    future_import_module_statements)
示例#23
0
def _parse_and_analyze(f, autobatch_functions):
    """Performs preliminary analyses and transformations.

  The goal is to massage the source program into a form on which
  the `_AutoBatchingTransformer` below will be successful.

  Args:
    f: Function to analyze
    autobatch_functions: List of Python `str` names of autobatched functions.
      Arguments to these functions will be canonicalized to variable references,
      but others will not.

  Returns:
    node: A Python AST node representing the function, suitable for
      passing to `_AutoBatchingTransformer.visit`
    entity_info: An AutoGraph `EntityInfo` object, with some information
      about `f`.  Required for initializing `_AutoBatchingTransformer`.
  """
    namespace = {}

    # Get the AST of the function
    future_features = inspect_utils.getfutureimports(f)
    node, _ = parser.parse_entity(f, future_features=future_features)

    # Boilerplate for AutoGraph transforms
    entity_info = transformer.EntityInfo(source_code='',
                                         source_file=None,
                                         future_features=future_features,
                                         namespace=namespace)
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(recursive=True),
        autograph_module=None)
    ctx = converter.EntityContext(namer=naming.Namer(namespace),
                                  entity_info=entity_info,
                                  program_ctx=program_ctx)

    # Canonicalize away break statements
    node = converter.standard_analysis(node, ctx, is_initial=True)
    node = break_statements.transform(node, ctx)

    # Canonicalize away continue statements
    node = converter.standard_analysis(node, ctx, is_initial=False)
    node = continue_statements.transform(node, ctx)

    # Force single returns
    node = converter.standard_analysis(node, ctx, is_initial=False)
    node = return_statements.transform(node, ctx, default_to_null_return=False)

    # Transform into ANF
    # Replacing if tests and autobatched function call arguments because
    # that's where divergence can happen.
    # Replacing all function calls because the downstream transformation
    # expects calls to lead directly to assignments.
    def maybe_replace_function_argument(parent, field_name, child):
        del field_name, child
        if not anno.hasanno(parent.func, anno.Basic.QN):
            return False
        func_name = anno.getanno(parent.func, anno.Basic.QN)
        if str(func_name) in autobatch_functions:
            return True
        return False

    anf_config = [
        (anf.ASTEdgePattern(gast.If, 'test', anf.ANY), anf.REPLACE),
        (anf.ASTEdgePattern(anf.ANY, anf.ANY, gast.Call), anf.REPLACE),
        (anf.ASTEdgePattern(gast.Call, 'args',
                            anf.ANY), maybe_replace_function_argument),
    ]
    node = anf.transform(node, ctx, config=anf_config)
    node = converter.standard_analysis(node, ctx, is_initial=False)

    return node, ctx
 def test_getfutureimports_functions(self):
   self.assertEqual(inspect_utils.getfutureimports(future_import_module.f),
                    future_import_module_statements)