def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn, future_features=()) entity_info = transformer.EntityInfo(source_code=source, source_file=None, future_features=(), namespace={}) node = qual_names.resolve(node) ctx = transformer.Context(entity_info) node = activity.resolve(node, ctx) graphs = cfg.build(node) node = reaching_definitions.resolve(node, ctx, graphs, reaching_definitions.Definition) return node
def _parse_and_analyze(self, test_fn): # TODO(mdan): Use a custom FunctionTransformer here. node, source = parser.parse_entity(test_fn, future_features=()) entity_info = transformer.EntityInfo(name=test_fn.__name__, source_code=source, source_file=None, future_features=(), namespace={}) node = qual_names.resolve(node) namer = naming.Namer({}) ctx = transformer.Context(entity_info, namer, None) node = activity.resolve(node, ctx) return node, entity_info
def mlir_gen_internal(node, entity_info): """Returns mlir module for unprocessed node `node`.""" namer = naming.Namer({}) graphs = cfg.build(node) ctx = transformer.Context(entity_info, namer, None) node = qual_names.resolve(node) node = activity.resolve(node, ctx) node = reaching_definitions.resolve(node, ctx, graphs) node = reaching_fndefs.resolve(node, ctx, graphs) node = liveness.resolve(node, ctx, graphs) mlir_generator = MLIRGen(ctx) mlir_generator.visit(node) return mlir_generator.prog
def _parse_and_analyze(self, test_fn): node, _, source = parser.parse_entity(test_fn, future_imports=()) entity_info = transformer.EntityInfo(source_code=source, source_file=None, namespace={}, arg_values=None, arg_types=None) node = qual_names.resolve(node) ctx = transformer.Context(entity_info) node = activity.resolve(node, ctx) graphs = cfg.build(node) liveness.resolve(node, ctx, graphs) return node
def _parse_and_analyze(self, test_fn, namespace, arg_types=None): node, source = parser.parse_entity(test_fn) entity_info = transformer.EntityInfo(source_code=source, source_file=None, namespace=namespace, arg_values=None, arg_types=arg_types) node = qual_names.resolve(node) graphs = cfg.build(node) ctx = transformer.Context(entity_info) node = activity.resolve(node, ctx) node = reaching_definitions.resolve(node, ctx, graphs, reaching_definitions.Definition) node = live_values.resolve(node, ctx, {}) node = type_info.resolve(node, ctx) node = live_values.resolve(node, ctx, {}) return node
def transform_function(self, fn, user_context): """Transforms a function. Subclasses may override this method. The return value is opaque. The method receives the original AST. The result is passed as-is to the output of `transform`. Args: fn: A function or lambda. user_context: An opaque object (may be None) that is forwarded to transform_ast, through the ctx.user_context argument. Returns: Any. By default it returns the output of transform_ast. """ future_features = inspect_utils.getfutureimports(fn) node, source = parser.parse_entity(fn, future_features=future_features) logging.log(3, 'Source code of %s:\n\n%s\n', fn, source) origin_info.resolve_entity(node, source, fn) namespace = inspect_utils.getnamespace(fn) namer = naming.Namer(namespace) new_name = namer.new_symbol(self.get_transformed_name(node), ()) entity_info = transformer.EntityInfo(name=new_name, source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) context = transformer.Context(entity_info, namer, user_context) node = self._erase_arg_defaults(node) node = self.transform_ast(node, context) if isinstance(node, gast.Lambda): node = gast.Assign(targets=[ gast.Name(new_name, ctx=gast.Store(), annotation=None, type_comment=None) ], value=node) else: node.name = new_name return node, context
def prepare(self, test_fn, namespace, recursive=True): namespace['ConversionOptions'] = converter.ConversionOptions future_features = ('print_function', 'division') node, source = parser.parse_entity(test_fn, future_features=future_features) namer = naming.Namer(namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=recursive), autograph_module=None) entity_info = transformer.EntityInfo(name=test_fn.__name__, source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) ctx = transformer.Context(entity_info, namer, program_ctx) origin_info.resolve_entity(node, source, test_fn) node = converter.standard_analysis(node, ctx, is_initial=True) return node, ctx
def transform_function(self, fn, user_context): """Transforms a function. Subclasses may override this method. The return value is opaque. The method receives the original AST. The result is passed as-is to the output of `transform`. Args: fn: A function or lambda. user_context: An opaque object (may be None) that is forwarded to transform_ast, through the ctx.user_context argument. Returns: Tuple[Any, Any]. By default it returns the output of transform_ast, together with a `transformer.Context` containing information about the transformation process. """ future_features = inspect_utils.getfutureimports(fn) node, source = parser.parse_entity(fn, future_features=future_features) logging.log(3, 'Source code of %s:\n\n%s\n', fn, source) origin_info.resolve_entity(node, source, fn) namespace = inspect_utils.getnamespace(fn) namer = naming.Namer(namespace) new_name = namer.new_symbol(self.get_transformed_name(node), ()) entity_info = transformer.EntityInfo( name=new_name, source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) context = transformer.Context(entity_info, namer, user_context) node = self._erase_arg_defaults(node) result = self.transform_ast(node, context) return result, context
def _transform_function(self, fn, user_context): """Performs source code transformation on a function.""" future_features = inspect_utils.getfutureimports(fn) node, source = parser.parse_entity(fn, future_features=future_features) logging.log(3, 'Source code of %s:\n\n%s\n', fn, source) origin_info.resolve_entity(node, source, fn) namespace = inspect_utils.getnamespace(fn) namer = naming.Namer(namespace) new_name = namer.new_symbol(self.get_transformed_name(node), ()) entity_info = transformer.EntityInfo( name=new_name, source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) context = transformer.Context(entity_info, namer, user_context) node = self._erase_arg_defaults(node) node = self.transform_ast(node, context) if isinstance(node, gast.Lambda): node = gast.Assign( targets=[ gast.Name( new_name, ctx=gast.Store(), annotation=None, type_comment=None) ], value=node) else: node.name = new_name return node, context
def _parse_and_analyze(f, autobatch_functions): """Performs preliminary analyses and transformations. The goal is to massage the source program into a form on which the `_AutoBatchingTransformer` below will be successful. Args: f: Function to analyze autobatch_functions: List of Python `str` names of autobatched functions. Arguments to these functions will be canonicalized to variable references, but others will not. Returns: node: A Python AST node representing the function, suitable for passing to `_AutoBatchingTransformer.visit` entity_info: An AutoGraph `EntityInfo` object, with some information about `f`. Required for initializing `_AutoBatchingTransformer`. """ # TODO(mdan): Replace all this boilerplate with FunctionTranspiler. namespace = {} # Get the AST of the function future_features = inspect_utils.getfutureimports(f) node, _ = parser.parse_entity(f, future_features=future_features) # Boilerplate for AutoGraph transforms if hasattr(converter, 'EntityContext'): # TF 2.2- entity_info = transformer.EntityInfo(source_code='', source_file=None, future_features=future_features, namespace=namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=True), autograph_module=None) ctx = converter.EntityContext(namer=naming.Namer(namespace), entity_info=entity_info, program_ctx=program_ctx) else: # TF 2.3+ entity_info = transformer.EntityInfo(name=f.__name__, source_code='', source_file=None, future_features=future_features, namespace=namespace) program_ctx = converter.ProgramContext( options=converter.ConversionOptions(recursive=True), autograph_module=None) ctx = transformer.Context(info=entity_info, namer=naming.Namer(namespace), user_context=program_ctx) # Canonicalize away break statements node = converter.standard_analysis(node, ctx) node = break_statements.transform(node, ctx) # Canonicalize away continue statements node = converter.standard_analysis(node, ctx) node = continue_statements.transform(node, ctx) # Force single returns node = converter.standard_analysis(node, ctx) node = return_statements.transform(node, ctx, default_to_null_return=False) # Transform into ANF # Replacing if tests and autobatched function call arguments because # that's where divergence can happen. # Replacing all function calls because the downstream transformation # expects calls to lead directly to assignments. def maybe_replace_function_argument(parent, field_name, child): del field_name, child if not anno.hasanno(parent.func, anno.Basic.QN): return False func_name = anno.getanno(parent.func, anno.Basic.QN) if str(func_name) in autobatch_functions: return True return False anf_config = [ (anf.ASTEdgePattern(gast.If, 'test', anf.ANY), anf.REPLACE), (anf.ASTEdgePattern(anf.ANY, anf.ANY, gast.Call), anf.REPLACE), (anf.ASTEdgePattern(gast.Call, 'args', anf.ANY), maybe_replace_function_argument), ] node = anf.transform(node, ctx, config=anf_config) node = qual_names.resolve(node) node = activity.resolve(node, ctx) return node, ctx
def _simple_context(self): entity_info = transformer.EntityInfo( source_code=None, source_file=None, future_features=(), namespace=None) return transformer.Context(entity_info)
def _live_tensors(f, attr_name="inputs"): """Returns the indices of the used inputs. Note: This currently only handles direct index accesses e.g. op.inputs[1]. If the function has slicing or list comprehension on attr_name then returns _ALL. This ensure that this is correct even if inefficient. Args: f: A grad function, taking the op as first argument. attr_name: op attr to track. "inputs" or "outputs". Returns: Either one of: * set of integers representing individual indices of inputs used * the value _ALL, if indices are used but cannot be determined which * empty set, if no inputs are used """ node, _ = parser.parse_entity(f, ()) entity_info = transformer.EntityInfo( name=f.__name__, source_code=None, source_file=None, future_features=(), namespace=sys.modules[f.__module__].__dict__) ctx = transformer.Context(entity_info, None, None) graphs = cfg.build(node) node = qual_names.resolve(node) node = activity.resolve(node, ctx, None) node = reaching_fndefs.resolve(node, ctx, graphs) node = liveness.resolve(node, ctx, graphs) op_arg_name = anno.getanno(node.args.args[0], anno.Basic.QN) op_inputs_outputs_name = qual_names.QN(op_arg_name, attr=attr_name) special_tracker = _SubscriptUseTracker(ctx, (op_inputs_outputs_name, )) node = special_tracker.visit(node) live_vars_in = anno.getanno(node.body[0], anno.Static.LIVE_VARS_IN) inputs_outputs_used_qns = set() for v in special_tracker.complex_reads: # Complicated patterns like op.inputs[:3]. Could be smarter about them # if they matter much. if v == op_inputs_outputs_name: return _ALL for v in live_vars_in: if v in special_tracker.reads: if (v.has_subscript() and v.parent == op_inputs_outputs_name): inputs_outputs_used_qns.add(v) elif v == op_inputs_outputs_name: # When op.{attr_name} is used directly, assume all tensors are # used for now. In that case, no point digging further. # TODO(mdan): We can descend into tuple expansions. return _ALL function_calls_tracker = _FunctionCallsTracker(ctx, op_arg_name) node = function_calls_tracker.visit(node) input_output_indices = set() for called_f in function_calls_tracker.calls: child_indices = _live_tensors(called_f, attr_name=attr_name) if child_indices is _ALL: return _ALL input_output_indices |= child_indices for v in inputs_outputs_used_qns: assert v.has_subscript() _, subscript = v.qn if not subscript.is_simple(): # Not a number, assuming it can be anything. return _ALL subscript_val, = subscript.qn if (not isinstance(subscript_val, qual_names.Literal) and not isinstance(subscript_val.value, int)): # Not a number, assuming it can be anything. return _ALL input_output_indices.add(subscript_val.value) return input_output_indices