def visit_Attribute(self, node): node = self.generic_visit(node) if anno.hasanno(node.value, anno.Basic.QN): anno.setanno( node, anno.Basic.QN, QN(anno.getanno(node.value, anno.Basic.QN), attr=node.attr)) return node
def visit_Lambda(self, node): self.state[_Function].enter() node = self.generic_visit(node) # Only wrap the top-level function. Theoretically, we can and should wrap # everything, but that can lead to excessive boilerplate when lambdas are # nested. # TODO(mdan): Looks more closely for use cases that actually require this. if self.state[_Function].level > 2: self.state[_Function].exit() return node scope = anno.getanno(node, anno.Static.SCOPE) function_context_name = self.ctx.namer.new_symbol( 'lscope', scope.referenced) self.state[_Function].context_name = function_context_name anno.setanno(node, 'function_context_name', function_context_name) template = """ ag__.with_function_scope( lambda function_context: body, function_context_name, options) """ node.body = templates.replace_as_expression( template, options=self._function_scope_options().to_ast(), function_context=function_context_name, function_context_name=gast.Constant(function_context_name, kind=None), body=node.body) self.state[_Function].exit() return node
def _block_statement_live_out(self, node): successors = self.current_analyzer.graph.stmt_next[node] stmt_live_out = set() for s in successors: stmt_live_out.update(self.current_analyzer.in_[s]) anno.setanno(node, anno.Static.LIVE_VARS_OUT, frozenset(stmt_live_out)) return node
def _aggregate_predecessors_defined_in(self, node): preds = self.current_analyzer.graph.stmt_prev[node] node_defined_in = set() for p in preds: node_defined_in |= set(self.current_analyzer.out[p].value.keys()) anno.setanno(node, anno.Static.DEFINED_VARS_IN, frozenset(node_defined_in))
def visit_With(self, node): node.items = self.visit_block(node.items) node.body, definitely_returns = self._visit_statement_block( node, node.body) if definitely_returns: anno.setanno(node, STMT_DEFINITELY_RETURNS, True) return node
def visit(self, node): node = super(Annotator, self).visit(node) if (self.current_analyzer is not None and isinstance(node, gast.stmt) and node in self.current_analyzer.graph.index): cfg_node = self.current_analyzer.graph.index[node] anno.setanno(node, anno.Static.LIVE_VARS_IN, frozenset(self.current_analyzer.in_[cfg_node])) return node
def visit_Attribute(self, node): node = self.generic_visit(node) parent_val = anno.getanno(node.value, STATIC_VALUE, default=None) if parent_val is not None and inspect.ismodule(parent_val): if hasattr(parent_val, node.attr): anno.setanno(node, STATIC_VALUE, getattr(parent_val, node.attr)) return node
def visit(self, node): if not isinstance(node, gast.AST): # This is not that uncommon a mistake: various node bodies are lists, for # example, posing a land mine for transformers that need to recursively # call `visit`. The error needs to be raised before the exception handler # below is installed, because said handler will mess up if `node` is not, # in fact, a node. msg = ('invalid value for "node": expected "ast.AST", got "{}"; to' ' visit lists of nodes, use "visit_block" instead').format( type(node)) raise ValueError(msg) if anno.hasanno(node, anno.Basic.SKIP_PROCESSING): return node parent_origin = self.ctx.current_origin if anno.hasanno(node, anno.Basic.ORIGIN): self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN) try: processing_expr_node = isinstance(node, gast.Expr) if processing_expr_node: entry_expr_value = node.value result = super(Base, self).visit(node) # Adjust for consistency: replacing the value of an Expr with # an Assign node removes the need for the Expr node. if (processing_expr_node and isinstance(result, gast.Expr) and (result.value is not entry_expr_value)): # When the replacement is a list, it is assumed that the list came # from a template that contained a number of statements, which # themselves are standalone and don't require an enclosing Expr. if isinstance(result.value, (list, tuple, gast.Assign, gast.AugAssign)): result = result.value # By default, all replacements receive the origin info of the replaced # node. if result is not node and result is not None: inherited_origin = anno.getanno(node, anno.Basic.ORIGIN, default=parent_origin) if inherited_origin is not None: nodes_to_adjust = result if isinstance(result, (list, tuple)): nodes_to_adjust = result else: nodes_to_adjust = (result, ) for n in nodes_to_adjust: if not anno.hasanno(n, anno.Basic.ORIGIN): anno.setanno(n, anno.Basic.ORIGIN, inherited_origin) finally: self.ctx.current_origin = parent_origin return result
def visit_Name(self, node): node = self.generic_visit(node) if isinstance(node.ctx, gast.Load): defs = anno.getanno(node, anno.Static.DEFINITIONS, ()) is_defined = bool(defs) if not is_defined and node.id in self.ctx.info.namespace: anno.setanno(node, STATIC_VALUE, self.ctx.info.namespace[node.id]) return node
def visit_While(self, node): self._enter_scope(False) node.test = self.visit(node.test) node_scope = self._exit_and_record_scope(node.test) anno.setanno(node, NodeAnno.COND_SCOPE, node_scope) node = self._process_parallel_blocks( node, ((node.body, NodeAnno.BODY_SCOPE), (node.orelse, NodeAnno.ORELSE_SCOPE))) return node
def _block_statement_live_in(self, node, entry_node): if entry_node in self.current_analyzer.graph.index: cfg_node = self.current_analyzer.graph.index[entry_node] stmt_live_in = frozenset(self.current_analyzer.in_[cfg_node]) else: assert anno.hasanno(entry_node, anno.Static.LIVE_VARS_IN), ( 'If not matching a CFG node, must be a block statement:' ' {}'.format(entry_node)) stmt_live_in = anno.getanno(entry_node, anno.Static.LIVE_VARS_IN) anno.setanno(node, anno.Static.LIVE_VARS_IN, stmt_live_in) return node
def _process_statement_directive(self, call_node, directive): if self.state[_LoopScope].statements_visited > 1: raise ValueError( '"%s" must be the first statement in the loop block' % (directive.__name__)) if self.state[_LoopScope].level < 2: raise ValueError('"%s" must be used inside a statement' % directive.__name__) target = self.state[_LoopScope].ast_node node_anno = anno.getanno(target, anno.Basic.DIRECTIVES, {}) node_anno[directive] = _map_args(call_node, directive) anno.setanno(target, anno.Basic.DIRECTIVES, node_anno) return call_node
def _attach_origin_info(self, node): if self._function_stack: function_name = self._function_stack[-1].name else: function_name = None source_code_line = self._source_lines[node.lineno - 1] comment = self._comments_map.get(node.lineno) loc = Location(self._filepath, self._absolute_lineno(node), self._absolute_col_offset(node)) origin = OriginInfo(loc, function_name, source_code_line, comment) anno.setanno(node, 'lineno', node.lineno) anno.setanno(node, anno.Basic.ORIGIN, origin)
def visit_If(self, node): node.test = self.visit(node.test) node.body, body_definitely_returns = self._visit_statement_block( node, node.body) if body_definitely_returns: anno.setanno(node, BODY_DEFINITELY_RETURNS, True) node.orelse, orelse_definitely_returns = self._visit_statement_block( node, node.orelse) if orelse_definitely_returns: anno.setanno(node, ORELSE_DEFINITELY_RETURNS, True) if body_definitely_returns and orelse_definitely_returns: self.state[_RewriteBlock].definitely_returns = True return node
def visit_For(self, node): original_node = node scope = anno.getanno(node, NodeAnno.BODY_SCOPE) break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.target = self.visit(node.target) node.iter = self.visit(node.iter) node.body, break_used = self._process_body(node.body, break_var) # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) if break_used: # Python's else clause only triggers if the loop exited cleanly (e.g. # break did not trigger). guarded_orelse = self._guard_if_present(node.orelse, break_var) extra_test = templates.replace_as_expression('ag__.not_(var_name)', var_name=break_var) # The extra test is hidden in the AST, which will confuse the static # analysis. To mitigate that, we insert a no-op statement that ensures # the control variable is marked as used. # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) template = """ var_name = False for target in iter_: (var_name,) body else: orelse """ node = templates.replace(template, var_name=break_var, iter_=node.iter, target=node.target, body=node.body, orelse=guarded_orelse) new_for_node = node[1] anno.setanno(new_for_node, anno.Basic.EXTRA_LOOP_TEST, extra_test) anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) return node
def visit_Name(self, node): if self.current_analyzer is None: # Names may appear outside function defs - for example in class # definitions. return node analyzer = self.current_analyzer cfg_node = self.current_cfg_node assert cfg_node is not None, ( 'name node, %s, outside of any statement?' % node.id) qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Load): anno.setanno(node, anno.Static.DEFINITIONS, tuple(analyzer.in_[cfg_node].value.get(qn, ()))) else: anno.setanno(node, anno.Static.DEFINITIONS, tuple(analyzer.out[cfg_node].value.get(qn, ()))) return node
def visit_Subscript(self, node): # TODO(mdan): This may no longer apply if we overload getitem. node = self.generic_visit(node) s = node.slice if not isinstance(s, gast.Index): # TODO(mdan): Support range and multi-dimensional indices. # Continuing silently because some demos use these. return node if isinstance(s.value, gast.Constant): subscript = QN(NumberLiteral(s.value.value)) else: # The index may be an expression, case in which a name doesn't make sense. if anno.hasanno(node.slice.value, anno.Basic.QN): subscript = anno.getanno(node.slice.value, anno.Basic.QN) else: return node if anno.hasanno(node.value, anno.Basic.QN): anno.setanno( node, anno.Basic.QN, QN(anno.getanno(node.value, anno.Basic.QN), subscript=subscript)) return node
def visit_FunctionDef(self, node): self.state[_Function].enter() scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) function_context_name = self.ctx.namer.new_symbol( 'fscope', scope.referenced) self.state[_Function].context_name = function_context_name anno.setanno(node, 'function_context_name', function_context_name) node = self.generic_visit(node) docstring_node = None if node.body: first_statement = node.body[0] if (isinstance(first_statement, gast.Expr) and isinstance(first_statement.value, gast.Constant)): docstring_node = first_statement node.body = node.body[1:] template = """ with ag__.FunctionScope( function_name, context_name, options) as function_context: body """ wrapped_body = templates.replace( template, function_name=gast.Constant(node.name, kind=None), context_name=gast.Constant(function_context_name, kind=None), options=self._function_scope_options().to_ast(), function_context=function_context_name, body=node.body) if docstring_node is not None: wrapped_body = [docstring_node] + wrapped_body node.body = wrapped_body self.state[_Function].exit() return node
def visit_For(self, node): node.iter = self.visit(node.iter) node.target = self.visit(node.target) # Add the check for return to the loop condition. node.body = self._visit_statement_block(node, node.body) if self.state[_Block].return_used: extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST, default=None) if extra_test is not None: extra_test = templates.replace_as_expression( 'ag__.and_(lambda: ag__.not_(control_var), lambda: extra_test)', extra_test=extra_test, control_var=self.state[_Function].do_return_var_name) else: extra_test = templates.replace_as_expression( 'ag__.not_(control_var)', control_var=self.state[_Function].do_return_var_name) anno.setanno(node, anno.Basic.EXTRA_LOOP_TEST, extra_test) node.orelse = self._visit_statement_block(node, node.orelse) return node
def visit_Expr(self, node): node = self.generic_visit(node) cfg_node = self.current_analyzer.graph.index[node] anno.setanno(node, anno.Static.LIVE_VARS_OUT, frozenset(self.current_analyzer.out[cfg_node])) return node
def visit_Name(self, node): node = self.generic_visit(node) anno.setanno(node, anno.Basic.QN, QN(node.id)) return node
def _exit_and_record_scope(self, node, tag=anno.Static.SCOPE): node_scope = self._exit_scope() anno.setanno(node, tag, node_scope) return node_scope
def visit_Print(self, node): self._enter_scope(False) node.values = self.visit_block(node.values) node_scope = self._exit_and_record_scope(node) anno.setanno(node, NodeAnno.ARGS_SCOPE, node_scope) return node