Exemple #1
0
 def _parse_and_analyze(self, test_fn):
     node, source = parser.parse_entity(test_fn, future_features=())
     entity_info = transformer.EntityInfo(source_code=source,
                                          source_file=None,
                                          future_features=(),
                                          namespace={})
     node = qual_names.resolve(node)
     ctx = transformer.Context(entity_info)
     node = activity.resolve(node, ctx)
     graphs = cfg.build(node)
     node = reaching_definitions.resolve(node, ctx, graphs,
                                         reaching_definitions.Definition)
     return node
Exemple #2
0
 def _parse_and_analyze(self, test_fn):
     # TODO(mdan): Use a custom FunctionTransformer here.
     node, source = parser.parse_entity(test_fn, future_features=())
     entity_info = transformer.EntityInfo(name=test_fn.__name__,
                                          source_code=source,
                                          source_file=None,
                                          future_features=(),
                                          namespace={})
     node = qual_names.resolve(node)
     namer = naming.Namer({})
     ctx = transformer.Context(entity_info, namer, None)
     node = activity.resolve(node, ctx)
     return node, entity_info
Exemple #3
0
def mlir_gen_internal(node, entity_info):
  """Returns mlir module for unprocessed node `node`."""
  namer = naming.Namer({})
  graphs = cfg.build(node)
  ctx = transformer.Context(entity_info, namer, None)
  node = qual_names.resolve(node)
  node = activity.resolve(node, ctx)
  node = reaching_definitions.resolve(node, ctx, graphs)
  node = reaching_fndefs.resolve(node, ctx, graphs)
  node = liveness.resolve(node, ctx, graphs)
  mlir_generator = MLIRGen(ctx)
  mlir_generator.visit(node)
  return mlir_generator.prog
Exemple #4
0
 def _parse_and_analyze(self, test_fn):
     node, _, source = parser.parse_entity(test_fn, future_imports=())
     entity_info = transformer.EntityInfo(source_code=source,
                                          source_file=None,
                                          namespace={},
                                          arg_values=None,
                                          arg_types=None)
     node = qual_names.resolve(node)
     ctx = transformer.Context(entity_info)
     node = activity.resolve(node, ctx)
     graphs = cfg.build(node)
     liveness.resolve(node, ctx, graphs)
     return node
Exemple #5
0
 def _parse_and_analyze(self, test_fn, namespace, arg_types=None):
     node, source = parser.parse_entity(test_fn)
     entity_info = transformer.EntityInfo(source_code=source,
                                          source_file=None,
                                          namespace=namespace,
                                          arg_values=None,
                                          arg_types=arg_types)
     node = qual_names.resolve(node)
     graphs = cfg.build(node)
     ctx = transformer.Context(entity_info)
     node = activity.resolve(node, ctx)
     node = reaching_definitions.resolve(node, ctx, graphs,
                                         reaching_definitions.Definition)
     node = live_values.resolve(node, ctx, {})
     node = type_info.resolve(node, ctx)
     node = live_values.resolve(node, ctx, {})
     return node
Exemple #6
0
    def transform_function(self, fn, user_context):
        """Transforms a function.

    Subclasses may override this method. The return value is opaque.

    The method receives the original AST. The result is passed as-is to the
    output of `transform`.

    Args:
      fn: A function or lambda.
      user_context: An opaque object (may be None) that is forwarded to
        transform_ast, through the ctx.user_context argument.
    Returns:
      Any. By default it returns the output of transform_ast.
    """
        future_features = inspect_utils.getfutureimports(fn)
        node, source = parser.parse_entity(fn, future_features=future_features)
        logging.log(3, 'Source code of %s:\n\n%s\n', fn, source)

        origin_info.resolve_entity(node, source, fn)

        namespace = inspect_utils.getnamespace(fn)
        namer = naming.Namer(namespace)
        new_name = namer.new_symbol(self.get_transformed_name(node), ())
        entity_info = transformer.EntityInfo(name=new_name,
                                             source_code=source,
                                             source_file='<fragment>',
                                             future_features=future_features,
                                             namespace=namespace)
        context = transformer.Context(entity_info, namer, user_context)

        node = self._erase_arg_defaults(node)
        node = self.transform_ast(node, context)

        if isinstance(node, gast.Lambda):
            node = gast.Assign(targets=[
                gast.Name(new_name,
                          ctx=gast.Store(),
                          annotation=None,
                          type_comment=None)
            ],
                               value=node)
        else:
            node.name = new_name

        return node, context
Exemple #7
0
    def prepare(self, test_fn, namespace, recursive=True):
        namespace['ConversionOptions'] = converter.ConversionOptions

        future_features = ('print_function', 'division')
        node, source = parser.parse_entity(test_fn,
                                           future_features=future_features)
        namer = naming.Namer(namespace)
        program_ctx = converter.ProgramContext(
            options=converter.ConversionOptions(recursive=recursive),
            autograph_module=None)
        entity_info = transformer.EntityInfo(name=test_fn.__name__,
                                             source_code=source,
                                             source_file='<fragment>',
                                             future_features=future_features,
                                             namespace=namespace)
        ctx = transformer.Context(entity_info, namer, program_ctx)
        origin_info.resolve_entity(node, source, test_fn)
        node = converter.standard_analysis(node, ctx, is_initial=True)
        return node, ctx
Exemple #8
0
  def transform_function(self, fn, user_context):
    """Transforms a function.

    Subclasses may override this method. The return value is opaque.

    The method receives the original AST. The result is passed as-is to the
    output of `transform`.

    Args:
      fn: A function or lambda.
      user_context: An opaque object (may be None) that is forwarded to
        transform_ast, through the ctx.user_context argument.
    Returns:
      Tuple[Any, Any]. By default it returns the output of transform_ast,
      together with a `transformer.Context` containing information about the
      transformation process.
    """
    future_features = inspect_utils.getfutureimports(fn)
    node, source = parser.parse_entity(fn, future_features=future_features)
    logging.log(3, 'Source code of %s:\n\n%s\n', fn, source)

    origin_info.resolve_entity(node, source, fn)

    namespace = inspect_utils.getnamespace(fn)
    namer = naming.Namer(namespace)
    new_name = namer.new_symbol(self.get_transformed_name(node), ())
    entity_info = transformer.EntityInfo(
        name=new_name,
        source_code=source,
        source_file='<fragment>',
        future_features=future_features,
        namespace=namespace)
    context = transformer.Context(entity_info, namer, user_context)

    node = self._erase_arg_defaults(node)
    result = self.transform_ast(node, context)

    return result, context
Exemple #9
0
  def _transform_function(self, fn, user_context):
    """Performs source code transformation on a function."""
    future_features = inspect_utils.getfutureimports(fn)
    node, source = parser.parse_entity(fn, future_features=future_features)
    logging.log(3, 'Source code of %s:\n\n%s\n', fn, source)

    origin_info.resolve_entity(node, source, fn)

    namespace = inspect_utils.getnamespace(fn)
    namer = naming.Namer(namespace)
    new_name = namer.new_symbol(self.get_transformed_name(node), ())
    entity_info = transformer.EntityInfo(
        name=new_name,
        source_code=source,
        source_file='<fragment>',
        future_features=future_features,
        namespace=namespace)
    context = transformer.Context(entity_info, namer, user_context)

    node = self._erase_arg_defaults(node)
    node = self.transform_ast(node, context)

    if isinstance(node, gast.Lambda):
      node = gast.Assign(
          targets=[
              gast.Name(
                  new_name,
                  ctx=gast.Store(),
                  annotation=None,
                  type_comment=None)
          ],
          value=node)
    else:
      node.name = new_name

    return node, context
Exemple #10
0
def _parse_and_analyze(f, autobatch_functions):
    """Performs preliminary analyses and transformations.

  The goal is to massage the source program into a form on which
  the `_AutoBatchingTransformer` below will be successful.

  Args:
    f: Function to analyze
    autobatch_functions: List of Python `str` names of autobatched functions.
      Arguments to these functions will be canonicalized to variable references,
      but others will not.

  Returns:
    node: A Python AST node representing the function, suitable for
      passing to `_AutoBatchingTransformer.visit`
    entity_info: An AutoGraph `EntityInfo` object, with some information
      about `f`.  Required for initializing `_AutoBatchingTransformer`.
  """
    # TODO(mdan): Replace all this boilerplate with FunctionTranspiler.
    namespace = {}

    # Get the AST of the function
    future_features = inspect_utils.getfutureimports(f)
    node, _ = parser.parse_entity(f, future_features=future_features)

    # Boilerplate for AutoGraph transforms
    if hasattr(converter, 'EntityContext'):
        # TF 2.2-
        entity_info = transformer.EntityInfo(source_code='',
                                             source_file=None,
                                             future_features=future_features,
                                             namespace=namespace)
        program_ctx = converter.ProgramContext(
            options=converter.ConversionOptions(recursive=True),
            autograph_module=None)
        ctx = converter.EntityContext(namer=naming.Namer(namespace),
                                      entity_info=entity_info,
                                      program_ctx=program_ctx)
    else:
        # TF 2.3+
        entity_info = transformer.EntityInfo(name=f.__name__,
                                             source_code='',
                                             source_file=None,
                                             future_features=future_features,
                                             namespace=namespace)
        program_ctx = converter.ProgramContext(
            options=converter.ConversionOptions(recursive=True),
            autograph_module=None)
        ctx = transformer.Context(info=entity_info,
                                  namer=naming.Namer(namespace),
                                  user_context=program_ctx)

    # Canonicalize away break statements
    node = converter.standard_analysis(node, ctx)
    node = break_statements.transform(node, ctx)

    # Canonicalize away continue statements
    node = converter.standard_analysis(node, ctx)
    node = continue_statements.transform(node, ctx)

    # Force single returns
    node = converter.standard_analysis(node, ctx)
    node = return_statements.transform(node, ctx, default_to_null_return=False)

    # Transform into ANF
    # Replacing if tests and autobatched function call arguments because
    # that's where divergence can happen.
    # Replacing all function calls because the downstream transformation
    # expects calls to lead directly to assignments.
    def maybe_replace_function_argument(parent, field_name, child):
        del field_name, child
        if not anno.hasanno(parent.func, anno.Basic.QN):
            return False
        func_name = anno.getanno(parent.func, anno.Basic.QN)
        if str(func_name) in autobatch_functions:
            return True
        return False

    anf_config = [
        (anf.ASTEdgePattern(gast.If, 'test', anf.ANY), anf.REPLACE),
        (anf.ASTEdgePattern(anf.ANY, anf.ANY, gast.Call), anf.REPLACE),
        (anf.ASTEdgePattern(gast.Call, 'args',
                            anf.ANY), maybe_replace_function_argument),
    ]
    node = anf.transform(node, ctx, config=anf_config)
    node = qual_names.resolve(node)
    node = activity.resolve(node, ctx)

    return node, ctx
 def _simple_context(self):
   entity_info = transformer.EntityInfo(
       source_code=None, source_file=None, future_features=(), namespace=None)
   return transformer.Context(entity_info)
Exemple #12
0
def _live_tensors(f, attr_name="inputs"):
    """Returns the indices of the used inputs.

  Note: This currently only handles direct index accesses e.g. op.inputs[1].
  If the function has slicing or list comprehension on attr_name then returns
  _ALL. This ensure that this is correct even if inefficient.

  Args:
    f: A grad function, taking the op as first argument.
    attr_name: op attr to track. "inputs" or "outputs".

  Returns:
    Either one of:
      * set of integers representing individual indices of inputs used
      * the value _ALL, if indices are used but cannot be determined which
      * empty set, if no inputs are used
  """
    node, _ = parser.parse_entity(f, ())
    entity_info = transformer.EntityInfo(
        name=f.__name__,
        source_code=None,
        source_file=None,
        future_features=(),
        namespace=sys.modules[f.__module__].__dict__)
    ctx = transformer.Context(entity_info, None, None)

    graphs = cfg.build(node)
    node = qual_names.resolve(node)
    node = activity.resolve(node, ctx, None)
    node = reaching_fndefs.resolve(node, ctx, graphs)
    node = liveness.resolve(node, ctx, graphs)

    op_arg_name = anno.getanno(node.args.args[0], anno.Basic.QN)
    op_inputs_outputs_name = qual_names.QN(op_arg_name, attr=attr_name)

    special_tracker = _SubscriptUseTracker(ctx, (op_inputs_outputs_name, ))
    node = special_tracker.visit(node)

    live_vars_in = anno.getanno(node.body[0], anno.Static.LIVE_VARS_IN)
    inputs_outputs_used_qns = set()
    for v in special_tracker.complex_reads:
        # Complicated patterns like op.inputs[:3]. Could be smarter about them
        # if they matter much.
        if v == op_inputs_outputs_name:
            return _ALL
    for v in live_vars_in:
        if v in special_tracker.reads:
            if (v.has_subscript() and v.parent == op_inputs_outputs_name):
                inputs_outputs_used_qns.add(v)
            elif v == op_inputs_outputs_name:
                # When op.{attr_name} is used directly, assume all tensors are
                # used for now. In that case, no point digging further.
                # TODO(mdan): We can descend into tuple expansions.
                return _ALL

    function_calls_tracker = _FunctionCallsTracker(ctx, op_arg_name)
    node = function_calls_tracker.visit(node)

    input_output_indices = set()

    for called_f in function_calls_tracker.calls:
        child_indices = _live_tensors(called_f, attr_name=attr_name)
        if child_indices is _ALL:
            return _ALL
        input_output_indices |= child_indices

    for v in inputs_outputs_used_qns:
        assert v.has_subscript()
        _, subscript = v.qn
        if not subscript.is_simple():
            # Not a number, assuming it can be anything.
            return _ALL
        subscript_val, = subscript.qn
        if (not isinstance(subscript_val, qual_names.Literal)
                and not isinstance(subscript_val.value, int)):
            # Not a number, assuming it can be anything.
            return _ALL
        input_output_indices.add(subscript_val.value)
    return input_output_indices