示例#1
0
def _parse_and_analyze(f):
    """Performs preliminary analyses and transformations.

  The goal is to massage the source program into a form on which
  the `_AutoBatchingTransformer` below will be successful.

  Args:
    f: Function to analyze

  Returns:
    node: A Python AST node representing the function, suitable for
      passing to `_AutoBatchingTransformer.visit`
    entity_info: An AutoGraph `EntityInfo` object, with some information
      about `f`.  Required for initializing `_AutoBatchingTransformer`.
  """
    namespace = {}

    # Get the AST of the function
    future_features = inspect_utils.getfutureimports(f)
    node, _ = parser.parse_entity(f, future_features=future_features)

    # Boilerplate for AutoGraph transforms
    entity_info = transformer.EntityInfo(source_code='',
                                         source_file=None,
                                         future_features=future_features,
                                         namespace=namespace)
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(recursive=True),
        autograph_module=None)
    ctx = converter.EntityContext(namer=naming.Namer(namespace),
                                  entity_info=entity_info,
                                  program_ctx=program_ctx)

    # Canonicalize away break statements
    node = converter.standard_analysis(node, ctx, is_initial=True)
    node = break_statements.transform(node, ctx)

    # Canonicalize away continue statements
    node = converter.standard_analysis(node, ctx, is_initial=False)
    node = continue_statements.transform(node, ctx)

    # Force single returns
    node = converter.standard_analysis(node, ctx, is_initial=False)
    node = return_statements.transform(node, ctx, default_to_null_return=False)

    # Transform into ANF
    node = anf.transform(node, ctx)
    node = converter.standard_analysis(node, ctx, is_initial=False)

    return node, ctx
示例#2
0
    def prepare(self,
                test_fn,
                namespace,
                namer=None,
                arg_types=None,
                owner_type=None,
                recursive=True,
                strip_decorators=()):
        namespace['ConversionOptions'] = converter.ConversionOptions

        node, source = parser.parse_entity(test_fn)
        node = node.body[0]
        if namer is None:
            namer = FakeNamer()
        program_ctx = converter.ProgramContext(
            options=converter.ConversionOptions(
                recursive=recursive, strip_decorators=strip_decorators),
            partial_types=None,
            autograph_module=None,
            uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
        entity_info = transformer.EntityInfo(source_code=source,
                                             source_file='<fragment>',
                                             namespace=namespace,
                                             arg_values=None,
                                             arg_types=arg_types,
                                             owner_type=owner_type)
        ctx = converter.EntityContext(namer, entity_info, program_ctx)
        node = converter.standard_analysis(node, ctx, is_initial=True)
        return node, ctx
示例#3
0
  def prepare(self,
              test_fn,
              namespace,
              namer=None,
              arg_types=None,
              owner_type=None,
              recursive=True,
              strip_decorators=()):
    namespace['ConversionOptions'] = converter.ConversionOptions

    node, source = parser.parse_entity(test_fn)
    node = node.body[0]
    if namer is None:
      namer = FakeNamer()
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(
            recursive=recursive,
            strip_decorators=strip_decorators,
            verbose=True),
        partial_types=None,
        autograph_module=None,
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    entity_info = transformer.EntityInfo(
        source_code=source,
        source_file='<fragment>',
        namespace=namespace,
        arg_values=None,
        arg_types=arg_types,
        owner_type=owner_type)
    ctx = converter.EntityContext(namer, entity_info, program_ctx)
    origin_info.resolve(node, source, test_fn)
    node = converter.standard_analysis(node, ctx, is_initial=True)
    return node, ctx
示例#4
0
    def converted(self, entity, converter_module, namespace, *tf_symbols):
        node, ctx = self.prepare(entity, namespace)

        if not isinstance(converter_module, (list, tuple)):
            converter_module = (converter_module, )
        for m in converter_module:
            node = m.transform(node, ctx)
            node = converter.standard_analysis(node, ctx, is_initial=True)

        with self.compiled(node, namespace, *tf_symbols) as result:
            yield result
  def converted(self, entity, converter_module, namespace, *tf_symbols):
    node, ctx = self.prepare(entity, namespace)

    if not isinstance(converter_module, (list, tuple)):
      converter_module = (converter_module,)
    for i, m in enumerate(converter_module):
      node = converter.standard_analysis(node, ctx, is_initial=not i)
      node = m.transform(node, ctx)

    with self.compiled(node, namespace, *tf_symbols) as result:
      yield result
示例#6
0
def node_to_graph(node, context):
    """Convert Python code to equivalent TF graph mode code.

  Args:
    node: AST, the code to convert.
    context: converter.EntityContext

  Returns:
    A tuple (node, deps):
        * node: A Python ast node, representing the converted code.
        * deps: A set of strings, the fully qualified names of entity
            dependencies that this node has.
  """
    # TODO(mdan): Insert list_comprehensions somewhere.
    unsupported_features_checker.verify(node)

    node = converter.standard_analysis(node, context, is_initial=True)
    # Past this point, line numbers are no longer accurate so we ignore the
    # source.
    # TODO(mdan): Is it feasible to reconstruct intermediate source code?
    context.info.source_code = None
    node = converter.apply_(node, context, arg_defaults)
    node = converter.apply_(node, context, directives)
    node = converter.apply_(node, context, break_statements)
    if context.program.options.uses(converter.Feature.ASSERT_STATEMENTS):
        node = converter.apply_(node, context, asserts)
    # Note: sequencing continue canonicalization before for loop one avoids
    # dealing with the extra loop increment operation that the for
    # canonicalization creates.
    node = converter.apply_(node, context, continue_statements)
    node = converter.apply_(node, context, return_statements)
    if context.program.options.uses(converter.Feature.LISTS):
        node = converter.apply_(node, context, lists)
        node = converter.apply_(node, context, slices)
    if context.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS):
        node = converter.apply_(node, context, builtin_functions)
    node = converter.apply_(node, context, call_trees)
    node = converter.apply_(node, context, control_flow)
    node = converter.apply_(node, context, conditional_expressions)
    if context.program.options.uses(converter.Feature.LOGICAL_EXPRESSIONS):
        node = converter.apply_(node, context, logical_expressions)
    if context.program.options.uses(converter.Feature.AUTO_CONTROL_DEPS):
        node = converter.apply_(node, context, side_effect_guards)
    # TODO(mdan): If function scopes ever does more, the toggle will need moving.
    if context.program.options.uses(converter.Feature.NAME_SCOPES):
        node = converter.apply_(node, context, function_scopes)
    if context.program.options.uses(converter.Feature.ERROR_REWRITING):
        node = converter.apply_(node, context, error_handlers)
    return node
示例#7
0
def node_to_graph(node, context):
  """Convert Python code to equivalent TF graph mode code.

  Args:
    node: AST, the code to convert.
    context: converter.EntityContext

  Returns:
    A tuple (node, deps):
        * node: A Python ast node, representing the converted code.
        * deps: A set of strings, the fully qualified names of entity
            dependencies that this node has.
  """
  # TODO(mdan): Insert list_comprehensions somewhere.
  unsupported_features_checker.verify(node)

  node = converter.standard_analysis(node, context, is_initial=True)
  # Past this point, line numbers are no longer accurate so we ignore the
  # source.
  # TODO(mdan): Is it feasible to reconstruct intermediate source code?
  context.info.source_code = None
  node = converter.apply_(node, context, arg_defaults)
  node = converter.apply_(node, context, directives)
  node = converter.apply_(node, context, break_statements)
  if context.program.options.uses(converter.Feature.ASSERT_STATEMENTS):
    node = converter.apply_(node, context, asserts)
  # Note: sequencing continue canonicalization before for loop one avoids
  # dealing with the extra loop increment operation that the for
  # canonicalization creates.
  node = converter.apply_(node, context, continue_statements)
  node = converter.apply_(node, context, return_statements)
  if context.program.options.uses(converter.Feature.LISTS):
    node = converter.apply_(node, context, lists)
    node = converter.apply_(node, context, slices)
  if context.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS):
    node = converter.apply_(node, context, builtin_functions)
  node = converter.apply_(node, context, call_trees)
  node = converter.apply_(node, context, control_flow)
  node = converter.apply_(node, context, conditional_expressions)
  if context.program.options.uses(converter.Feature.LOGICAL_EXPRESSIONS):
    node = converter.apply_(node, context, logical_expressions)
  if context.program.options.uses(converter.Feature.AUTO_CONTROL_DEPS):
    node = converter.apply_(node, context, side_effect_guards)
  # TODO(mdan): If function scopes ever does more, the toggle will need moving.
  if context.program.options.uses(converter.Feature.NAME_SCOPES):
    node = converter.apply_(node, context, function_scopes)
  if context.program.options.uses(converter.Feature.ERROR_REWRITING):
    node = converter.apply_(node, context, error_handlers)
  return node
示例#8
0
def node_to_graph(node, context, rewrite_errors=True):
    """Convert Python code to equivalent TF graph mode code.

  Args:
    node: AST, the code to convert.
    context: converter.EntityContext
    rewrite_errors: Boolean, whether or not to rewrite the error traceback.

  Returns:
    A tuple (node, deps):
        * node: A Python ast node, representing the converted code.
        * deps: A set of strings, the fully qualified names of entity
            dependencies that this node has.
  """
    # TODO(mdan): Insert list_comprehensions somewhere.

    node = converter.standard_analysis(node, context, is_initial=True)
    # Past this point, line numbers are no longer accurate so we ignore the
    # source.
    # TODO(mdan): Is it feasible to reconstruct intermediate source code?
    context.info.source_code = None

    if context.program.options.uses(converter.Feature.DECORATORS):
        node = converter.apply_(node, context, decorators)
    node = converter.apply_(node, context, arg_defaults)
    node = converter.apply_(node, context, directives)
    node = converter.apply_(node, context, break_statements)
    node = converter.apply_(node, context, asserts)
    # Note: sequencing continue canonicalization before for loop one avoids
    # dealing with the extra loop increment operation that the for
    # canonicalization creates.
    node = converter.apply_(node, context, continue_statements)
    node = converter.apply_(node, context, return_statements)
    if context.program.options.uses(converter.Feature.LISTS):
        node = converter.apply_(node, context, lists)
        node = converter.apply_(node, context, slices)
    node = converter.apply_(node, context, builtin_functions)
    node = converter.apply_(node, context, call_trees)
    node = converter.apply_(node, context, control_flow)
    node = converter.apply_(node, context, conditional_expressions)
    node = converter.apply_(node, context, logical_expressions)
    if context.program.options.uses(converter.Feature.AUTO_CONTROL_DEPS):
        node = converter.apply_(node, context, side_effect_guards)
    node = converter.apply_(node, context, function_scopes)
    if rewrite_errors:
        node = converter.apply_(node, context, error_handlers)
    return node
示例#9
0
    def prepare(self, test_fn, namespace, recursive=True):
        namespace['ConversionOptions'] = converter.ConversionOptions

        future_features = ('print_function', 'division')
        node, source = parser.parse_entity(test_fn,
                                           future_features=future_features)
        namer = naming.Namer(namespace)
        program_ctx = converter.ProgramContext(
            options=converter.ConversionOptions(recursive=recursive),
            autograph_module=None)
        entity_info = transformer.EntityInfo(source_code=source,
                                             source_file='<fragment>',
                                             future_features=future_features,
                                             namespace=namespace)
        ctx = converter.EntityContext(namer, entity_info, program_ctx)
        node = converter.standard_analysis(node, ctx, is_initial=True)
        return node, ctx
    def prepare(self, test_fn, namespace, arg_types=None, recursive=True):
        namespace['ConversionOptions'] = converter.ConversionOptions

        node, source, _ = parser.parse_entity(test_fn)
        namer = naming.Namer(namespace)
        program_ctx = converter.ProgramContext(
            options=converter.ConversionOptions(recursive=recursive),
            autograph_module=None)
        entity_info = transformer.EntityInfo(source_code=source,
                                             source_file='<fragment>',
                                             namespace=namespace,
                                             arg_values=None,
                                             arg_types=arg_types)
        ctx = converter.EntityContext(namer, entity_info, program_ctx)
        origin_info.resolve(node, source, test_fn)
        node = converter.standard_analysis(node, ctx, is_initial=True)
        return node, ctx
示例#11
0
def node_to_graph(node, context, rewrite_errors=True):
  """Convert Python code to equivalent TF graph mode code.

  Args:
    node: AST, the code to convert.
    context: converter.EntityContext
    rewrite_errors: Boolean, whether or not to rewrite the error traceback.

  Returns:
    A tuple (node, deps):
        * node: A Python ast node, representing the converted code.
        * deps: A set of strings, the fully qualified names of entity
            dependencies that this node has.
  """
  # TODO(mdan): Insert list_comprehensions somewhere.

  node = converter.standard_analysis(node, context, is_initial=True)
  # Past this point, line numbers are no longer accurate so we ignore the
  # source.
  # TODO(mdan): Is it feasible to reconstruct intermediate source code?
  context.info.source_code = None

  node = converter.apply_(node, context, decorators)
  node = converter.apply_(node, context, directives)
  node = converter.apply_(node, context, break_statements)
  node = converter.apply_(node, context, asserts)
  # Note: sequencing continue canonicalization before for loop one avoids
  # dealing with the extra loop increment operation that the for
  # canonicalization creates.
  node = converter.apply_(node, context, continue_statements)
  node = converter.apply_(node, context, return_statements)
  if context.program.options.uses(converter.Feature.LISTS):
    node = converter.apply_(node, context, lists)
    node = converter.apply_(node, context, slices)
  node = converter.apply_(node, context, builtin_functions)
  node = converter.apply_(node, context, call_trees)
  node = converter.apply_(node, context, control_flow)
  node = converter.apply_(node, context, conditional_expressions)
  node = converter.apply_(node, context, logical_expressions)
  if context.program.options.uses(converter.Feature.AUTO_CONTROL_DEPS):
    node = converter.apply_(node, context, side_effect_guards)
  node = converter.apply_(node, context, function_scopes)
  if rewrite_errors:
    node = converter.apply_(node, context, error_handlers)
  return node
示例#12
0
  def prepare(self, test_fn, namespace, recursive=True):
    namespace['ConversionOptions'] = converter.ConversionOptions

    future_features = ('print_function', 'division')
    node, source = parser.parse_entity(test_fn, future_features=future_features)
    namer = naming.Namer(namespace)
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(recursive=recursive),
        autograph_module=None)
    entity_info = transformer.EntityInfo(
        source_code=source,
        source_file='<fragment>',
        future_features=future_features,
        namespace=namespace)
    ctx = converter.EntityContext(namer, entity_info, program_ctx)
    origin_info.resolve(node, source, test_fn)
    node = converter.standard_analysis(node, ctx, is_initial=True)
    return node, ctx
    def transform_ast(self, node, ctx):
        # TODO(mdan): Insert list_comprehensions somewhere.
        unsupported_features_checker.verify(node)

        node = converter.standard_analysis(node, ctx, is_initial=True)
        node = converter.apply_(node, ctx, functions)
        node = converter.apply_(node, ctx, directives)
        node = converter.apply_(node, ctx, break_statements)
        if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS):
            node = converter.apply_(node, ctx, asserts)
        # Note: sequencing continue canonicalization before for loop one avoids
        # dealing with the extra loop increment operation that the for
        # canonicalization creates.
        node = converter.apply_(node, ctx, continue_statements)
        node = converter.apply_(node, ctx, return_statements)
        if ctx.user.options.uses(converter.Feature.LISTS):
            node = converter.apply_(node, ctx, lists)
            node = converter.apply_(node, ctx, slices)
        node = converter.apply_(node, ctx, call_trees)
        node = converter.apply_(node, ctx, control_flow)
        node = converter.apply_(node, ctx, conditional_expressions)
        node = converter.apply_(node, ctx, logical_expressions)
        return node
示例#14
0
def node_to_graph(node, context):
    """Convert Python code to equivalent TF graph mode code.

  Args:
    node: AST, the code to convert.
    context: converter.EntityContext

  Returns:
    A tuple (node, deps):
        * node: A Python ast node, representing the converted code.
        * deps: A set of strings, the fully qualified names of entity
            dependencies that this node has.
  """
    # TODO(mdan): Insert list_comprehensions somewhere.
    unsupported_features_checker.verify(node)

    node = converter.standard_analysis(node, context, is_initial=True)
    node = converter.apply_(node, context, function_scopes)
    node = converter.apply_(node, context, arg_defaults)
    node = converter.apply_(node, context, directives)
    node = converter.apply_(node, context, break_statements)
    if context.program.options.uses(converter.Feature.ASSERT_STATEMENTS):
        node = converter.apply_(node, context, asserts)
    # Note: sequencing continue canonicalization before for loop one avoids
    # dealing with the extra loop increment operation that the for
    # canonicalization creates.
    node = converter.apply_(node, context, continue_statements)
    node = converter.apply_(node, context, return_statements)
    if context.program.options.uses(converter.Feature.LISTS):
        node = converter.apply_(node, context, lists)
        node = converter.apply_(node, context, slices)
    node = converter.apply_(node, context, call_trees)
    node = converter.apply_(node, context, control_flow)
    node = converter.apply_(node, context, conditional_expressions)
    node = converter.apply_(node, context, logical_expressions)
    return node
示例#15
0
def _parse_and_analyze(f, autobatch_functions):
    """Performs preliminary analyses and transformations.

  The goal is to massage the source program into a form on which
  the `_AutoBatchingTransformer` below will be successful.

  Args:
    f: Function to analyze
    autobatch_functions: List of Python `str` names of autobatched functions.
      Arguments to these functions will be canonicalized to variable references,
      but others will not.

  Returns:
    node: A Python AST node representing the function, suitable for
      passing to `_AutoBatchingTransformer.visit`
    entity_info: An AutoGraph `EntityInfo` object, with some information
      about `f`.  Required for initializing `_AutoBatchingTransformer`.
  """
    namespace = {}

    # Get the AST of the function
    future_features = inspect_utils.getfutureimports(f)
    node, _ = parser.parse_entity(f, future_features=future_features)

    # Boilerplate for AutoGraph transforms
    entity_info = transformer.EntityInfo(source_code='',
                                         source_file=None,
                                         future_features=future_features,
                                         namespace=namespace)
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(recursive=True),
        autograph_module=None)
    ctx = converter.EntityContext(namer=naming.Namer(namespace),
                                  entity_info=entity_info,
                                  program_ctx=program_ctx)

    # Canonicalize away break statements
    node = converter.standard_analysis(node, ctx, is_initial=True)
    node = break_statements.transform(node, ctx)

    # Canonicalize away continue statements
    node = converter.standard_analysis(node, ctx, is_initial=False)
    node = continue_statements.transform(node, ctx)

    # Force single returns
    node = converter.standard_analysis(node, ctx, is_initial=False)
    node = return_statements.transform(node, ctx, default_to_null_return=False)

    # Transform into ANF
    # Replacing if tests and autobatched function call arguments because
    # that's where divergence can happen.
    # Replacing all function calls because the downstream transformation
    # expects calls to lead directly to assignments.
    def maybe_replace_function_argument(parent, field_name, child):
        del field_name, child
        if not anno.hasanno(parent.func, anno.Basic.QN):
            return False
        func_name = anno.getanno(parent.func, anno.Basic.QN)
        if str(func_name) in autobatch_functions:
            return True
        return False

    anf_config = [
        (anf.ASTEdgePattern(gast.If, 'test', anf.ANY), anf.REPLACE),
        (anf.ASTEdgePattern(anf.ANY, anf.ANY, gast.Call), anf.REPLACE),
        (anf.ASTEdgePattern(gast.Call, 'args',
                            anf.ANY), maybe_replace_function_argument),
    ]
    node = anf.transform(node, ctx, config=anf_config)
    node = converter.standard_analysis(node, ctx, is_initial=False)

    return node, ctx