def _parse_and_analyze(f): """Performs preliminary analyses and transformations. The goal is to massage the source program into a form on which the `_AutoBatchingTransformer` below will be successful. Args: f: Function to analyze Returns: node: A Python AST node representing the function, suitable for passing to `_AutoBatchingTransformer.visit` entity_info: An AutoGraph `EntityInfo` object, with some information about `f`. Required for initializing `_AutoBatchingTransformer`. """ namespace = {} # Get the AST of the function future_features = inspect_utils.getfutureimports(f) node, _ = parser.parse_entity(f, future_features=future_features) # Boilerplate for AutoGraph transforms entity_info = transformer.EntityInfo(source_code='', source_file=None, future_features=future_features, namespace=namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=True), autograph_module=None) ctx = converter.EntityContext(namer=naming.Namer(namespace), entity_info=entity_info, program_ctx=program_ctx) # Canonicalize away break statements node = converter.standard_analysis(node, ctx, is_initial=True) node = break_statements.transform(node, ctx) # Canonicalize away continue statements node = converter.standard_analysis(node, ctx, is_initial=False) node = continue_statements.transform(node, ctx) # Force single returns node = converter.standard_analysis(node, ctx, is_initial=False) node = return_statements.transform(node, ctx, default_to_null_return=False) # Transform into ANF node = anf.transform(node, ctx) node = converter.standard_analysis(node, ctx, is_initial=False) return node, ctx
def prepare(self, test_fn, namespace, namer=None, arg_types=None, owner_type=None, recursive=True, strip_decorators=()): namespace['ConversionOptions'] = converter.ConversionOptions node, source = parser.parse_entity(test_fn) node = node.body[0] if namer is None: namer = FakeNamer() program_ctx = converter.ProgramContext( options=converter.ConversionOptions( recursive=recursive, strip_decorators=strip_decorators), partial_types=None, autograph_module=None, uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace=namespace, arg_values=None, arg_types=arg_types, owner_type=owner_type) ctx = converter.EntityContext(namer, entity_info, program_ctx) node = converter.standard_analysis(node, ctx, is_initial=True) return node, ctx
def prepare(self, test_fn, namespace, namer=None, arg_types=None, owner_type=None, recursive=True, strip_decorators=()): namespace['ConversionOptions'] = converter.ConversionOptions node, source = parser.parse_entity(test_fn) node = node.body[0] if namer is None: namer = FakeNamer() program_ctx = converter.ProgramContext( options=converter.ConversionOptions( recursive=recursive, strip_decorators=strip_decorators, verbose=True), partial_types=None, autograph_module=None, uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) entity_info = transformer.EntityInfo( source_code=source, source_file='<fragment>', namespace=namespace, arg_values=None, arg_types=arg_types, owner_type=owner_type) ctx = converter.EntityContext(namer, entity_info, program_ctx) origin_info.resolve(node, source, test_fn) node = converter.standard_analysis(node, ctx, is_initial=True) return node, ctx
def converted(self, entity, converter_module, namespace, *tf_symbols): node, ctx = self.prepare(entity, namespace) if not isinstance(converter_module, (list, tuple)): converter_module = (converter_module, ) for m in converter_module: node = m.transform(node, ctx) node = converter.standard_analysis(node, ctx, is_initial=True) with self.compiled(node, namespace, *tf_symbols) as result: yield result
def converted(self, entity, converter_module, namespace, *tf_symbols): node, ctx = self.prepare(entity, namespace) if not isinstance(converter_module, (list, tuple)): converter_module = (converter_module,) for i, m in enumerate(converter_module): node = converter.standard_analysis(node, ctx, is_initial=not i) node = m.transform(node, ctx) with self.compiled(node, namespace, *tf_symbols) as result: yield result
def node_to_graph(node, context): """Convert Python code to equivalent TF graph mode code. Args: node: AST, the code to convert. context: converter.EntityContext Returns: A tuple (node, deps): * node: A Python ast node, representing the converted code. * deps: A set of strings, the fully qualified names of entity dependencies that this node has. """ # TODO(mdan): Insert list_comprehensions somewhere. unsupported_features_checker.verify(node) node = converter.standard_analysis(node, context, is_initial=True) # Past this point, line numbers are no longer accurate so we ignore the # source. # TODO(mdan): Is it feasible to reconstruct intermediate source code? context.info.source_code = None node = converter.apply_(node, context, arg_defaults) node = converter.apply_(node, context, directives) node = converter.apply_(node, context, break_statements) if context.program.options.uses(converter.Feature.ASSERT_STATEMENTS): node = converter.apply_(node, context, asserts) # Note: sequencing continue canonicalization before for loop one avoids # dealing with the extra loop increment operation that the for # canonicalization creates. node = converter.apply_(node, context, continue_statements) node = converter.apply_(node, context, return_statements) if context.program.options.uses(converter.Feature.LISTS): node = converter.apply_(node, context, lists) node = converter.apply_(node, context, slices) if context.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS): node = converter.apply_(node, context, builtin_functions) node = converter.apply_(node, context, call_trees) node = converter.apply_(node, context, control_flow) node = converter.apply_(node, context, conditional_expressions) if context.program.options.uses(converter.Feature.LOGICAL_EXPRESSIONS): node = converter.apply_(node, context, logical_expressions) if context.program.options.uses(converter.Feature.AUTO_CONTROL_DEPS): node = converter.apply_(node, context, side_effect_guards) # TODO(mdan): If function scopes ever does more, the toggle will need moving. if context.program.options.uses(converter.Feature.NAME_SCOPES): node = converter.apply_(node, context, function_scopes) if context.program.options.uses(converter.Feature.ERROR_REWRITING): node = converter.apply_(node, context, error_handlers) return node
def node_to_graph(node, context, rewrite_errors=True): """Convert Python code to equivalent TF graph mode code. Args: node: AST, the code to convert. context: converter.EntityContext rewrite_errors: Boolean, whether or not to rewrite the error traceback. Returns: A tuple (node, deps): * node: A Python ast node, representing the converted code. * deps: A set of strings, the fully qualified names of entity dependencies that this node has. """ # TODO(mdan): Insert list_comprehensions somewhere. node = converter.standard_analysis(node, context, is_initial=True) # Past this point, line numbers are no longer accurate so we ignore the # source. # TODO(mdan): Is it feasible to reconstruct intermediate source code? context.info.source_code = None if context.program.options.uses(converter.Feature.DECORATORS): node = converter.apply_(node, context, decorators) node = converter.apply_(node, context, arg_defaults) node = converter.apply_(node, context, directives) node = converter.apply_(node, context, break_statements) node = converter.apply_(node, context, asserts) # Note: sequencing continue canonicalization before for loop one avoids # dealing with the extra loop increment operation that the for # canonicalization creates. node = converter.apply_(node, context, continue_statements) node = converter.apply_(node, context, return_statements) if context.program.options.uses(converter.Feature.LISTS): node = converter.apply_(node, context, lists) node = converter.apply_(node, context, slices) node = converter.apply_(node, context, builtin_functions) node = converter.apply_(node, context, call_trees) node = converter.apply_(node, context, control_flow) node = converter.apply_(node, context, conditional_expressions) node = converter.apply_(node, context, logical_expressions) if context.program.options.uses(converter.Feature.AUTO_CONTROL_DEPS): node = converter.apply_(node, context, side_effect_guards) node = converter.apply_(node, context, function_scopes) if rewrite_errors: node = converter.apply_(node, context, error_handlers) return node
def prepare(self, test_fn, namespace, recursive=True): namespace['ConversionOptions'] = converter.ConversionOptions future_features = ('print_function', 'division') node, source = parser.parse_entity(test_fn, future_features=future_features) namer = naming.Namer(namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=recursive), autograph_module=None) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) ctx = converter.EntityContext(namer, entity_info, program_ctx) node = converter.standard_analysis(node, ctx, is_initial=True) return node, ctx
def prepare(self, test_fn, namespace, arg_types=None, recursive=True): namespace['ConversionOptions'] = converter.ConversionOptions node, source, _ = parser.parse_entity(test_fn) namer = naming.Namer(namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=recursive), autograph_module=None) entity_info = transformer.EntityInfo(source_code=source, source_file='<fragment>', namespace=namespace, arg_values=None, arg_types=arg_types) ctx = converter.EntityContext(namer, entity_info, program_ctx) origin_info.resolve(node, source, test_fn) node = converter.standard_analysis(node, ctx, is_initial=True) return node, ctx
def node_to_graph(node, context, rewrite_errors=True): """Convert Python code to equivalent TF graph mode code. Args: node: AST, the code to convert. context: converter.EntityContext rewrite_errors: Boolean, whether or not to rewrite the error traceback. Returns: A tuple (node, deps): * node: A Python ast node, representing the converted code. * deps: A set of strings, the fully qualified names of entity dependencies that this node has. """ # TODO(mdan): Insert list_comprehensions somewhere. node = converter.standard_analysis(node, context, is_initial=True) # Past this point, line numbers are no longer accurate so we ignore the # source. # TODO(mdan): Is it feasible to reconstruct intermediate source code? context.info.source_code = None node = converter.apply_(node, context, decorators) node = converter.apply_(node, context, directives) node = converter.apply_(node, context, break_statements) node = converter.apply_(node, context, asserts) # Note: sequencing continue canonicalization before for loop one avoids # dealing with the extra loop increment operation that the for # canonicalization creates. node = converter.apply_(node, context, continue_statements) node = converter.apply_(node, context, return_statements) if context.program.options.uses(converter.Feature.LISTS): node = converter.apply_(node, context, lists) node = converter.apply_(node, context, slices) node = converter.apply_(node, context, builtin_functions) node = converter.apply_(node, context, call_trees) node = converter.apply_(node, context, control_flow) node = converter.apply_(node, context, conditional_expressions) node = converter.apply_(node, context, logical_expressions) if context.program.options.uses(converter.Feature.AUTO_CONTROL_DEPS): node = converter.apply_(node, context, side_effect_guards) node = converter.apply_(node, context, function_scopes) if rewrite_errors: node = converter.apply_(node, context, error_handlers) return node
def prepare(self, test_fn, namespace, recursive=True): namespace['ConversionOptions'] = converter.ConversionOptions future_features = ('print_function', 'division') node, source = parser.parse_entity(test_fn, future_features=future_features) namer = naming.Namer(namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=recursive), autograph_module=None) entity_info = transformer.EntityInfo( source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) ctx = converter.EntityContext(namer, entity_info, program_ctx) origin_info.resolve(node, source, test_fn) node = converter.standard_analysis(node, ctx, is_initial=True) return node, ctx
def transform_ast(self, node, ctx): # TODO(mdan): Insert list_comprehensions somewhere. unsupported_features_checker.verify(node) node = converter.standard_analysis(node, ctx, is_initial=True) node = converter.apply_(node, ctx, functions) node = converter.apply_(node, ctx, directives) node = converter.apply_(node, ctx, break_statements) if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS): node = converter.apply_(node, ctx, asserts) # Note: sequencing continue canonicalization before for loop one avoids # dealing with the extra loop increment operation that the for # canonicalization creates. node = converter.apply_(node, ctx, continue_statements) node = converter.apply_(node, ctx, return_statements) if ctx.user.options.uses(converter.Feature.LISTS): node = converter.apply_(node, ctx, lists) node = converter.apply_(node, ctx, slices) node = converter.apply_(node, ctx, call_trees) node = converter.apply_(node, ctx, control_flow) node = converter.apply_(node, ctx, conditional_expressions) node = converter.apply_(node, ctx, logical_expressions) return node
def node_to_graph(node, context): """Convert Python code to equivalent TF graph mode code. Args: node: AST, the code to convert. context: converter.EntityContext Returns: A tuple (node, deps): * node: A Python ast node, representing the converted code. * deps: A set of strings, the fully qualified names of entity dependencies that this node has. """ # TODO(mdan): Insert list_comprehensions somewhere. unsupported_features_checker.verify(node) node = converter.standard_analysis(node, context, is_initial=True) node = converter.apply_(node, context, function_scopes) node = converter.apply_(node, context, arg_defaults) node = converter.apply_(node, context, directives) node = converter.apply_(node, context, break_statements) if context.program.options.uses(converter.Feature.ASSERT_STATEMENTS): node = converter.apply_(node, context, asserts) # Note: sequencing continue canonicalization before for loop one avoids # dealing with the extra loop increment operation that the for # canonicalization creates. node = converter.apply_(node, context, continue_statements) node = converter.apply_(node, context, return_statements) if context.program.options.uses(converter.Feature.LISTS): node = converter.apply_(node, context, lists) node = converter.apply_(node, context, slices) node = converter.apply_(node, context, call_trees) node = converter.apply_(node, context, control_flow) node = converter.apply_(node, context, conditional_expressions) node = converter.apply_(node, context, logical_expressions) return node
def _parse_and_analyze(f, autobatch_functions): """Performs preliminary analyses and transformations. The goal is to massage the source program into a form on which the `_AutoBatchingTransformer` below will be successful. Args: f: Function to analyze autobatch_functions: List of Python `str` names of autobatched functions. Arguments to these functions will be canonicalized to variable references, but others will not. Returns: node: A Python AST node representing the function, suitable for passing to `_AutoBatchingTransformer.visit` entity_info: An AutoGraph `EntityInfo` object, with some information about `f`. Required for initializing `_AutoBatchingTransformer`. """ namespace = {} # Get the AST of the function future_features = inspect_utils.getfutureimports(f) node, _ = parser.parse_entity(f, future_features=future_features) # Boilerplate for AutoGraph transforms entity_info = transformer.EntityInfo(source_code='', source_file=None, future_features=future_features, namespace=namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=True), autograph_module=None) ctx = converter.EntityContext(namer=naming.Namer(namespace), entity_info=entity_info, program_ctx=program_ctx) # Canonicalize away break statements node = converter.standard_analysis(node, ctx, is_initial=True) node = break_statements.transform(node, ctx) # Canonicalize away continue statements node = converter.standard_analysis(node, ctx, is_initial=False) node = continue_statements.transform(node, ctx) # Force single returns node = converter.standard_analysis(node, ctx, is_initial=False) node = return_statements.transform(node, ctx, default_to_null_return=False) # Transform into ANF # Replacing if tests and autobatched function call arguments because # that's where divergence can happen. # Replacing all function calls because the downstream transformation # expects calls to lead directly to assignments. def maybe_replace_function_argument(parent, field_name, child): del field_name, child if not anno.hasanno(parent.func, anno.Basic.QN): return False func_name = anno.getanno(parent.func, anno.Basic.QN) if str(func_name) in autobatch_functions: return True return False anf_config = [ (anf.ASTEdgePattern(gast.If, 'test', anf.ANY), anf.REPLACE), (anf.ASTEdgePattern(anf.ANY, anf.ANY, gast.Call), anf.REPLACE), (anf.ASTEdgePattern(gast.Call, 'args', anf.ANY), maybe_replace_function_argument), ] node = anf.transform(node, ctx, config=anf_config) node = converter.standard_analysis(node, ctx, is_initial=False) return node, ctx