def test_anf_some_function_calls(self): # Another example specific configuration that differs from the default: # Moving all arguments out of some function calls but leaving others be. whitelist = ['foo'] def transform(parent, field, child): del field del child func_name = parent.func.id return str(func_name) in whitelist config = [(anf.ASTEdgePattern(gast.Call, anf.ANY, anf.ANY), transform)] def test_function(x, foo, bar): y = foo(x, x + 1, 2) return bar(y, y + 1, 2) def expected_result(x, foo, bar): tmp_1001 = x + 1 tmp_1002 = 2 y = foo(x, tmp_1001, tmp_1002) return bar(y, y + 1, 2) self.assert_body_anfs_as_expected(expected_result, test_function, config)
def test_constants_in_function_calls(self): # An example specific configuration that differs from the default: Moving # literals out of being directly passed to functions, but nothing else. literals = (gast.Num, gast.Str, gast.Bytes, gast.NameConstant) config = [(anf.ASTEdgePattern(gast.Call, anf.ANY, literals), anf.REPLACE)] def test_function(x, frob): return frob(x, x + 1, 2) def expected_result(x, frob): tmp_1001 = 2 return frob(x, x + 1, tmp_1001) self.assert_body_anfs_as_expected(expected_result, test_function, config)
def test_touching_name_constant(self): # Checking that the nodes for `True`, `False`, and `None` can be manipulated # by a configuration. This is non-trivial, because in Python 2 those are # represented as `Name`, which is the same node type as variable references. specials = (gast.Name, gast.Constant) config = [(anf.ASTEdgePattern(gast.Call, anf.ANY, specials), anf.REPLACE)] def test_function(f): return f(True, False, None) def expected_result(f): tmp_1001 = True tmp_1002 = False tmp_1003 = None return f(tmp_1001, tmp_1002, tmp_1003) self.assert_body_anfs_as_expected(expected_result, test_function, config)
def _parse_and_analyze(f, autobatch_functions): """Performs preliminary analyses and transformations. The goal is to massage the source program into a form on which the `_AutoBatchingTransformer` below will be successful. Args: f: Function to analyze autobatch_functions: List of Python `str` names of autobatched functions. Arguments to these functions will be canonicalized to variable references, but others will not. Returns: node: A Python AST node representing the function, suitable for passing to `_AutoBatchingTransformer.visit` entity_info: An AutoGraph `EntityInfo` object, with some information about `f`. Required for initializing `_AutoBatchingTransformer`. """ namespace = {} # Get the AST of the function future_features = inspect_utils.getfutureimports(f) node, _ = parser.parse_entity(f, future_features=future_features) # Boilerplate for AutoGraph transforms entity_info = transformer.EntityInfo(source_code='', source_file=None, future_features=future_features, namespace=namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=True), autograph_module=None) ctx = converter.EntityContext(namer=naming.Namer(namespace), entity_info=entity_info, program_ctx=program_ctx) # Canonicalize away break statements node = converter.standard_analysis(node, ctx, is_initial=True) node = break_statements.transform(node, ctx) # Canonicalize away continue statements node = converter.standard_analysis(node, ctx, is_initial=False) node = continue_statements.transform(node, ctx) # Force single returns node = converter.standard_analysis(node, ctx, is_initial=False) node = return_statements.transform(node, ctx, default_to_null_return=False) # Transform into ANF # Replacing if tests and autobatched function call arguments because # that's where divergence can happen. # Replacing all function calls because the downstream transformation # expects calls to lead directly to assignments. def maybe_replace_function_argument(parent, field_name, child): del field_name, child if not anno.hasanno(parent.func, anno.Basic.QN): return False func_name = anno.getanno(parent.func, anno.Basic.QN) if str(func_name) in autobatch_functions: return True return False anf_config = [ (anf.ASTEdgePattern(gast.If, 'test', anf.ANY), anf.REPLACE), (anf.ASTEdgePattern(anf.ANY, anf.ANY, gast.Call), anf.REPLACE), (anf.ASTEdgePattern(gast.Call, 'args', anf.ANY), maybe_replace_function_argument), ] node = anf.transform(node, ctx, config=anf_config) node = converter.standard_analysis(node, ctx, is_initial=False) return node, ctx