def visit(self, node): if not isinstance(node, gast.AST): # This is not that uncommon a mistake: various node bodies are lists, for # example, posing a land mine for transformers that need to recursively # call `visit`. The error needs to be raised before the exception handler # below is installed, because said handler will mess up if `node` is not, # in fact, a node. msg = ('invalid value for "node": expected "ast.AST", got "{}"; to' ' visit lists of nodes, use "visit_block" instead').format( type(node)) raise ValueError(msg) if anno.hasanno(node, anno.Basic.SKIP_PROCESSING): return node parent_origin = self.ctx.current_origin if anno.hasanno(node, anno.Basic.ORIGIN): self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN) try: processing_expr_node = isinstance(node, gast.Expr) if processing_expr_node: entry_expr_value = node.value result = super(Base, self).visit(node) # Adjust for consistency: replacing the value of an Expr with # an Assign node removes the need for the Expr node. if (processing_expr_node and isinstance(result, gast.Expr) and (result.value is not entry_expr_value)): # When the replacement is a list, it is assumed that the list came # from a template that contained a number of statements, which # themselves are standalone and don't require an enclosing Expr. if isinstance(result.value, (list, tuple, gast.Assign, gast.AugAssign)): result = result.value # By default, all replacements receive the origin info of the replaced # node. if result is not node and result is not None: inherited_origin = anno.getanno(node, anno.Basic.ORIGIN, default=parent_origin) if inherited_origin is not None: nodes_to_adjust = result if isinstance(result, (list, tuple)): nodes_to_adjust = result else: nodes_to_adjust = (result, ) for n in nodes_to_adjust: if not anno.hasanno(n, anno.Basic.ORIGIN): anno.setanno(n, anno.Basic.ORIGIN, inherited_origin) finally: self.ctx.current_origin = parent_origin return result
def _node_sets_self_attribute(self, node): if anno.hasanno(node, anno.Basic.QN): qn = anno.getanno(node, anno.Basic.QN) # TODO(mdan): The 'self' argument is not guaranteed to be called 'self'. if qn.has_attr and qn.parent.qn == ('self', ): return True return False
def visit_Attribute(self, node): node = self.generic_visit(node) if anno.hasanno(node.value, anno.Basic.QN): anno.setanno( node, anno.Basic.QN, QN(anno.getanno(node.value, anno.Basic.QN), attr=node.attr)) return node
def visit_FunctionDef(self, node): self.state[_Function].enter() # Note: if the conversion process ever creates helper functions, this # assumption will no longer hold. assert anno.hasanno(node, 'function_context_name'), ( 'The function_scopes converter always creates a scope for functions.' ) self.state[_Function].context_name = anno.getanno( node, 'function_context_name') node.args = self.visit(node.args) node.body = self.visit_block(node.body) if self.state[_Function].level < 2: # Top-level functions lose their decorator because the conversion is # always just-in-time and by the time it happens the decorators are # already set to be applied. node.decorator_list = [] else: # TODO(mdan): Fix the tests so that we can always add this decorator. # Inner functions are converted already, so we insert a decorator to # prevent double conversion. Double conversion would work too, but this # saves the overhead. node.decorator_list.append( parser.parse_expression('ag__.autograph_artifact')) if node.returns: node.returns = self.visit(node.returns) self.state[_Function].exit() return node
def visit_For(self, node): self.builder.begin_statement(node) self._enter_lexical_scope(node) self.builder.enter_section(node) # Note: Strictly speaking, this should be node.target + node.iter. # However, the activity analysis accounts for this inconsistency, # so dataflow analysis produces the correct values. self.builder.enter_loop_section(node, node.iter) # Also include the "extra loop test" annotation, to capture things like the # control variable for return and break in for loops. if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): self._process_basic_statement( anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)) for stmt in node.body: self.visit(stmt) self.builder.exit_loop_section(node) # Note: although the orelse is technically part of the loop node, # they don't count as loop bodies. For example, a break in the loop's # orelse will affect the parent loop, not the current one. self._exit_lexical_scope(node) for stmt in node.orelse: self.visit(stmt) self.builder.exit_section(node) self.builder.end_statement(node)
def _block_statement_live_in(self, node, entry_node): if entry_node in self.current_analyzer.graph.index: cfg_node = self.current_analyzer.graph.index[entry_node] stmt_live_in = frozenset(self.current_analyzer.in_[cfg_node]) else: assert anno.hasanno(entry_node, anno.Static.LIVE_VARS_IN), ( 'If not matching a CFG node, must be a block statement:' ' {}'.format(entry_node)) stmt_live_in = anno.getanno(entry_node, anno.Static.LIVE_VARS_IN) anno.setanno(node, anno.Static.LIVE_VARS_IN, stmt_live_in) return node
def visit_Lambda(self, node): if anno.hasanno(node, 'function_context_name'): # Lambda functions created during the conversion process have no # context manager. self.state[_Function].enter() self.state[_Function].context_name = anno.getanno( node, 'function_context_name') node = self.generic_visit(node) self.state[_Function].exit() else: node = self.generic_visit(node) return node
def visit_Subscript(self, node): # TODO(mdan): This may no longer apply if we overload getitem. node = self.generic_visit(node) s = node.slice if not isinstance(s, gast.Index): # TODO(mdan): Support range and multi-dimensional indices. # Continuing silently because some demos use these. return node if isinstance(s.value, gast.Constant): subscript = QN(NumberLiteral(s.value.value)) else: # The index may be an expression, case in which a name doesn't make sense. if anno.hasanno(node.slice.value, anno.Basic.QN): subscript = anno.getanno(node.slice.value, anno.Basic.QN) else: return node if anno.hasanno(node.value, anno.Basic.QN): anno.setanno( node, anno.Basic.QN, QN(anno.getanno(node.value, anno.Basic.QN), subscript=subscript)) return node
def _create_loop_options(self, node): if not anno.hasanno(node, anno.Basic.DIRECTIVES): return gast.Dict([], []) loop_directives = anno.getanno(node, anno.Basic.DIRECTIVES) if directives.set_loop_options not in loop_directives: return gast.Dict([], []) opts_dict = loop_directives[directives.set_loop_options] str_keys, values = zip(*opts_dict.items()) keys = [gast.Constant(s, kind=None) for s in str_keys] values = list(values) # ast and gast don't play well with tuples. return gast.Dict(keys, values)
def _track_symbol(self, node, composite_writes_alter_parent=False): # A QN may be missing when we have an attribute (or subscript) on a function # call. Example: a().b if not anno.hasanno(node, anno.Basic.QN): return qn = anno.getanno(node, anno.Basic.QN) # When inside a comprehension, ignore reads to any of the comprehensions's # targets. This includes attributes or slices of those arguments. for l in self.state[_Comprehension]: if qn in l.targets: return if qn.owner_set & set(l.targets): return if isinstance(node.ctx, gast.Store): # In comprehensions, modified symbols are the comprehension targets. if self.state[_Comprehension].level > 0: self.state[_Comprehension].targets.add(qn) # List comprehension targets leak in Python 2. # For details, see: # https://stackoverflow.com/questions/4198906/list-comprehension-rebinds-names-even-after-scope-of-comprehension-is-this-righ if not (six.PY2 and self.state[_Comprehension].is_list_comp): return self.scope.modified.add(qn) self.scope.bound.add(qn) if qn.is_composite and composite_writes_alter_parent: self.scope.modified.add(qn.parent) if self._in_aug_assign: self.scope.read.add(qn) elif isinstance(node.ctx, gast.Load): self.scope.read.add(qn) elif isinstance(node.ctx, gast.Param): self.scope.bound.add(qn) self.scope.mark_param(qn, self.state[_FunctionOrClass].node) elif isinstance(node.ctx, gast.Del): # The read matches the Python semantics - attempting to delete an # undefined symbol is illegal. self.scope.read.add(qn) # Targets of del are considered bound: # https://docs.python.org/3/reference/executionmodel.html#binding-of-names self.scope.bound.add(qn) self.scope.deleted.add(qn) else: raise ValueError('Unknown context {} for node "{}".'.format( type(node.ctx), qn))
def visit_For(self, node): self._enter_scope(False) node.target = self.visit(node.target) node.iter = self.visit(node.iter) self._exit_and_record_scope(node.iter) self._enter_scope(False) self.visit(node.target) if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): self._process_statement( anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)) self._exit_and_record_scope(node, tag=NodeAnno.ITERATE_SCOPE) node = self._process_parallel_blocks( node, ((node.body, NodeAnno.BODY_SCOPE), (node.orelse, NodeAnno.ORELSE_SCOPE))) return node
def _replace_pop_call(self, node): # Expressions that use pop() are converted to a statement + expression. # # For example: # # print(target.pop()) # # ... is converted to: # # target, target_pop = ag__.list_pop(target) # print(target_pop) # # Here, we just generate the variable name and swap it in, # and _generate_pop_operation will handle the rest. # # Multiple uses of pop() are allowed: # # print(tartget.pop(), target.pop()) # print(tartget.pop().pop()) # assert isinstance(node.func, gast.Attribute) scope = anno.getanno(node, NodeAnno.ARGS_SCOPE) target_node = node.func.value # Attempt to use a related name if one exists. Otherwise use something # generic. if anno.hasanno(target_node, anno.Basic.QN): target_name = anno.getanno(target_node, anno.Basic.QN).ssf() else: target_name = 'list_' pop_var_name = self.ctx.namer.new_symbol(target_name, scope.referenced) stmt = self.state[_Statement] if stmt.pop_uses is None: stmt.pop_uses = [] stmt.pop_uses.append((node, pop_var_name)) return templates.replace_as_expression('var_name', var_name=pop_var_name)
def visit_node(self, node): prev_live_in = self.in_[node] if anno.hasanno(node.ast_node, anno.Static.SCOPE): node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE) gen = node_scope.read | self.extra_gen.get(node.ast_node, frozenset()) # TODO(mdan): verify whether composites' parents need to be added. # E.g. whether x needs to be added if x.y is live. Theoretically the # activity analysis should have both so that wouldn't be needed. kill = node_scope.modified | node_scope.deleted live_out = set() for n in node.next: live_out |= self.in_[n] live_in = gen | (live_out - kill) else: # Nodes that don't have a scope annotation are assumed not to touch any # symbols. # This Name node below is a literal name, e.g. False assert isinstance(node.ast_node, (gast.Name, gast.Continue, gast.Break, gast.Pass, gast.Global, gast.Nonlocal)), type( node.ast_node) live_out = set() for n in node.next: live_out |= self.in_[n] live_in = live_out self.in_[node] = live_in self.out[node] = live_out # TODO(mdan): Move this to the superclass? return prev_live_in != live_in
def visit_node(self, node): prev_defs_out = self.out[node] defs_in = _NodeState(self.extra_in.get(node.ast_node, None)) for n in node.prev: defs_in |= self.out[n] if anno.hasanno(node.ast_node, anno.Static.SCOPE): node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE) # The definition objects created by each node must be singletons because # their ids are used in equality checks. if node not in self.gen_map: node_symbols = {} # Every modification receives a definition. for s in node_scope.modified: def_ = self._definition_factory() node_symbols[s] = def_ # Every param receives a definition. Params are not necessarily # considered as "modified". for s, p in node_scope.params.items(): def_ = self._definition_factory() def_.param_of = weakref.ref(p) node_symbols[s] = def_ self.gen_map[node] = _NodeState(node_symbols) gen = self.gen_map[node] kill = node_scope.modified | node_scope.deleted defs_out = gen | (defs_in - kill) elif isinstance(node.ast_node, (gast.Global, gast.Nonlocal)): # Special case for global and nonlocal: they generate a definition, # but are not tracked by activity analysis. if node not in self.gen_map: node_symbols = {} kill = set() for s in node.ast_node.names: qn = qual_names.QN(s) # TODO(mdan): If definitions exist, should we preserve those instead? # Incoming definitions may be present when this is a local function. # In that case, the definitions of the nonlocal symbol from the # enclosing function are available here. See self.extra_in. kill.add(qn) def_ = self._definition_factory() node_symbols[qn] = def_ self.gen_map[node] = _NodeState(node_symbols) gen = self.gen_map[node] defs_out = gen | (defs_in - kill) else: # Nodes that don't have a scope annotation are assumed not to touch any # symbols. # This Name node below is a literal name, e.g. False # This can also happen if activity.py forgot to annotate the node with a # scope object. assert isinstance(node.ast_node, (gast.Name, gast.Break, gast.Continue, gast.Raise, gast.Pass)), (node.ast_node, node) defs_out = defs_in self.in_[node] = defs_in self.out[node] = defs_out # TODO(mdan): Move this to the superclass? return prev_defs_out != defs_out
def visit_Attribute(self, node): if anno.hasanno(node, anno.Basic.QN): return self._process_name_node(node) # Renaming attributes is not supported. return self.generic_visit(node)
def visit_For(self, node): node = self.generic_visit(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) iter_scope = anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE) loop_vars, reserved_symbols, possibly_undefs = self._get_loop_vars( node, body_scope.modified | iter_scope.modified) undefined_assigns = self._create_undefined_assigns(possibly_undefs) nonlocal_declarations = self._create_nonlocal_declarations(loop_vars) state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols) state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols) state_functions = self._create_state_functions(loop_vars, nonlocal_declarations, state_getter_name, state_setter_name) opts = self._create_loop_options(node) if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST) extra_test_name = self.ctx.namer.new_symbol( 'extra_test', reserved_symbols) template = """ def extra_test_name(): nonlocal_declarations return extra_test_expr """ extra_test_function = templates.replace( template, extra_test_expr=extra_test, extra_test_name=extra_test_name, loop_vars=loop_vars, nonlocal_declarations=nonlocal_declarations) else: extra_test_name = parser.parse_expression('None') extra_test_function = [] # iterate_arg_name holds a single arg with the iterates, which may be a # tuple. iterate_arg_name = self.ctx.namer.new_symbol('itr', reserved_symbols) template = """ iterates = iterate_arg_name """ iterate_expansion = templates.replace( template, iterate_arg_name=iterate_arg_name, iterates=node.target) template = """ state_functions def body_name(iterate_arg_name): nonlocal_declarations iterate_expansion body extra_test_function undefined_assigns ag__.for_stmt( iterated, extra_test_name, body_name, state_getter_name, state_setter_name, (symbol_names,), opts) """ return templates.replace( template, body=node.body, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), extra_test_function=extra_test_function, extra_test_name=extra_test_name, iterate_arg_name=iterate_arg_name, iterate_expansion=iterate_expansion, iterated=node.iter, nonlocal_declarations=nonlocal_declarations, opts=opts, symbol_names=tuple( gast.Constant(str(s), kind=None) for s in loop_vars), state_functions=state_functions, state_getter_name=state_getter_name, state_setter_name=state_setter_name, undefined_assigns=undefined_assigns)