def dfs_visit(self, node): # AST reuses some gast.nodes, such as Param node of expr_context if node not in self.node_to_wrapper_map: cur_wrapper = AstNodeWrapper(node) self.node_to_wrapper_map[node] = cur_wrapper else: cur_wrapper = self.node_to_wrapper_map[node] if self.node_wrapper_root is None: self.node_wrapper_root = cur_wrapper if len(self.ancestor_wrappers) != 0: last_wrapper = self.ancestor_wrappers[-1] last_wrapper.children.append(cur_wrapper) cur_wrapper.parent = last_wrapper self.ancestor_wrappers.append(cur_wrapper) for child in gast.iter_child_nodes(node): if isinstance(child, gast.FunctionDef) or isinstance( child, gast.AsyncFunctionDef): # TODO: current version is function name mapping to its type # consider complex case involving parameters self.var_env.enter_scope(child.name, AstVarScope.SCOPE_TYPE_FUNCTION) func_type = self.dfs_visit(child) self.var_env.exit_scope() else: self.dfs_visit(child) self.ancestor_wrappers.pop() cur_wrapper.node_var_type = self._get_node_var_type(cur_wrapper) return cur_wrapper.node_var_type
def analyze_convert_fn(ast: AST) -> AST: """Finds and Annotates the target convert function AST Requires `analyze_func()` to analyze to the AST first. Assumes the target convert function is the first child node of AST. Annotates the top level node with the target convert `FunctionDef` node as `convert_fn`, otherwise set to `None`. Additionally annotates `convert_fn` if present with: - `plotter_name` - name of the plotter instance passed to `convert_fn` Args: ast: AST to scan and annotate the target convert function. Returns: The given AST with the target convert function annotated as `convert_fn` """ # TODO(mrzzy): Allow convert_fn with default args ie def convert_fn(g, a=2, b=3): # walk through the AST to find the top level node with min nesting candidate_fns = [ n for n in gast.iter_child_nodes(ast) if isinstance(n, FunctionDef) ] ast.convert_fn = candidate_fns[0] if len(candidate_fns) == 1 else None # extract name of plotter argument if present if ast.convert_fn is not None and ast.convert_fn.n_args == 1: ast.convert_fn.plotter_name = ast.convert_fn.args.args[0].id return ast
def walk_const(ast, part_of=None): iterable_types = (ListAST, Tuple) ignore_types = (Load, Store, Del) ast.is_constant = False if isinstance(ast, ignore_types): pass # part of collection but not constant, meaning collection is not constant elif part_of is not None and not isinstance(ast, iterable_types + (Constant,)): part_of.is_constant = False # constant if constant type and and not part of any larger collection elif part_of is None and isinstance(ast, Constant): ast.is_constant = True # make sure we are not unpacking from the iterable # otherwise elements should be recognised as constants instead elif isinstance(ast, iterable_types) and not ( hasattr(ast, "assign") and ast.assign.is_unpack ): # mark child nodes as part of this collection node part_of = ast # mark this collection node as constant unless a child node overrides ast.is_constant = True # recursively inspect child nodes for constants for node in gast.iter_child_nodes(ast): walk_const(node, part_of)
def propagate_activity(ast): # recursively obtain symbol activity from child code blocks input_syms, output_syms, base_in_syms, base_out_syms = {}, {}, {}, {} for node in gast.iter_child_nodes(ast): ( child_inputs, child_outputs, child_base_ins, child_base_outs, ) = propagate_activity(node) input_syms.update(child_inputs) output_syms.update(child_outputs) base_in_syms.update(child_base_ins) base_out_syms.update(child_base_outs) if not ast.is_block: return input_syms, output_syms, base_in_syms, base_out_syms # include symbol activity from this blockf block = ast input_syms.update(getattr(ast, "input_syms", {})) output_syms.update(getattr(ast, "output_syms", {})) base_in_syms.update(getattr(ast, "base_in_syms", {})) base_out_syms.update(getattr(ast, "base_out_syms", {})) block.input_syms, block.output_syms = input_syms, output_syms block.base_in_syms, block.base_out_syms = base_in_syms, base_out_syms return input_syms, output_syms, base_in_syms, base_out_syms
def walk_parent(ast, parent=None): if not parent is None: ast.parent = parent # recursively resolve parents of child nodes for node in gast.iter_child_nodes(ast): walk_parent(node, ast)
def walk_block(ast, block=None): if not block is None: ast.block = block # detect code blocks by checking for attributes ast.is_block = any(hasattr(ast, attr) for attr in ["body", "orelse"]) if ast.is_block: block = ast # recursively resolve code blocks of child nodes for node in gast.iter_child_nodes(ast): walk_block(node, block)
def _check_wrapper(self, wrapper, node_to_wrapper_map): self.assertEqual(node_to_wrapper_map[wrapper.node], wrapper) if wrapper.parent is not None: self.assertTrue(wrapper in wrapper.parent.children) children_ast_nodes = [ child for child in gast.iter_child_nodes(wrapper.node) ] self.assertEqual(len(wrapper.children), len(children_ast_nodes)) for child in wrapper.children: self.assertTrue(child.node in children_ast_nodes) self._check_wrapper(child, node_to_wrapper_map)
def walk_resolve(ast, symbol_table=deque([{}])): # get current stack frame of the symbol table symbol_frame = symbol_table[-1] new_scope = False def push_definition(symbol_frame, target_sym, assign_value): # record definition (assign value) for symbol in symbol table definitions = symbol_frame.get(target_sym, []) definitions.append(assign_value) symbol_frame[target_sym] = definitions if isinstance(ast, (Assign, AnnAssign)): # update frame with definitions for symbol defined in assignment assign = ast target_syms = {t.symbol: getattr(t, "symbol", None) for t in assign.tgts} for target_sym, assign_value in zip(target_syms, assign.values): push_definition(symbol_frame, target_sym, assign_value) elif isinstance(ast, FunctionDef): # record symbol defined by function definition fn_def = ast push_definition(symbol_frame, target_sym=fn_def.name, assign_value=fn_def) # record arguments defined in function as symbols for arg in fn_def.args.args: push_definition(symbol_frame, target_sym=arg.id, assign_value=arg) # function definition creates a new scope new_scope = True elif hasattr(ast, "symbol"): # try to resolve symbol definitions definitions = symbol_frame.get(ast.symbol, []) # label latest definition of symbol on symbol AST node ast.definition = definitions[-1] if len(definitions) >= 1 else None # label all resolved definitions of symbol on symbol AST node ast.definitions = definitions # create a new stack frame if in new scope if new_scope: new_frame = deepcopy(symbol_frame) symbol_table.append(new_frame) # recursively resolve attributes of child nodes for node in gast.iter_child_nodes(ast): walk_resolve(node, symbol_table) # pop stack frame from symbol table to revert to previous frame if new_scope: symbol_table.pop()
def _go(subtree, parent_func, parent_loop): """Recursively process a subtree. The high level strategy is to recursively walk down the tree. When we see a function or loop node, we update the `parent_func` or `parent_loop` argument, and then continue descending. When we reach a `break`, `continue`, or `return`, we then connect these nodes to the corresponding innermost function or loop. Args: subtree: Current subtree to process. parent_func: The AST node corresponding to the (innermost) FunctionDef that contains this subtree. parent_loop: The AST node corresponding to the (innermost) For or While loop that contains this subtree. """ if isinstance(subtree, gast.Return): assert parent_func, "return outside function" if from_return: result.append( (ast_to_node_id[id(subtree)], ast_to_node_id[id(parent_func)], JUMPS_OUT_OF_EDGE_TYPE)) if from_retval and subtree.value: result.append( (ast_to_node_id[id(subtree.value)], ast_to_node_id[id(parent_func)], JUMPS_OUT_OF_EDGE_TYPE)) elif isinstance(subtree, (gast.Break, gast.Continue)): assert parent_loop, "break or continue outside loop" if from_break_cont: result.append( (ast_to_node_id[id(subtree)], ast_to_node_id[id(parent_loop)], JUMPS_OUT_OF_EDGE_TYPE)) elif isinstance(subtree, gast.FunctionDef): # Update current function for stmt in subtree.body: _go(stmt, subtree, None) elif isinstance(subtree, (gast.For, gast.While)): # Update current loop for stmt in subtree.body: _go(stmt, parent_func, subtree) else: for child in gast.iter_child_nodes(subtree): _go(child, parent_func, parent_loop)
def generic_visit(self, node): is_pure = all([self.visit(x) for x in ast.iter_child_nodes(node)]) if is_pure: self.result.add(node) return is_pure
def generic_visit(self, node): lambdas = [self.visit(child) for child in ast.iter_child_nodes(node)] return lambda ctx: sum(l(ctx) for l in lambdas)
def generic_visit(self, node): is_pure = all(map(self.visit, ast.iter_child_nodes(node))) if is_pure: self.result.add(node) return is_pure