Exemple #1
0
    def test_missing_orelse(self):
        def test_fn(x):
            if x > 0:
                return x

        node, ctx = self.prepare(test_fn, {})
        with self.assertRaises(ValueError):
            return_statements.transform(node, ctx)
  def test_missing_orelse(self):

    def test_fn(x):
      if x > 0:
        return x

    node, ctx = self.prepare(test_fn, {})
    with self.assertRaises(ValueError):
      return_statements.transform(node, ctx)
Exemple #3
0
    def test_loop(self):
        def test_fn(x):
            for _ in range(10):
                return x
            return x

        node, ctx = self.prepare(test_fn, {})
        with self.assertRaises(ValueError):
            return_statements.transform(node, ctx)
  def test_loop(self):

    def test_fn(x):
      for _ in range(10):
        return x
      return x

    node, ctx = self.prepare(test_fn, {})
    with self.assertRaises(ValueError):
      return_statements.transform(node, ctx)
    def transform_ast(self, node, ctx):
        # TODO(mdan): Insert list_comprehensions somewhere.
        unsupported_features_checker.verify(node)

        # Run initial analysis.
        graphs = cfg.build(node)
        node = qual_names.resolve(node)
        node = activity.resolve(node, ctx, None)
        node = reaching_definitions.resolve(node, ctx, graphs)
        anno.dup(
            node,
            {
                anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
            },
        )

        node = functions.transform(node, ctx)
        node = directives.transform(node, ctx)
        node = break_statements.transform(node, ctx)
        if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS):
            node = asserts.transform(node, ctx)
        # Note: sequencing continue canonicalization before for loop one avoids
        # dealing with the extra loop increment operation that the for
        # canonicalization creates.
        node = continue_statements.transform(node, ctx)
        node = return_statements.transform(node, ctx)
        if ctx.user.options.uses(converter.Feature.LISTS):
            node = lists.transform(node, ctx)
            node = slices.transform(node, ctx)
        node = call_trees.transform(node, ctx)
        node = control_flow.transform(node, ctx)
        node = conditional_expressions.transform(node, ctx)
        node = logical_expressions.transform(node, ctx)
        return node
def _apply_py_to_tf_passes(node, ctx):
  """Apply transformations from PyToTF to match tf.function tracing."""
  # TODO(fengliuai): we don't know which passes are required, thus we evalute
  # each one when the corresponding node is handled.
  # copied from PyToTF.transform_ast
  node = return_statements.transform(node, ctx, False)
  node = control_flow.transform(node, ctx)
  return node
Exemple #7
0
def _parse_and_analyze(f):
    """Performs preliminary analyses and transformations.

  The goal is to massage the source program into a form on which
  the `_AutoBatchingTransformer` below will be successful.

  Args:
    f: Function to analyze

  Returns:
    node: A Python AST node representing the function, suitable for
      passing to `_AutoBatchingTransformer.visit`
    entity_info: An AutoGraph `EntityInfo` object, with some information
      about `f`.  Required for initializing `_AutoBatchingTransformer`.
  """
    namespace = {}

    # Get the AST of the function
    future_features = inspect_utils.getfutureimports(f)
    node, _ = parser.parse_entity(f, future_features=future_features)

    # Boilerplate for AutoGraph transforms
    entity_info = transformer.EntityInfo(source_code='',
                                         source_file=None,
                                         future_features=future_features,
                                         namespace=namespace)
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(recursive=True),
        autograph_module=None)
    ctx = converter.EntityContext(namer=naming.Namer(namespace),
                                  entity_info=entity_info,
                                  program_ctx=program_ctx)

    # Canonicalize away break statements
    node = converter.standard_analysis(node, ctx, is_initial=True)
    node = break_statements.transform(node, ctx)

    # Canonicalize away continue statements
    node = converter.standard_analysis(node, ctx, is_initial=False)
    node = continue_statements.transform(node, ctx)

    # Force single returns
    node = converter.standard_analysis(node, ctx, is_initial=False)
    node = return_statements.transform(node, ctx, default_to_null_return=False)

    # Transform into ANF
    node = anf.transform(node, ctx)
    node = converter.standard_analysis(node, ctx, is_initial=False)

    return node, ctx
    def test_method(self):
        class TestClass(object):
            def test_fn(self, l):
                def inner_fn(i):
                    return i + 1

                l += 1
                return l, inner_fn(l)

        ns = {'TestClass': TestClass}
        node, ctx = self.prepare(TestClass, ns)
        node = functions.transform(node, ctx)
        node = return_statements.transform(node, ctx)

        with self.compiled(node, {}, (ops.name_scope, )) as result:
            first, second = result.TestClass().test_fn(constant_op.constant(1))
            self.assertIn('test_fn/', first.op.name)
            self.assertNotIn('inner_fn', first.op.name)
            self.assertIn('test_fn/inner_fn/', second.op.inputs[0].name)
Exemple #9
0
    def transform_ast(self, node, ctx):
        unsupported_features_checker.verify(node)
        node = self.initial_analysis(node, ctx)

        node = functions.transform(node, ctx)
        node = directives.transform(node, ctx)
        node = break_statements.transform(node, ctx)
        if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS):
            node = asserts.transform(node, ctx)
        # Note: sequencing continue canonicalization before for loop one avoids
        # dealing with the extra loop increment operation that the for
        # canonicalization creates.
        node = continue_statements.transform(node, ctx)
        node = return_statements.transform(node, ctx)
        if ctx.user.options.uses(converter.Feature.LISTS):
            node = lists.transform(node, ctx)
            node = slices.transform(node, ctx)
        node = call_trees.transform(node, ctx)
        node = control_flow.transform(node, ctx)
        node = conditional_expressions.transform(node, ctx)
        node = logical_expressions.transform(node, ctx)
        node = variables.transform(node, ctx)
        return node
Exemple #10
0
def _parse_and_analyze(f, autobatch_functions):
    """Performs preliminary analyses and transformations.

  The goal is to massage the source program into a form on which
  the `_AutoBatchingTransformer` below will be successful.

  Args:
    f: Function to analyze
    autobatch_functions: List of Python `str` names of autobatched functions.
      Arguments to these functions will be canonicalized to variable references,
      but others will not.

  Returns:
    node: A Python AST node representing the function, suitable for
      passing to `_AutoBatchingTransformer.visit`
    entity_info: An AutoGraph `EntityInfo` object, with some information
      about `f`.  Required for initializing `_AutoBatchingTransformer`.
  """
    namespace = {}

    # Get the AST of the function
    future_features = inspect_utils.getfutureimports(f)
    node, _ = parser.parse_entity(f, future_features=future_features)

    # Boilerplate for AutoGraph transforms
    entity_info = transformer.EntityInfo(source_code='',
                                         source_file=None,
                                         future_features=future_features,
                                         namespace=namespace)
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(recursive=True),
        autograph_module=None)
    ctx = converter.EntityContext(namer=naming.Namer(namespace),
                                  entity_info=entity_info,
                                  program_ctx=program_ctx)

    # Canonicalize away break statements
    node = converter.standard_analysis(node, ctx, is_initial=True)
    node = break_statements.transform(node, ctx)

    # Canonicalize away continue statements
    node = converter.standard_analysis(node, ctx, is_initial=False)
    node = continue_statements.transform(node, ctx)

    # Force single returns
    node = converter.standard_analysis(node, ctx, is_initial=False)
    node = return_statements.transform(node, ctx, default_to_null_return=False)

    # Transform into ANF
    # Replacing if tests and autobatched function call arguments because
    # that's where divergence can happen.
    # Replacing all function calls because the downstream transformation
    # expects calls to lead directly to assignments.
    def maybe_replace_function_argument(parent, field_name, child):
        del field_name, child
        if not anno.hasanno(parent.func, anno.Basic.QN):
            return False
        func_name = anno.getanno(parent.func, anno.Basic.QN)
        if str(func_name) in autobatch_functions:
            return True
        return False

    anf_config = [
        (anf.ASTEdgePattern(gast.If, 'test', anf.ANY), anf.REPLACE),
        (anf.ASTEdgePattern(anf.ANY, anf.ANY, gast.Call), anf.REPLACE),
        (anf.ASTEdgePattern(gast.Call, 'args',
                            anf.ANY), maybe_replace_function_argument),
    ]
    node = anf.transform(node, ctx, config=anf_config)
    node = converter.standard_analysis(node, ctx, is_initial=False)

    return node, ctx