def visit(self, node): if not isinstance(node, gast.AST): # This is not that uncommon a mistake: various node bodies are lists, for # example, posing a land mine for transformers that need to recursively # call `visit`. The error needs to be raised before the exception handler # below is installed, because said handler will mess up if `node` is not, # in fact, a node. msg = ('invalid value for "node": expected "ast.AST", got "{}"; to' ' visit lists of nodes, use "visit_block" instead').format( type(node)) raise ValueError(msg) if anno.hasanno(node, anno.Basic.SKIP_PROCESSING): return node parent_origin = self.ctx.current_origin if anno.hasanno(node, anno.Basic.ORIGIN): self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN) try: processing_expr_node = isinstance(node, gast.Expr) if processing_expr_node: entry_expr_value = node.value result = super(Base, self).visit(node) # Adjust for consistency: replacing the value of an Expr with # an Assign node removes the need for the Expr node. if (processing_expr_node and isinstance(result, gast.Expr) and (result.value is not entry_expr_value)): # When the replacement is a list, it is assumed that the list came # from a template that contained a number of statements, which # themselves are standalone and don't require an enclosing Expr. if isinstance(result.value, (list, tuple, gast.Assign, gast.AugAssign)): result = result.value # By default, all replacements receive the origin info of the replaced # node. if result is not node and result is not None: inherited_origin = anno.getanno(node, anno.Basic.ORIGIN, default=parent_origin) if inherited_origin is not None: nodes_to_adjust = result if isinstance(result, (list, tuple)): nodes_to_adjust = result else: nodes_to_adjust = (result, ) for n in nodes_to_adjust: if not anno.hasanno(n, anno.Basic.ORIGIN): anno.setanno(n, anno.Basic.ORIGIN, inherited_origin) finally: self.ctx.current_origin = parent_origin return result
def get_definition_directive(self, node, directive, arg, default): """Returns the unique directive argument for a symbol. See lang/directives.py for details on directives. Example: # Given a directive in the code: ag.foo_directive(bar, baz=1) # One can write for an AST node Name(id='bar'): get_definition_directive(node, ag.foo_directive, 'baz') Args: node: ast.AST, the node representing the symbol for which the directive argument is needed. directive: Callable[..., Any], the directive to search. arg: str, the directive argument to return. default: Any Raises: ValueError: if conflicting annotations have been found """ defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ()) if not defs: return default arg_values_found = [] for def_ in defs: if (directive in def_.directives and arg in def_.directives[directive]): arg_values_found.append(def_.directives[directive][arg]) if not arg_values_found: return default if len(arg_values_found) == 1: return arg_values_found[0] # If multiple annotations reach the symbol, they must all match. If they do, # return any of them. first_value = arg_values_found[0] for other_value in arg_values_found[1:]: if not ast_util.matches(first_value, other_value): qn = anno.getanno(node, anno.Basic.QN) raise ValueError( '%s has ambiguous annotations for %s(%s): %s, %s' % (qn, directive.__name__, arg, parser.unparse(other_value).strip(), parser.unparse(first_value).strip())) return first_value
def visit_While(self, node): original_node = node scope = anno.getanno(node, NodeAnno.BODY_SCOPE) break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.test = self.visit(node.test) node.body, break_used = self._process_body(node.body, break_var) # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) if break_used: # Python's else clause only triggers if the loop exited cleanly (e.g. # break did not trigger). guarded_orelse = self._guard_if_present(node.orelse, break_var) template = """ var_name = False while ag__.and_(lambda: test, lambda: ag__.not_(var_name)): body else: orelse """ node = templates.replace(template, var_name=break_var, test=node.test, body=node.body, orelse=guarded_orelse) new_while_node = node[1] anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES) return node
def visit_Lambda(self, node): self.state[_Function].enter() node = self.generic_visit(node) # Only wrap the top-level function. Theoretically, we can and should wrap # everything, but that can lead to excessive boilerplate when lambdas are # nested. # TODO(mdan): Looks more closely for use cases that actually require this. if self.state[_Function].level > 2: self.state[_Function].exit() return node scope = anno.getanno(node, anno.Static.SCOPE) function_context_name = self.ctx.namer.new_symbol( 'lscope', scope.referenced) self.state[_Function].context_name = function_context_name anno.setanno(node, 'function_context_name', function_context_name) template = """ ag__.with_function_scope( lambda function_context: body, function_context_name, options) """ node.body = templates.replace_as_expression( template, options=self._function_scope_options().to_ast(), function_context=function_context_name, function_context_name=gast.Constant(function_context_name, kind=None), body=node.body) self.state[_Function].exit() return node
def visit_FunctionDef(self, node): self.state[_Function].enter() # Note: if the conversion process ever creates helper functions, this # assumption will no longer hold. assert anno.hasanno(node, 'function_context_name'), ( 'The function_scopes converter always creates a scope for functions.' ) self.state[_Function].context_name = anno.getanno( node, 'function_context_name') node.args = self.visit(node.args) node.body = self.visit_block(node.body) if self.state[_Function].level < 2: # Top-level functions lose their decorator because the conversion is # always just-in-time and by the time it happens the decorators are # already set to be applied. node.decorator_list = [] else: # TODO(mdan): Fix the tests so that we can always add this decorator. # Inner functions are converted already, so we insert a decorator to # prevent double conversion. Double conversion would work too, but this # saves the overhead. node.decorator_list.append( parser.parse_expression('ag__.autograph_artifact')) if node.returns: node.returns = self.visit(node.returns) self.state[_Function].exit() return node
def visit_Attribute(self, node): node = self.generic_visit(node) if anno.hasanno(node.value, anno.Basic.QN): anno.setanno( node, anno.Basic.QN, QN(anno.getanno(node.value, anno.Basic.QN), attr=node.attr)) return node
def visit_FunctionDef(self, node): parent_analyzer = self.current_analyzer subgraph = self.graphs[node] # Preorder tree processing: # 1. if this is a child function, the parent was already analyzed and it # has the proper state value for the subgraph's entry # 2. analyze the current function body # 2. recursively walk the subtree; child functions will be processed analyzer = Analyzer(subgraph, self.definition_factory) if parent_analyzer is not None: # Wire the state between the two subgraphs' analyzers. parent_out_state = parent_analyzer.out[ parent_analyzer.graph.index[node]] # Exception: symbols modified in the child function are local to it body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) parent_out_state -= body_scope.modified analyzer.extra_in[node.args] = parent_out_state # Complete the analysis for the local function and annotate its body. analyzer.visit_forward() # Recursively process any remaining subfunctions. self.current_analyzer = analyzer # Note: not visiting name, decorator_list and returns because they don't # apply to this analysis. # TODO(mdan): Should we still process the function name? node.args = self.visit(node.args) node.body = self.visit_block(node.body) self.current_analyzer = parent_analyzer return node
def visit_For(self, node): self.builder.begin_statement(node) self._enter_lexical_scope(node) self.builder.enter_section(node) # Note: Strictly speaking, this should be node.target + node.iter. # However, the activity analysis accounts for this inconsistency, # so dataflow analysis produces the correct values. self.builder.enter_loop_section(node, node.iter) # Also include the "extra loop test" annotation, to capture things like the # control variable for return and break in for loops. if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): self._process_basic_statement( anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)) for stmt in node.body: self.visit(stmt) self.builder.exit_loop_section(node) # Note: although the orelse is technically part of the loop node, # they don't count as loop bodies. For example, a break in the loop's # orelse will affect the parent loop, not the current one. self._exit_lexical_scope(node) for stmt in node.orelse: self.visit(stmt) self.builder.exit_section(node) self.builder.end_statement(node)
def _node_sets_self_attribute(self, node): if anno.hasanno(node, anno.Basic.QN): qn = anno.getanno(node, anno.Basic.QN) # TODO(mdan): The 'self' argument is not guaranteed to be called 'self'. if qn.has_attr and qn.parent.qn == ('self', ): return True return False
def visit_Attribute(self, node): node = self.generic_visit(node) parent_val = anno.getanno(node.value, STATIC_VALUE, default=None) if parent_val is not None and inspect.ismodule(parent_val): if hasattr(parent_val, node.attr): anno.setanno(node, STATIC_VALUE, getattr(parent_val, node.attr)) return node
def _postprocess_statement(self, node): # If the node definitely returns (e.g. it's a with statement with a # return statement in it), then the current block also definitely returns. if anno.getanno(node, STMT_DEFINITELY_RETURNS, default=False): self.state[_RewriteBlock].definitely_returns = True # The special case: collapse a typical conditional return pattern into # a single conditional with possibly returns on both branches. This # reduces the use of None return values, which don't work with TF # conditionals. if (isinstance(node, gast.If) and anno.getanno( node, BODY_DEFINITELY_RETURNS, default=False)): return node, node.orelse elif (isinstance(node, gast.If) and anno.getanno( node, ORELSE_DEFINITELY_RETURNS, default=False)): return node, node.body return node, None
def _get_loop_vars(self, node, modified): body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN) live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) reserved_symbols = body_scope.referenced basic_loop_vars = self._get_basic_loop_vars(modified, live_in, live_out) composite_loop_vars = self._get_composite_loop_vars(modified, live_in) loop_vars = tuple(basic_loop_vars | composite_loop_vars) # Variable that are used or defined inside the loop, but not defined # before entering the loop. Only simple variables must be defined. The # composite ones will be implicitly checked at runtime. undefined_lives = basic_loop_vars - defined_in return loop_vars, reserved_symbols, undefined_lives
def visit_Name(self, node): node = self.generic_visit(node) if isinstance(node.ctx, gast.Load): defs = anno.getanno(node, anno.Static.DEFINITIONS, ()) is_defined = bool(defs) if not is_defined and node.id in self.ctx.info.namespace: anno.setanno(node, STATIC_VALUE, self.ctx.info.namespace[node.id]) return node
def _process_symbol_directive(self, call_node, directive): if len(call_node.args) < 1: raise ValueError('"%s" requires a positional first argument' ' as the target' % directive.__name__) target = call_node.args[0] defs = anno.getanno(target, anno.Static.ORIG_DEFINITIONS) for def_ in defs: def_.directives[directive] = _map_args(call_node, directive) return call_node
def visit_ExceptHandler(self, node): self._enter_scope(False) # try/except oddity: as expected, it leaks any names you defined inside the # except block, but not the name of the exception variable. if node.name is not None: self.scope.isolated_names.add( anno.getanno(node.name, anno.Basic.QN)) node = self.generic_visit(node) self._exit_scope() return node
def visit_FunctionDef(self, node): self.state[_Function].enter() self.state[_Block].enter() self.state[_Block].is_function = True scope = anno.getanno(node, NodeAnno.BODY_SCOPE) do_return_var_name = self.ctx.namer.new_symbol('do_return', scope.referenced) retval_var_name = self.ctx.namer.new_symbol('retval_', scope.referenced) self.state[_Function].do_return_var_name = do_return_var_name self.state[_Function].retval_var_name = retval_var_name converted_body = self._visit_statement_block(node, node.body) # Avoid placing statements before any eventual docstring. # TODO(mdan): Should a docstring even be included in the output? docstring = None if converted_body: if (isinstance(converted_body[0], gast.Expr) and isinstance(converted_body[0].value, gast.Constant)): docstring = converted_body[0] converted_body = converted_body[1:] if self.state[_Block].return_used: if self.default_to_null_return: # TODO(mdan): Remove the (do_return_var_name,) below. # Currently, that line ensures the variable is both defined and alive # throughout the function. template = """ do_return_var_name = False retval_var_name = ag__.UndefinedReturnValue() body (do_return_var_name,) return ag__.retval(retval_var_name) """ else: template = """ body return retval_var_name """ node.body = templates.replace( template, body=converted_body, do_return_var_name=do_return_var_name, retval_var_name=retval_var_name) if docstring: node.body.insert(0, docstring) self.state[_Block].exit() self.state[_Function].exit() return node
def _block_statement_live_in(self, node, entry_node): if entry_node in self.current_analyzer.graph.index: cfg_node = self.current_analyzer.graph.index[entry_node] stmt_live_in = frozenset(self.current_analyzer.in_[cfg_node]) else: assert anno.hasanno(entry_node, anno.Static.LIVE_VARS_IN), ( 'If not matching a CFG node, must be a block statement:' ' {}'.format(entry_node)) stmt_live_in = anno.getanno(entry_node, anno.Static.LIVE_VARS_IN) anno.setanno(node, anno.Static.LIVE_VARS_IN, stmt_live_in) return node
def visit_Lambda(self, node): if anno.hasanno(node, 'function_context_name'): # Lambda functions created during the conversion process have no # context manager. self.state[_Function].enter() self.state[_Function].context_name = anno.getanno( node, 'function_context_name') node = self.generic_visit(node) self.state[_Function].exit() else: node = self.generic_visit(node) return node
def _process_name_node(self, node): qn = anno.getanno(node, anno.Basic.QN) if qn in self.name_map: new_node = gast.Name(str(self.name_map[qn]), ctx=node.ctx, annotation=None, type_comment=None) # All annotations get carried over. for k in anno.keys(node): anno.copyanno(node, new_node, k) return new_node return self.generic_visit(node)
def visit_While(self, node): node = self.generic_visit(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) loop_vars, reserved_symbols, possibly_undefs = self._get_loop_vars( node, body_scope.modified) undefined_assigns = self._create_undefined_assigns(possibly_undefs) nonlocal_declarations = self._create_nonlocal_declarations(loop_vars) state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols) state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols) state_functions = self._create_state_functions(loop_vars, nonlocal_declarations, state_getter_name, state_setter_name) opts = self._create_loop_options(node) template = """ state_functions def body_name(): nonlocal_declarations body def test_name(): return test undefined_assigns ag__.pt_while_stmt( test_name, body_name, state_getter_name, state_setter_name, (symbol_names,), opts) """ return templates.replace( template, body=node.body, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), nonlocal_declarations=nonlocal_declarations, opts=opts, state_functions=state_functions, state_getter_name=state_getter_name, state_setter_name=state_setter_name, symbol_names=tuple( gast.Constant(str(s), kind=None) for s in loop_vars), test=node.test, test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), undefined_assigns=undefined_assigns)
def _process_statement_directive(self, call_node, directive): if self.state[_LoopScope].statements_visited > 1: raise ValueError( '"%s" must be the first statement in the loop block' % (directive.__name__)) if self.state[_LoopScope].level < 2: raise ValueError('"%s" must be used inside a statement' % directive.__name__) target = self.state[_LoopScope].ast_node node_anno = anno.getanno(target, anno.Basic.DIRECTIVES, {}) node_anno[directive] = _map_args(call_node, directive) anno.setanno(target, anno.Basic.DIRECTIVES, node_anno) return call_node
def _create_loop_options(self, node): if not anno.hasanno(node, anno.Basic.DIRECTIVES): return gast.Dict([], []) loop_directives = anno.getanno(node, anno.Basic.DIRECTIVES) if directives.set_loop_options not in loop_directives: return gast.Dict([], []) opts_dict = loop_directives[directives.set_loop_options] str_keys, values = zip(*opts_dict.items()) keys = [gast.Constant(s, kind=None) for s in str_keys] values = list(values) # ast and gast don't play well with tuples. return gast.Dict(keys, values)
def visit_Subscript(self, node): # TODO(mdan): This may no longer apply if we overload getitem. node = self.generic_visit(node) s = node.slice if not isinstance(s, gast.Index): # TODO(mdan): Support range and multi-dimensional indices. # Continuing silently because some demos use these. return node if isinstance(s.value, gast.Constant): subscript = QN(NumberLiteral(s.value.value)) else: # The index may be an expression, case in which a name doesn't make sense. if anno.hasanno(node.slice.value, anno.Basic.QN): subscript = anno.getanno(node.slice.value, anno.Basic.QN) else: return node if anno.hasanno(node.value, anno.Basic.QN): anno.setanno( node, anno.Basic.QN, QN(anno.getanno(node.value, anno.Basic.QN), subscript=subscript)) return node
def _replace_pop_call(self, node): # Expressions that use pop() are converted to a statement + expression. # # For example: # # print(target.pop()) # # ... is converted to: # # target, target_pop = ag__.list_pop(target) # print(target_pop) # # Here, we just generate the variable name and swap it in, # and _generate_pop_operation will handle the rest. # # Multiple uses of pop() are allowed: # # print(tartget.pop(), target.pop()) # print(tartget.pop().pop()) # assert isinstance(node.func, gast.Attribute) scope = anno.getanno(node, NodeAnno.ARGS_SCOPE) target_node = node.func.value # Attempt to use a related name if one exists. Otherwise use something # generic. if anno.hasanno(target_node, anno.Basic.QN): target_name = anno.getanno(target_node, anno.Basic.QN).ssf() else: target_name = 'list_' pop_var_name = self.ctx.namer.new_symbol(target_name, scope.referenced) stmt = self.state[_Statement] if stmt.pop_uses is None: stmt.pop_uses = [] stmt.pop_uses.append((node, pop_var_name)) return templates.replace_as_expression('var_name', var_name=pop_var_name)
def _track_symbol(self, node, composite_writes_alter_parent=False): # A QN may be missing when we have an attribute (or subscript) on a function # call. Example: a().b if not anno.hasanno(node, anno.Basic.QN): return qn = anno.getanno(node, anno.Basic.QN) # When inside a comprehension, ignore reads to any of the comprehensions's # targets. This includes attributes or slices of those arguments. for l in self.state[_Comprehension]: if qn in l.targets: return if qn.owner_set & set(l.targets): return if isinstance(node.ctx, gast.Store): # In comprehensions, modified symbols are the comprehension targets. if self.state[_Comprehension].level > 0: self.state[_Comprehension].targets.add(qn) # List comprehension targets leak in Python 2. # For details, see: # https://stackoverflow.com/questions/4198906/list-comprehension-rebinds-names-even-after-scope-of-comprehension-is-this-righ if not (six.PY2 and self.state[_Comprehension].is_list_comp): return self.scope.modified.add(qn) self.scope.bound.add(qn) if qn.is_composite and composite_writes_alter_parent: self.scope.modified.add(qn.parent) if self._in_aug_assign: self.scope.read.add(qn) elif isinstance(node.ctx, gast.Load): self.scope.read.add(qn) elif isinstance(node.ctx, gast.Param): self.scope.bound.add(qn) self.scope.mark_param(qn, self.state[_FunctionOrClass].node) elif isinstance(node.ctx, gast.Del): # The read matches the Python semantics - attempting to delete an # undefined symbol is illegal. self.scope.read.add(qn) # Targets of del are considered bound: # https://docs.python.org/3/reference/executionmodel.html#binding-of-names self.scope.bound.add(qn) self.scope.deleted.add(qn) else: raise ValueError('Unknown context {} for node "{}".'.format( type(node.ctx), qn))
def _determine_aliased_symbols(self, scope, node_defined_in, block): if block: block_live_in = set( anno.getanno(block[0], anno.Static.LIVE_VARS_IN)) else: block_live_in = set() modified_live = scope.modified & node_defined_in & block_live_in # Composite symbols are handled elsewhere, see _create_state_functions return { s for s in modified_live if not s.is_composite() and s not in self.state[_Function].scope.globals }
def visit_Call(self, node): full_name = str(anno.getanno(node.func, anno.Basic.QN, default='')) function_context_name = self.state[_Function].context_name node = self.generic_visit(node) # TODO(mdan): Refactor converted_call as a 'Call' operator. # Calls to the internal 'ag__' module are never converted (though their # arguments might be). if full_name.startswith('ag__.'): return node # Calls to the function context manager (inserted by function_scopes) are # also safe. if full_name.startswith(function_context_name + '.'): return node # Calls to pdb.set_trace or ipdb.set_trace are never converted. We don't use # the normal mechanisms to bypass these literals because they are sensitive # to the frame they are being called from. # TODO(mdan): Generalize this to a "static whitelist" config. if full_name in ('pdb.set_trace', 'ipdb.set_trace', 'breakpoint'): global set_trace_warned if not set_trace_warned: # TODO(mdan): Update and shorten once available on tensorflow.org. ag_logging.warn( 'Detected `pdb.set_trace()` in user code. The code' ' generated by AutoGraph is not optimized for step-by-step' ' debugging. See https://github.com/tensorflow/tensorflow/' 'blob/master/tensorflow/python/autograph/g3doc/reference/' 'debugging.md.') set_trace_warned = True return node if (full_name == 'print' and not self.ctx.program.options.uses( converter.Feature.BUILTIN_FUNCTIONS)): return node template = """ ag__.converted_call(func, args, kwargs, function_ctx) """ new_call = templates.replace_as_expression( template, func=node.func, args=self._args_to_tuple(node), kwargs=self._kwargs_to_dict(node), function_ctx=function_context_name) return new_call
def visit_For(self, node): self._enter_scope(False) node.target = self.visit(node.target) node.iter = self.visit(node.iter) self._exit_and_record_scope(node.iter) self._enter_scope(False) self.visit(node.target) if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): self._process_statement( anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)) self._exit_and_record_scope(node, tag=NodeAnno.ITERATE_SCOPE) node = self._process_parallel_blocks( node, ((node.body, NodeAnno.BODY_SCOPE), (node.orelse, NodeAnno.ORELSE_SCOPE))) return node
def visit_Expr(self, node): self.state[_LoopScope].statements_visited += 1 node = self.generic_visit(node) if isinstance(node.value, gast.Call): call_node = node.value static_val = anno.getanno(call_node.func, STATIC_VALUE, default=None) if static_val is not None: # Note: directive calls are not output in the generated code, hence # the removal from the code by returning None. if static_val is directives.set_element_type: self._process_symbol_directive(call_node, static_val) return None elif static_val is directives.set_loop_options: self._process_statement_directive(call_node, static_val) return None return node
def visit_For(self, node): original_node = node scope = anno.getanno(node, NodeAnno.BODY_SCOPE) break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.target = self.visit(node.target) node.iter = self.visit(node.iter) node.body, break_used = self._process_body(node.body, break_var) # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) if break_used: # Python's else clause only triggers if the loop exited cleanly (e.g. # break did not trigger). guarded_orelse = self._guard_if_present(node.orelse, break_var) extra_test = templates.replace_as_expression('ag__.not_(var_name)', var_name=break_var) # The extra test is hidden in the AST, which will confuse the static # analysis. To mitigate that, we insert a no-op statement that ensures # the control variable is marked as used. # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) template = """ var_name = False for target in iter_: (var_name,) body else: orelse """ node = templates.replace(template, var_name=break_var, iter_=node.iter, target=node.target, body=node.body, orelse=guarded_orelse) new_for_node = node[1] anno.setanno(new_for_node, anno.Basic.EXTRA_LOOP_TEST, extra_test) anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) return node