def _simple_source_info(self): return transformer.EntityInfo(source_code=None, source_file=None, namespace=None, arg_values=None, arg_types=None, owner_type=None)
def prepare(self, test_fn, namespace, recursive=True): namespace['ConversionOptions'] = converter.ConversionOptions future_features = ('print_function', 'division') node, source = parser.parse_entity(test_fn, future_features=future_features) namer = naming.Namer(namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=recursive), autograph_module=None) entity_info = transformer.EntityInfo( name=test_fn.__name__, source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) ctx = transformer.Context(entity_info, namer, program_ctx) origin_info.resolve_entity(node, source, test_fn) graphs = cfg.build(node) node = qual_names.resolve(node) node = activity.resolve(node, ctx, None) node = reaching_definitions.resolve(node, ctx, graphs) anno.dup( node, { anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS, }, ) return node, ctx
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] # TODO(znado): Place inside standard_analysis. origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = program_ctx.new_namer(namespace) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace=namespace, arg_values=arg_values, arg_types=arg_types, owner_type=owner_type) context = converter.EntityContext(namer, entity_info, program_ctx) node = node_to_graph(node, context) # TODO(mdan): This somewhat duplicates the call rename logic in call_trees.py new_name, did_rename = namer.compiled_function_name( f.__name__, f, owner_type) if not did_rename: new_name = f.__name__ if node.name != f.__name__: raise NotImplementedError( 'Strange corner case. Send us offending code!') node.name = new_name program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. return [node], new_name, namespace
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] # In general, the output of inspect.getsource is inexact because it uses # regex matching to adjust the exact location around the line number that # CPython records. This is particularly problematic for lambda functions, # where the entire containing lines are returned. nodes = ast_util.find_matching_definitions(node, f) if len(nodes) != 1: if f.__name__ == '<lambda>': raise ValueError( 'Unable to identify source code of lambda function {}. It was' ' defined on this line: {}, which must contain a single lambda with' ' matching signature. To avoid ambiguity, define each lambda' ' in a separate expression.'.format(f, source)) else: raise ValueError( 'Unable to identify source code of function {}. The source code' ' reported by Python did not include exactly one matching signature:' '\n{}\n. This is an extremely rare occurrence. Please report it to' ' the TensorFlow team.'.format(f, source)) node, = nodes # TODO(znado): Place inside standard_analysis. origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = program_ctx.new_namer(namespace) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace=namespace, arg_values=arg_values, arg_types=arg_types, owner_type=owner_type) context = converter.EntityContext(namer, entity_info, program_ctx) node = node_to_graph(node, context) if isinstance(node, gast.Lambda): new_name = namer.new_symbol('tf__lambda', ()) node = gast.Assign(targets=[gast.Name(new_name, gast.Store(), None)], value=node) else: # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py new_name, did_rename = namer.compiled_function_name( f.__name__, f, owner_type) if did_rename: node.name = new_name else: new_name = f.__name__ assert node.name == new_name program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. return [node], new_name, namespace
def prepare(self, test_fn, namespace, namer=None, arg_types=None, owner_type=None, recursive=True, strip_decorators=()): namespace['ConversionOptions'] = converter.ConversionOptions node, source = parser.parse_entity(test_fn) node = node.body[0] if namer is None: namer = FakeNamer() program_ctx = converter.ProgramContext( options=converter.ConversionOptions( recursive=recursive, strip_decorators=strip_decorators), partial_types=None, autograph_module=None, uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace=namespace, arg_values=None, arg_types=arg_types, owner_type=owner_type) ctx = converter.EntityContext(namer, entity_info, program_ctx) node = converter.standard_analysis(node, ctx, is_initial=True) return node, ctx
def _simple_context(self): entity_info = transformer.EntityInfo(name='Test_fn', source_code=None, source_file=None, future_features=(), namespace=None) return transformer.Context(entity_info, None, None)
def _simple_context(self): entity_info = transformer.EntityInfo(source_code=None, source_file=None, namespace=None, arg_values=None, arg_types=None) return transformer.Context(entity_info)
def convert_func_to_ast(f, program_ctx, do_rename=True): """Specialization of `convert_entity_to_ast` for callable functions.""" future_features = inspect_utils.getfutureimports(f) node, source = parser.parse_entity(f, future_features=future_features) logging.log(3, 'Source code of %s:\n\n%s\n', f, source) # Parsed AST should contain future imports and one function def node. # In general, the output of inspect.getsource is inexact for lambdas because # it uses regex matching to adjust the exact location around the line number # that CPython records. Then, the entire containing line is returned, which # we may have trouble disambiguating. For example: # x, y = lambda: 1, lambda: 2 if f.__name__ == '<lambda>': nodes = ast_util.find_matching_definitions(node, f) if len(nodes) != 1: raise ValueError( 'Unable to identify source code of lambda function {}. It was' ' defined on this line: {}, which must contain a single lambda with' ' matching signature. To avoid ambiguity, define each lambda' ' in a separate expression.'.format(f, source)) node, = nodes # TODO(znado): Place inside standard_analysis. origin_info.resolve_entity(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = naming.Namer(namespace) if isinstance(node, gast.Lambda): new_name = namer.new_symbol('tf__lambda', ()) elif do_rename: new_name = namer.function_name(f.__name__) else: new_name = f.__name__ entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) context = converter.EntityContext(namer, entity_info, program_ctx, new_name) node = node_to_graph(node, context) if isinstance(node, gast.Lambda): node = gast.Assign(targets=[ gast.Name(new_name, ctx=gast.Store(), annotation=None, type_comment=None) ], value=node) elif do_rename: node.name = new_name else: assert node.name == new_name return (node, ), new_name, entity_info
def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn, future_features=()) entity_info = transformer.EntityInfo( source_code=source, source_file=None, future_features=(), namespace={}) node = qual_names.resolve(node) ctx = transformer.Context(entity_info) node = activity.resolve(node, ctx) return node, entity_info
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] # TODO(mdan): Can we convert everything and scoop the lambda afterwards? if f.__name__ == '<lambda>': nodes = ast_util.find_matching_lambda_definitions(node, f) if len(nodes) != 1: raise ValueError( 'Unable to identify source code of lambda function {}. It was' ' defined on this line: {}, which contains multiple lambdas with' ' identical argument names. To avoid ambiguity, define each lambda' ' in a separate expression.'.format(f, source)) node, = nodes # TODO(znado): Place inside standard_analysis. origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = program_ctx.new_namer(namespace) entity_info = transformer.EntityInfo( source_code=source, source_file='<fragment>', namespace=namespace, arg_values=arg_values, arg_types=arg_types, owner_type=owner_type) context = converter.EntityContext(namer, entity_info, program_ctx) node = node_to_graph(node, context) if isinstance(node, gast.Lambda): new_name = namer.new_symbol('tf__lambda', ()) node = gast.Assign( targets=[gast.Name(new_name, gast.Store(), None)], value=node) else: # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) if did_rename: node.name = new_name else: new_name = f.__name__ assert node.name == new_name program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. return [node], new_name, namespace
def mlir_gen(func): """Parse a function and return TFProgram.""" node, source = parser.parse_entity(func, future_features=()) entity_info = transformer.EntityInfo( name=func.__name__, source_code=source, source_file=None, future_features=(), namespace=inspect_utils.getnamespace(func)) return mlir_gen_internal(node, entity_info)
def get_node_and_ctx(f): node, source = parser.parse_entity(f, ()) f_info = transformer.EntityInfo( name='f', source_code=source, source_file=None, future_features=(), namespace=None) ctx = transformer.Context(f_info, None, None) return node, ctx
def function_to_graph(f, program_ctx, arg_values, arg_types, do_rename=True): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) logging.log(3, 'Source code of %s:\n\n%s\n', f, source) node = node.body[0] # In general, the output of inspect.getsource is inexact for lambdas because # it uses regex matching to adjust the exact location around the line number # that CPython records. Then, the entire containing line is returned, which # we may have trouble disambiguating. For example: # x, y = lambda: 1, lambda: 2 if f.__name__ == '<lambda>': nodes = ast_util.find_matching_definitions(node, f) if len(nodes) != 1: raise ValueError( 'Unable to identify source code of lambda function {}. It was' ' defined on this line: {}, which must contain a single lambda with' ' matching signature. To avoid ambiguity, define each lambda' ' in a separate expression.'.format(f, source)) node, = nodes # TODO(znado): Place inside standard_analysis. origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = naming.Namer(namespace) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace=namespace, arg_values=arg_values, arg_types=arg_types) context = converter.EntityContext(namer, entity_info, program_ctx) try: node = node_to_graph(node, context) except (ValueError, AttributeError, KeyError, NotImplementedError) as e: logging.error(1, 'Error converting %s', f, exc_info=True) raise errors.InternalError('conversion', e) # TODO(mdan): Catch and rethrow syntax errors. if isinstance(node, gast.Lambda): new_name = namer.new_symbol('tf__lambda', ()) node = gast.Assign(targets=[gast.Name(new_name, gast.Store(), None)], value=node) elif do_rename: # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py new_name = namer.function_name(f.__name__) node.name = new_name else: new_name = f.__name__ assert node.name == new_name return [node], new_name, namespace
def _parse_and_analyze(self, test_fn): node, source, _ = parser.parse_entity(test_fn) entity_info = transformer.EntityInfo(source_code=source, source_file=None, namespace={}, arg_values=None, arg_types=None) node = qual_names.resolve(node) ctx = transformer.Context(entity_info) node = activity.resolve(node, ctx) return node, entity_info
def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn, future_features=()) entity_info = transformer.EntityInfo( source_code=source, source_file=None, future_features=(), namespace={}) node = qual_names.resolve(node) ctx = transformer.Context(entity_info) node = activity.resolve(node, ctx) graphs = cfg.build(node) node = reaching_definitions.resolve(node, ctx, graphs, reaching_definitions.Definition) return node
def mlir_gen_from_source(source=None, src_file=None): """Parse a function as either a string or from a supplied file path and return a TFProgram. """ if source is None: source = open(src_file).read() node = ast.parse(source) entity_info = transformer.EntityInfo(name='mlir_module', source_code=source, source_file=None, future_features=(), namespace={}) return mlir_gen_internal(node, entity_info)
def _transform_function(self, fn, user_context): """Performs source code transformation on a function.""" future_features = inspect_utils.getfutureimports(fn) node, source = parser.parse_entity(fn, future_features=future_features) logging.log(3, 'Source code of %s:\n\n%s\n', fn, source) # In general, the output of inspect.getsource is inexact for lambdas # because it uses regex matching to adjust the exact location around # the line number that CPython records. Then, the entire containing line # is returned, which we may have trouble disambiguating. # For example: # x, y = lambda: 1, lambda: 2 is_lambda = fn.__name__ == '<lambda>' if is_lambda: nodes = ast_util.find_matching_definitions(node, fn) if len(nodes) != 1: raise ValueError( 'Unable to identify source code of lambda function {}.' ' It was defined in this code:\n' '{}\n' 'This code must contain a single distinguishable lambda.' ' To avoid this problem, define each lambda in a separate' ' expression.'.format(fn, source)) node, = nodes origin_info.resolve_entity(node, source, fn) namespace = inspect_utils.getnamespace(fn) namer = naming.Namer(namespace) new_name = namer.new_symbol(self.get_transformed_name(node), ()) entity_info = transformer.EntityInfo(name=new_name, source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) context = transformer.Context(entity_info, namer, user_context) node = self._erase_arg_defaults(node) node = self.transform_ast(node, context) if is_lambda: node = gast.Assign(targets=[ gast.Name(new_name, ctx=gast.Store(), annotation=None, type_comment=None) ], value=node) else: node.name = new_name return node, context
def _parse_and_analyze(self, test_fn): # TODO(mdan): Use a custom FunctionTransformer here. node, source = parser.parse_entity(test_fn, future_features=()) entity_info = transformer.EntityInfo(name=test_fn.__name__, source_code=source, source_file=None, future_features=(), namespace={}) node = qual_names.resolve(node) namer = naming.Namer({}) ctx = transformer.Context(entity_info, namer, None) node = activity.resolve(node, ctx) return node, entity_info
def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn) entity_info = transformer.EntityInfo(source_code=source, source_file=None, namespace={}, arg_values=None, arg_types=None, owner_type=None) node = qual_names.resolve(node) node = activity.resolve(node, entity_info) graphs = cfg.build(node) liveness.resolve(node, entity_info, graphs) return node
def _parse_and_analyze(self, test_fn): node, _, source = parser.parse_entity(test_fn, future_imports=()) entity_info = transformer.EntityInfo(source_code=source, source_file=None, namespace={}, arg_values=None, arg_types=None) node = qual_names.resolve(node) ctx = transformer.Context(entity_info) node = activity.resolve(node, ctx) graphs = cfg.build(node) liveness.resolve(node, ctx, graphs) return node
def _parse_and_analyze(f): """Performs preliminary analyses and transformations. The goal is to massage the source program into a form on which the `_AutoBatchingTransformer` below will be successful. Args: f: Function to analyze Returns: node: A Python AST node representing the function, suitable for passing to `_AutoBatchingTransformer.visit` entity_info: An AutoGraph `EntityInfo` object, with some information about `f`. Required for initializing `_AutoBatchingTransformer`. """ namespace = {} # Get the AST of the function future_features = inspect_utils.getfutureimports(f) node, _ = parser.parse_entity(f, future_features=future_features) # Boilerplate for AutoGraph transforms entity_info = transformer.EntityInfo(source_code='', source_file=None, future_features=future_features, namespace=namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=True), autograph_module=None) ctx = converter.EntityContext(namer=naming.Namer(namespace), entity_info=entity_info, program_ctx=program_ctx) # Canonicalize away break statements node = converter.standard_analysis(node, ctx, is_initial=True) node = break_statements.transform(node, ctx) # Canonicalize away continue statements node = converter.standard_analysis(node, ctx, is_initial=False) node = continue_statements.transform(node, ctx) # Force single returns node = converter.standard_analysis(node, ctx, is_initial=False) node = return_statements.transform(node, ctx, default_to_null_return=False) # Transform into ANF node = anf.transform(node, ctx) node = converter.standard_analysis(node, ctx, is_initial=False) return node, ctx
def _parse_and_analyze(self, test_fn, namespace, arg_types=None): node, source = parser.parse_entity(test_fn) entity_info = transformer.EntityInfo(source_code=source, source_file=None, namespace=namespace, arg_values=None, arg_types=arg_types) node = qual_names.resolve(node) graphs = cfg.build(node) ctx = transformer.Context(entity_info) node = activity.resolve(node, ctx) node = reaching_definitions.resolve(node, ctx, graphs, reaching_definitions.Definition) node = live_values.resolve(node, ctx, {}) node = type_info.resolve(node, ctx) node = live_values.resolve(node, ctx, {}) return node
def prepare(self, test_fn, namespace, recursive=True): namespace['ConversionOptions'] = converter.ConversionOptions future_features = ('print_function', 'division') node, source = parser.parse_entity(test_fn, future_features=future_features) namer = naming.Namer(namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=recursive), autograph_module=None) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) ctx = converter.EntityContext(namer, entity_info, program_ctx) node = converter.standard_analysis(node, ctx, is_initial=True) return node, ctx
def transform_function(self, fn, user_context): """Transforms a function. Subclasses may override this method. The return value is opaque. The method receives the original AST. The result is passed as-is to the output of `transform`. Args: fn: A function or lambda. user_context: An opaque object (may be None) that is forwarded to transform_ast, through the ctx.user_context argument. Returns: Any. By default it returns the output of transform_ast. """ future_features = inspect_utils.getfutureimports(fn) node, source = parser.parse_entity(fn, future_features=future_features) logging.log(3, 'Source code of %s:\n\n%s\n', fn, source) origin_info.resolve_entity(node, source, fn) namespace = inspect_utils.getnamespace(fn) namer = naming.Namer(namespace) new_name = namer.new_symbol(self.get_transformed_name(node), ()) entity_info = transformer.EntityInfo(name=new_name, source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) context = transformer.Context(entity_info, namer, user_context) node = self._erase_arg_defaults(node) node = self.transform_ast(node, context) if isinstance(node, gast.Lambda): node = gast.Assign(targets=[ gast.Name(new_name, ctx=gast.Store(), annotation=None, type_comment=None) ], value=node) else: node.name = new_name return node, context
def prepare(self, test_fn, namespace, arg_types=None, recursive=True): namespace['ConversionOptions'] = converter.ConversionOptions node, source, _ = parser.parse_entity(test_fn) namer = naming.Namer(namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=recursive), autograph_module=None) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace=namespace, arg_values=None, arg_types=arg_types) ctx = converter.EntityContext(namer, entity_info, program_ctx) origin_info.resolve(node, source, test_fn) node = converter.standard_analysis(node, ctx, is_initial=True) return node, ctx
def _parse_and_analyze(self, test_fn, namespace, literals=None, arg_types=None): literals = literals or {} node, source = parser.parse_entity(test_fn) entity_info = transformer.EntityInfo(source_code=source, source_file=None, namespace=namespace, arg_values=None, arg_types=arg_types, owner_type=None) node = qual_names.resolve(node) graphs = cfg.build(node) node = activity.resolve(node, entity_info) node = reaching_definitions.resolve(node, entity_info, graphs, reaching_definitions.Definition) node = live_values.resolve(node, entity_info, literals) node = type_info.resolve(node, entity_info) node = live_values.resolve(node, entity_info, literals) return node
def transform_function(self, fn, user_context): """Transforms a function. Subclasses may override this method. The return value is opaque. The method receives the original AST. The result is passed as-is to the output of `transform`. Args: fn: A function or lambda. user_context: An opaque object (may be None) that is forwarded to transform_ast, through the ctx.user_context argument. Returns: Tuple[Any, Any]. By default it returns the output of transform_ast, together with a `transformer.Context` containing information about the transformation process. """ future_features = inspect_utils.getfutureimports(fn) node, source = parser.parse_entity(fn, future_features=future_features) logging.log(3, 'Source code of %s:\n\n%s\n', fn, source) origin_info.resolve_entity(node, source, fn) namespace = inspect_utils.getnamespace(fn) namer = naming.Namer(namespace) new_name = namer.new_symbol(self.get_transformed_name(node), ()) entity_info = transformer.EntityInfo( name=new_name, source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) context = transformer.Context(entity_info, namer, user_context) node = self._erase_arg_defaults(node) result = self.transform_ast(node, context) return result, context
def test_to_ast(self): opts = converter.ConversionOptions() namer = converter_testing.FakeNamer() program_ctx = converter.ProgramContext(options=opts, partial_types=None, autograph_module=None, uncompiled_modules=()) entity_info = transformer.EntityInfo(source_code='', source_file='<fragment>', namespace={}, arg_values=None, arg_types={}, owner_type=None) ctx = converter.EntityContext(namer, entity_info, program_ctx) opts_ast = opts.to_ast(ctx) template = ''' def test_fn(): return opts_ast ''' opts_packed = templates.replace(template, opts_ast=opts_ast) reparsed, _ = compiler.ast_to_object(opts_packed) reparsed.__dict__['ag__'] = self.make_fake_mod( 'fake_ag', converter.ConversionOptions, converter.Feature) reparsed_opts = reparsed.test_fn() self.assertEqual(opts.recursive, reparsed_opts.recursive) self.assertEqual(opts.verbose, reparsed_opts.verbose) self.assertEqual(opts.force_conversion, reparsed_opts.force_conversion) self.assertEqual(opts.internal_convert_user_code, reparsed_opts.internal_convert_user_code) self.assertEqual(opts.optional_features, reparsed_opts.optional_features)
def _transform_function(self, fn, user_context): """Performs source code transformation on a function.""" future_features = inspect_utils.getfutureimports(fn) node, source = parser.parse_entity(fn, future_features=future_features) logging.log(3, 'Source code of %s:\n\n%s\n', fn, source) origin_info.resolve_entity(node, source, fn) namespace = inspect_utils.getnamespace(fn) namer = naming.Namer(namespace) new_name = namer.new_symbol(self.get_transformed_name(node), ()) entity_info = transformer.EntityInfo( name=new_name, source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) context = transformer.Context(entity_info, namer, user_context) node = self._erase_arg_defaults(node) node = self.transform_ast(node, context) if isinstance(node, gast.Lambda): node = gast.Assign( targets=[ gast.Name( new_name, ctx=gast.Store(), annotation=None, type_comment=None) ], value=node) else: node.name = new_name return node, context
def _parse_and_analyze(f, autobatch_functions): """Performs preliminary analyses and transformations. The goal is to massage the source program into a form on which the `_AutoBatchingTransformer` below will be successful. Args: f: Function to analyze autobatch_functions: List of Python `str` names of autobatched functions. Arguments to these functions will be canonicalized to variable references, but others will not. Returns: node: A Python AST node representing the function, suitable for passing to `_AutoBatchingTransformer.visit` entity_info: An AutoGraph `EntityInfo` object, with some information about `f`. Required for initializing `_AutoBatchingTransformer`. """ namespace = {} # Get the AST of the function future_features = inspect_utils.getfutureimports(f) node, _ = parser.parse_entity(f, future_features=future_features) # Boilerplate for AutoGraph transforms entity_info = transformer.EntityInfo(source_code='', source_file=None, future_features=future_features, namespace=namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=True), autograph_module=None) ctx = converter.EntityContext(namer=naming.Namer(namespace), entity_info=entity_info, program_ctx=program_ctx) # Canonicalize away break statements node = converter.standard_analysis(node, ctx, is_initial=True) node = break_statements.transform(node, ctx) # Canonicalize away continue statements node = converter.standard_analysis(node, ctx, is_initial=False) node = continue_statements.transform(node, ctx) # Force single returns node = converter.standard_analysis(node, ctx, is_initial=False) node = return_statements.transform(node, ctx, default_to_null_return=False) # Transform into ANF # Replacing if tests and autobatched function call arguments because # that's where divergence can happen. # Replacing all function calls because the downstream transformation # expects calls to lead directly to assignments. def maybe_replace_function_argument(parent, field_name, child): del field_name, child if not anno.hasanno(parent.func, anno.Basic.QN): return False func_name = anno.getanno(parent.func, anno.Basic.QN) if str(func_name) in autobatch_functions: return True return False anf_config = [ (anf.ASTEdgePattern(gast.If, 'test', anf.ANY), anf.REPLACE), (anf.ASTEdgePattern(anf.ANY, anf.ANY, gast.Call), anf.REPLACE), (anf.ASTEdgePattern(gast.Call, 'args', anf.ANY), maybe_replace_function_argument), ] node = anf.transform(node, ctx, config=anf_config) node = converter.standard_analysis(node, ctx, is_initial=False) return node, ctx