Ejemplo n.º 1
0
    def prepare(self,
                test_fn,
                namespace,
                namer=None,
                arg_types=None,
                owner_type=None,
                recursive=True,
                strip_decorators=()):
        namespace['ConversionOptions'] = converter.ConversionOptions

        node, source = parser.parse_entity(test_fn)
        node = node.body[0]
        if namer is None:
            namer = FakeNamer()
        program_ctx = converter.ProgramContext(
            options=converter.ConversionOptions(
                recursive=recursive, strip_decorators=strip_decorators),
            partial_types=None,
            autograph_module=None,
            uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
        entity_info = transformer.EntityInfo(source_code=source,
                                             source_file='<fragment>',
                                             namespace=namespace,
                                             arg_values=None,
                                             arg_types=arg_types,
                                             owner_type=owner_type)
        ctx = converter.EntityContext(namer, entity_info, program_ctx)
        node = converter.standard_analysis(node, ctx, is_initial=True)
        return node, ctx
Ejemplo n.º 2
0
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None):
    """Specialization of `entity_to_graph` for callable functions."""

    node, source = parser.parse_entity(f)
    node = node.body[0]
    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve(node, source, f)
    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = program_ctx.new_namer(namespace)

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         namespace=namespace,
                                         arg_values=arg_values,
                                         arg_types=arg_types,
                                         owner_type=owner_type)
    context = converter.EntityContext(namer, entity_info, program_ctx)
    node = node_to_graph(node, context)

    # TODO(mdan): This somewhat duplicates the call rename logic in call_trees.py
    new_name, did_rename = namer.compiled_function_name(
        f.__name__, f, owner_type)
    if not did_rename:
        new_name = f.__name__
        if node.name != f.__name__:
            raise NotImplementedError(
                'Strange corner case. Send us offending code!')
    node.name = new_name

    program_ctx.update_name_map(namer)
    # TODO(mdan): Use this at compilation.

    return [node], new_name, namespace
Ejemplo n.º 3
0
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None):
    """Specialization of `entity_to_graph` for callable functions."""

    node, source = parser.parse_entity(f)
    node = node.body[0]

    # In general, the output of inspect.getsource is inexact because it uses
    # regex matching to adjust the exact location around the line number that
    # CPython records. This is particularly problematic for lambda functions,
    # where the entire containing lines are returned.
    nodes = ast_util.find_matching_definitions(node, f)
    if len(nodes) != 1:
        if f.__name__ == '<lambda>':
            raise ValueError(
                'Unable to identify source code of lambda function {}. It was'
                ' defined on this line: {}, which must contain a single lambda with'
                ' matching signature. To avoid ambiguity, define each lambda'
                ' in a separate expression.'.format(f, source))
        else:
            raise ValueError(
                'Unable to identify source code of function {}. The source code'
                ' reported by Python did not include exactly one matching signature:'
                '\n{}\n. This is an extremely rare occurrence. Please report it to'
                ' the TensorFlow team.'.format(f, source))
    node, = nodes

    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve(node, source, f)
    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = program_ctx.new_namer(namespace)

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         namespace=namespace,
                                         arg_values=arg_values,
                                         arg_types=arg_types,
                                         owner_type=owner_type)
    context = converter.EntityContext(namer, entity_info, program_ctx)
    node = node_to_graph(node, context)

    if isinstance(node, gast.Lambda):
        new_name = namer.new_symbol('tf__lambda', ())
        node = gast.Assign(targets=[gast.Name(new_name, gast.Store(), None)],
                           value=node)

    else:
        # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py
        new_name, did_rename = namer.compiled_function_name(
            f.__name__, f, owner_type)
        if did_rename:
            node.name = new_name
        else:
            new_name = f.__name__
            assert node.name == new_name

    program_ctx.update_name_map(namer)
    # TODO(mdan): Use this at compilation.

    return [node], new_name, namespace
Ejemplo n.º 4
0
def convert_func_to_ast(f, program_ctx, do_rename=True):
    """Specialization of `convert_entity_to_ast` for callable functions."""

    future_features = inspect_utils.getfutureimports(f)
    node, source = parser.parse_entity(f, future_features=future_features)
    logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
    # Parsed AST should contain future imports and one function def node.

    # In general, the output of inspect.getsource is inexact for lambdas because
    # it uses regex matching to adjust the exact location around the line number
    # that CPython records. Then, the entire containing line is returned, which
    # we may have trouble disambiguating. For example:
    # x, y = lambda: 1, lambda: 2
    if f.__name__ == '<lambda>':
        nodes = ast_util.find_matching_definitions(node, f)
        if len(nodes) != 1:
            raise ValueError(
                'Unable to identify source code of lambda function {}. It was'
                ' defined on this line: {}, which must contain a single lambda with'
                ' matching signature. To avoid ambiguity, define each lambda'
                ' in a separate expression.'.format(f, source))
        node, = nodes

    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve_entity(node, source, f)

    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = naming.Namer(namespace)

    if isinstance(node, gast.Lambda):
        new_name = namer.new_symbol('tf__lambda', ())
    elif do_rename:
        new_name = namer.function_name(f.__name__)
    else:
        new_name = f.__name__

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         future_features=future_features,
                                         namespace=namespace)
    context = converter.EntityContext(namer, entity_info, program_ctx,
                                      new_name)
    node = node_to_graph(node, context)

    if isinstance(node, gast.Lambda):
        node = gast.Assign(targets=[
            gast.Name(new_name,
                      ctx=gast.Store(),
                      annotation=None,
                      type_comment=None)
        ],
                           value=node)
    elif do_rename:
        node.name = new_name
    else:
        assert node.name == new_name

    return (node, ), new_name, entity_info
Ejemplo n.º 5
0
def function_to_graph(f,
                      program_ctx,
                      arg_values,
                      arg_types,
                      owner_type=None):
  """Specialization of `entity_to_graph` for callable functions."""

  node, source = parser.parse_entity(f)
  node = node.body[0]

  # TODO(mdan): Can we convert everything and scoop the lambda afterwards?
  if f.__name__ == '<lambda>':
    nodes = ast_util.find_matching_lambda_definitions(node, f)
    if len(nodes) != 1:
      raise ValueError(
          'Unable to identify source code of lambda function {}. It was'
          ' defined on this line: {}, which contains multiple lambdas with'
          ' identical argument names. To avoid ambiguity, define each lambda'
          ' in a separate expression.'.format(f, source))
    node, = nodes

  # TODO(znado): Place inside standard_analysis.
  origin_info.resolve(node, source, f)
  namespace = inspect_utils.getnamespace(f)
  _add_self_references(namespace, program_ctx.autograph_module)
  namer = program_ctx.new_namer(namespace)

  entity_info = transformer.EntityInfo(
      source_code=source,
      source_file='<fragment>',
      namespace=namespace,
      arg_values=arg_values,
      arg_types=arg_types,
      owner_type=owner_type)
  context = converter.EntityContext(namer, entity_info, program_ctx)
  node = node_to_graph(node, context)

  if isinstance(node, gast.Lambda):
    new_name = namer.new_symbol('tf__lambda', ())
    node = gast.Assign(
        targets=[gast.Name(new_name, gast.Store(), None)], value=node)

  else:
    # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py
    new_name, did_rename = namer.compiled_function_name(f.__name__, f,
                                                        owner_type)
    if did_rename:
      node.name = new_name
    else:
      new_name = f.__name__
      assert node.name == new_name

  program_ctx.update_name_map(namer)
  # TODO(mdan): Use this at compilation.

  return [node], new_name, namespace
Ejemplo n.º 6
0
def function_to_graph(f, program_ctx, arg_values, arg_types, do_rename=True):
    """Specialization of `entity_to_graph` for callable functions."""

    node, source = parser.parse_entity(f)
    logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
    node = node.body[0]

    # In general, the output of inspect.getsource is inexact for lambdas because
    # it uses regex matching to adjust the exact location around the line number
    # that CPython records. Then, the entire containing line is returned, which
    # we may have trouble disambiguating. For example:
    # x, y = lambda: 1, lambda: 2
    if f.__name__ == '<lambda>':
        nodes = ast_util.find_matching_definitions(node, f)
        if len(nodes) != 1:
            raise ValueError(
                'Unable to identify source code of lambda function {}. It was'
                ' defined on this line: {}, which must contain a single lambda with'
                ' matching signature. To avoid ambiguity, define each lambda'
                ' in a separate expression.'.format(f, source))
        node, = nodes

    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve(node, source, f)
    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = naming.Namer(namespace)

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         namespace=namespace,
                                         arg_values=arg_values,
                                         arg_types=arg_types)
    context = converter.EntityContext(namer, entity_info, program_ctx)
    try:
        node = node_to_graph(node, context)
    except (ValueError, AttributeError, KeyError, NotImplementedError) as e:
        logging.error(1, 'Error converting %s', f, exc_info=True)
        raise errors.InternalError('conversion', e)
        # TODO(mdan): Catch and rethrow syntax errors.

    if isinstance(node, gast.Lambda):
        new_name = namer.new_symbol('tf__lambda', ())
        node = gast.Assign(targets=[gast.Name(new_name, gast.Store(), None)],
                           value=node)

    elif do_rename:
        # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py
        new_name = namer.function_name(f.__name__)
        node.name = new_name
    else:
        new_name = f.__name__
        assert node.name == new_name

    return [node], new_name, namespace
Ejemplo n.º 7
0
def _parse_and_analyze(f):
    """Performs preliminary analyses and transformations.

  The goal is to massage the source program into a form on which
  the `_AutoBatchingTransformer` below will be successful.

  Args:
    f: Function to analyze

  Returns:
    node: A Python AST node representing the function, suitable for
      passing to `_AutoBatchingTransformer.visit`
    entity_info: An AutoGraph `EntityInfo` object, with some information
      about `f`.  Required for initializing `_AutoBatchingTransformer`.
  """
    namespace = {}

    # Get the AST of the function
    future_features = inspect_utils.getfutureimports(f)
    node, _ = parser.parse_entity(f, future_features=future_features)

    # Boilerplate for AutoGraph transforms
    entity_info = transformer.EntityInfo(source_code='',
                                         source_file=None,
                                         future_features=future_features,
                                         namespace=namespace)
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(recursive=True),
        autograph_module=None)
    ctx = converter.EntityContext(namer=naming.Namer(namespace),
                                  entity_info=entity_info,
                                  program_ctx=program_ctx)

    # Canonicalize away break statements
    node = converter.standard_analysis(node, ctx, is_initial=True)
    node = break_statements.transform(node, ctx)

    # Canonicalize away continue statements
    node = converter.standard_analysis(node, ctx, is_initial=False)
    node = continue_statements.transform(node, ctx)

    # Force single returns
    node = converter.standard_analysis(node, ctx, is_initial=False)
    node = return_statements.transform(node, ctx, default_to_null_return=False)

    # Transform into ANF
    node = anf.transform(node, ctx)
    node = converter.standard_analysis(node, ctx, is_initial=False)

    return node, ctx
Ejemplo n.º 8
0
    def prepare(self, test_fn, namespace, recursive=True):
        namespace['ConversionOptions'] = converter.ConversionOptions

        future_features = ('print_function', 'division')
        node, source = parser.parse_entity(test_fn,
                                           future_features=future_features)
        namer = naming.Namer(namespace)
        program_ctx = converter.ProgramContext(
            options=converter.ConversionOptions(recursive=recursive),
            autograph_module=None)
        entity_info = transformer.EntityInfo(source_code=source,
                                             source_file='<fragment>',
                                             future_features=future_features,
                                             namespace=namespace)
        ctx = converter.EntityContext(namer, entity_info, program_ctx)
        node = converter.standard_analysis(node, ctx, is_initial=True)
        return node, ctx
    def prepare(self, test_fn, namespace, arg_types=None, recursive=True):
        namespace['ConversionOptions'] = converter.ConversionOptions

        node, source, _ = parser.parse_entity(test_fn)
        namer = naming.Namer(namespace)
        program_ctx = converter.ProgramContext(
            options=converter.ConversionOptions(recursive=recursive),
            autograph_module=None)
        entity_info = transformer.EntityInfo(source_code=source,
                                             source_file='<fragment>',
                                             namespace=namespace,
                                             arg_values=None,
                                             arg_types=arg_types)
        ctx = converter.EntityContext(namer, entity_info, program_ctx)
        origin_info.resolve(node, source, test_fn)
        node = converter.standard_analysis(node, ctx, is_initial=True)
        return node, ctx
Ejemplo n.º 10
0
    def test_to_ast(self):
        opts = converter.ConversionOptions()

        namer = converter_testing.FakeNamer()
        program_ctx = converter.ProgramContext(options=opts,
                                               partial_types=None,
                                               autograph_module=None,
                                               uncompiled_modules=())
        entity_info = transformer.EntityInfo(source_code='',
                                             source_file='<fragment>',
                                             namespace={},
                                             arg_values=None,
                                             arg_types={},
                                             owner_type=None)
        ctx = converter.EntityContext(namer, entity_info, program_ctx)
        opts_ast = opts.to_ast(ctx)

        template = '''
    def test_fn():
      return opts_ast
    '''
        opts_packed = templates.replace(template, opts_ast=opts_ast)

        reparsed, _ = compiler.ast_to_object(opts_packed)
        reparsed.__dict__['ag__'] = self.make_fake_mod(
            'fake_ag', converter.ConversionOptions, converter.Feature)

        reparsed_opts = reparsed.test_fn()

        self.assertEqual(opts.recursive, reparsed_opts.recursive)
        self.assertEqual(opts.verbose, reparsed_opts.verbose)
        self.assertEqual(opts.force_conversion, reparsed_opts.force_conversion)
        self.assertEqual(opts.internal_convert_user_code,
                         reparsed_opts.internal_convert_user_code)
        self.assertEqual(opts.optional_features,
                         reparsed_opts.optional_features)
Ejemplo n.º 11
0
def _parse_and_analyze(f, autobatch_functions):
    """Performs preliminary analyses and transformations.

  The goal is to massage the source program into a form on which
  the `_AutoBatchingTransformer` below will be successful.

  Args:
    f: Function to analyze
    autobatch_functions: List of Python `str` names of autobatched functions.
      Arguments to these functions will be canonicalized to variable references,
      but others will not.

  Returns:
    node: A Python AST node representing the function, suitable for
      passing to `_AutoBatchingTransformer.visit`
    entity_info: An AutoGraph `EntityInfo` object, with some information
      about `f`.  Required for initializing `_AutoBatchingTransformer`.
  """
    namespace = {}

    # Get the AST of the function
    future_features = inspect_utils.getfutureimports(f)
    node, _ = parser.parse_entity(f, future_features=future_features)

    # Boilerplate for AutoGraph transforms
    entity_info = transformer.EntityInfo(source_code='',
                                         source_file=None,
                                         future_features=future_features,
                                         namespace=namespace)
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(recursive=True),
        autograph_module=None)
    ctx = converter.EntityContext(namer=naming.Namer(namespace),
                                  entity_info=entity_info,
                                  program_ctx=program_ctx)

    # Canonicalize away break statements
    node = converter.standard_analysis(node, ctx, is_initial=True)
    node = break_statements.transform(node, ctx)

    # Canonicalize away continue statements
    node = converter.standard_analysis(node, ctx, is_initial=False)
    node = continue_statements.transform(node, ctx)

    # Force single returns
    node = converter.standard_analysis(node, ctx, is_initial=False)
    node = return_statements.transform(node, ctx, default_to_null_return=False)

    # Transform into ANF
    # Replacing if tests and autobatched function call arguments because
    # that's where divergence can happen.
    # Replacing all function calls because the downstream transformation
    # expects calls to lead directly to assignments.
    def maybe_replace_function_argument(parent, field_name, child):
        del field_name, child
        if not anno.hasanno(parent.func, anno.Basic.QN):
            return False
        func_name = anno.getanno(parent.func, anno.Basic.QN)
        if str(func_name) in autobatch_functions:
            return True
        return False

    anf_config = [
        (anf.ASTEdgePattern(gast.If, 'test', anf.ANY), anf.REPLACE),
        (anf.ASTEdgePattern(anf.ANY, anf.ANY, gast.Call), anf.REPLACE),
        (anf.ASTEdgePattern(gast.Call, 'args',
                            anf.ANY), maybe_replace_function_argument),
    ]
    node = anf.transform(node, ctx, config=anf_config)
    node = converter.standard_analysis(node, ctx, is_initial=False)

    return node, ctx
Ejemplo n.º 12
0
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None):
    """Specialization of `entity_to_graph` for callable functions."""

    node, source = parser.parse_entity(f)
    node = node.body[0]

    # In general, the output of inspect.getsource is inexact because it uses crude
    # regex matching methods to search the source file. This is particularly
    # problematic for lambda functions, where the entire containing lines are
    # returned. Certain distributions of CPython may also return the enclosing
    # function for local functions.
    nodes = ast_util.find_matching_definitions(node, f)
    if len(nodes) != 1:
        if f.__name__ == '<lambda>':
            raise ValueError(
                'Unable to identify source code of lambda function {}. It was'
                ' defined on this line: {}, which must contain a single lambda with'
                ' matching signature. To avoid ambiguity, define each lambda'
                ' in a separate expression.'.format(f, source))
        else:
            # The inspect.getsource bug is currently known to occur in the Windows
            # integration tests which run Python 3.6.
            # TODO(mdan): Find out eaxctly which distribution of Python is that.
            raise ValueError(
                'Unable to identify source code of function {}. The source code'
                ' reported by Python did not include exactly one matching signature:'
                '\n{}\nTo avoid ambiguity, use a unique name for each'
                ' function.\nNote that some distributions of Python may report source'
                ' code incorrectly. It may be possible to avoid that bug by'
                ' organizing the code into smaller units (smaller files, functions or'
                ' classes), or by turning AutoGraph off.'.format(f, source))
    node, = nodes

    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve(node, source, f)
    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = program_ctx.new_namer(namespace)

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         namespace=namespace,
                                         arg_values=arg_values,
                                         arg_types=arg_types,
                                         owner_type=owner_type)
    context = converter.EntityContext(namer, entity_info, program_ctx)
    node = node_to_graph(node, context)

    if isinstance(node, gast.Lambda):
        new_name = namer.new_symbol('tf__lambda', ())
        node = gast.Assign(targets=[gast.Name(new_name, gast.Store(), None)],
                           value=node)

    else:
        # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py
        new_name, did_rename = namer.compiled_function_name(
            f.__name__, f, owner_type)
        if did_rename:
            node.name = new_name
        else:
            new_name = f.__name__
            assert node.name == new_name

    program_ctx.update_name_map(namer)
    # TODO(mdan): Use this at compilation.

    return [node], new_name, namespace