def test_anf_some_function_calls(self):
        # Another example specific configuration that differs from the default:
        # Moving all arguments out of some function calls but leaving others be.
        whitelist = ['foo']

        def transform(parent, field, child):
            del field
            del child
            func_name = parent.func.id
            return str(func_name) in whitelist

        config = [(anf.ASTEdgePattern(gast.Call, anf.ANY, anf.ANY), transform)]

        def test_function(x, foo, bar):
            y = foo(x, x + 1, 2)
            return bar(y, y + 1, 2)

        def expected_result(x, foo, bar):
            tmp_1001 = x + 1
            tmp_1002 = 2
            y = foo(x, tmp_1001, tmp_1002)
            return bar(y, y + 1, 2)

        self.assert_body_anfs_as_expected(expected_result, test_function,
                                          config)
Exemplo n.º 2
0
    def test_constants_in_function_calls(self):
        # An example specific configuration that differs from the default: Moving
        # literals out of being directly passed to functions, but nothing else.
        literals = (gast.Num, gast.Str, gast.Bytes, gast.NameConstant)
        config = [(anf.ASTEdgePattern(gast.Call, anf.ANY,
                                      literals), anf.REPLACE)]

        def test_function(x, frob):
            return frob(x, x + 1, 2)

        def expected_result(x, frob):
            tmp_1001 = 2
            return frob(x, x + 1, tmp_1001)

        self.assert_body_anfs_as_expected(expected_result, test_function,
                                          config)
Exemplo n.º 3
0
  def test_touching_name_constant(self):
    # Checking that the nodes for `True`, `False`, and `None` can be manipulated
    # by a configuration.  This is non-trivial, because in Python 2 those are
    # represented as `Name`, which is the same node type as variable references.
    specials = (gast.Name, gast.Constant)
    config = [(anf.ASTEdgePattern(gast.Call, anf.ANY, specials), anf.REPLACE)]

    def test_function(f):
      return f(True, False, None)

    def expected_result(f):
      tmp_1001 = True
      tmp_1002 = False
      tmp_1003 = None
      return f(tmp_1001, tmp_1002, tmp_1003)

    self.assert_body_anfs_as_expected(expected_result, test_function, config)
Exemplo n.º 4
0
def _parse_and_analyze(f, autobatch_functions):
    """Performs preliminary analyses and transformations.

  The goal is to massage the source program into a form on which
  the `_AutoBatchingTransformer` below will be successful.

  Args:
    f: Function to analyze
    autobatch_functions: List of Python `str` names of autobatched functions.
      Arguments to these functions will be canonicalized to variable references,
      but others will not.

  Returns:
    node: A Python AST node representing the function, suitable for
      passing to `_AutoBatchingTransformer.visit`
    entity_info: An AutoGraph `EntityInfo` object, with some information
      about `f`.  Required for initializing `_AutoBatchingTransformer`.
  """
    namespace = {}

    # Get the AST of the function
    future_features = inspect_utils.getfutureimports(f)
    node, _ = parser.parse_entity(f, future_features=future_features)

    # Boilerplate for AutoGraph transforms
    entity_info = transformer.EntityInfo(source_code='',
                                         source_file=None,
                                         future_features=future_features,
                                         namespace=namespace)
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(recursive=True),
        autograph_module=None)
    ctx = converter.EntityContext(namer=naming.Namer(namespace),
                                  entity_info=entity_info,
                                  program_ctx=program_ctx)

    # Canonicalize away break statements
    node = converter.standard_analysis(node, ctx, is_initial=True)
    node = break_statements.transform(node, ctx)

    # Canonicalize away continue statements
    node = converter.standard_analysis(node, ctx, is_initial=False)
    node = continue_statements.transform(node, ctx)

    # Force single returns
    node = converter.standard_analysis(node, ctx, is_initial=False)
    node = return_statements.transform(node, ctx, default_to_null_return=False)

    # Transform into ANF
    # Replacing if tests and autobatched function call arguments because
    # that's where divergence can happen.
    # Replacing all function calls because the downstream transformation
    # expects calls to lead directly to assignments.
    def maybe_replace_function_argument(parent, field_name, child):
        del field_name, child
        if not anno.hasanno(parent.func, anno.Basic.QN):
            return False
        func_name = anno.getanno(parent.func, anno.Basic.QN)
        if str(func_name) in autobatch_functions:
            return True
        return False

    anf_config = [
        (anf.ASTEdgePattern(gast.If, 'test', anf.ANY), anf.REPLACE),
        (anf.ASTEdgePattern(anf.ANY, anf.ANY, gast.Call), anf.REPLACE),
        (anf.ASTEdgePattern(gast.Call, 'args',
                            anf.ANY), maybe_replace_function_argument),
    ]
    node = anf.transform(node, ctx, config=anf_config)
    node = converter.standard_analysis(node, ctx, is_initial=False)

    return node, ctx