Example #1
0
    def dfs_visit(self, node):
        # AST reuses some gast.nodes, such as Param node of expr_context
        if node not in self.node_to_wrapper_map:
            cur_wrapper = AstNodeWrapper(node)
            self.node_to_wrapper_map[node] = cur_wrapper
        else:
            cur_wrapper = self.node_to_wrapper_map[node]

        if self.node_wrapper_root is None:
            self.node_wrapper_root = cur_wrapper

        if len(self.ancestor_wrappers) != 0:
            last_wrapper = self.ancestor_wrappers[-1]
            last_wrapper.children.append(cur_wrapper)
            cur_wrapper.parent = last_wrapper

        self.ancestor_wrappers.append(cur_wrapper)
        for child in gast.iter_child_nodes(node):
            if isinstance(child, gast.FunctionDef) or isinstance(
                    child, gast.AsyncFunctionDef):
                # TODO: current version is function name mapping to its type
                # consider complex case involving parameters
                self.var_env.enter_scope(child.name,
                                         AstVarScope.SCOPE_TYPE_FUNCTION)
                func_type = self.dfs_visit(child)
                self.var_env.exit_scope()
            else:
                self.dfs_visit(child)
        self.ancestor_wrappers.pop()

        cur_wrapper.node_var_type = self._get_node_var_type(cur_wrapper)
        return cur_wrapper.node_var_type
Example #2
0
def analyze_convert_fn(ast: AST) -> AST:
    """Finds and Annotates the target convert function AST

    Requires `analyze_func()` to analyze to the AST first.
    Assumes the target convert function is the first child node of AST.
    Annotates the top level node with the target convert `FunctionDef` node as
    `convert_fn`, otherwise set to `None`.
    Additionally annotates `convert_fn` if present with:
    - `plotter_name` - name of the plotter instance passed to `convert_fn`

    Args:
        ast:
            AST to scan and annotate the target convert function.

    Returns:
        The given AST with the target convert function annotated as `convert_fn`
    """
    # TODO(mrzzy): Allow convert_fn with default args ie def convert_fn(g, a=2, b=3):
    # walk through the AST to find the top level node with min nesting
    candidate_fns = [
        n for n in gast.iter_child_nodes(ast) if isinstance(n, FunctionDef)
    ]
    ast.convert_fn = candidate_fns[0] if len(candidate_fns) == 1 else None

    # extract name of plotter argument if present
    if ast.convert_fn is not None and ast.convert_fn.n_args == 1:
        ast.convert_fn.plotter_name = ast.convert_fn.args.args[0].id

    return ast
Example #3
0
    def walk_const(ast, part_of=None):
        iterable_types = (ListAST, Tuple)
        ignore_types = (Load, Store, Del)

        ast.is_constant = False
        if isinstance(ast, ignore_types):
            pass
        # part of collection but not constant, meaning collection is not constant
        elif part_of is not None and not isinstance(ast, iterable_types + (Constant,)):
            part_of.is_constant = False
        # constant if constant type and and not part of any larger collection
        elif part_of is None and isinstance(ast, Constant):
            ast.is_constant = True
        # make sure we are not unpacking from the iterable
        # otherwise elements should be recognised as constants instead
        elif isinstance(ast, iterable_types) and not (
            hasattr(ast, "assign") and ast.assign.is_unpack
        ):
            # mark child nodes as part of this collection node
            part_of = ast
            # mark this collection node as constant unless a child node overrides
            ast.is_constant = True
        # recursively inspect child nodes for constants
        for node in gast.iter_child_nodes(ast):
            walk_const(node, part_of)
Example #4
0
    def propagate_activity(ast):
        # recursively obtain symbol activity from child code blocks
        input_syms, output_syms, base_in_syms, base_out_syms = {}, {}, {}, {}
        for node in gast.iter_child_nodes(ast):
            (
                child_inputs,
                child_outputs,
                child_base_ins,
                child_base_outs,
            ) = propagate_activity(node)
            input_syms.update(child_inputs)
            output_syms.update(child_outputs)
            base_in_syms.update(child_base_ins)
            base_out_syms.update(child_base_outs)

        if not ast.is_block:
            return input_syms, output_syms, base_in_syms, base_out_syms
        # include symbol activity from this blockf
        block = ast
        input_syms.update(getattr(ast, "input_syms", {}))
        output_syms.update(getattr(ast, "output_syms", {}))
        base_in_syms.update(getattr(ast, "base_in_syms", {}))
        base_out_syms.update(getattr(ast, "base_out_syms", {}))
        block.input_syms, block.output_syms = input_syms, output_syms
        block.base_in_syms, block.base_out_syms = base_in_syms, base_out_syms

        return input_syms, output_syms, base_in_syms, base_out_syms
Example #5
0
    def walk_parent(ast, parent=None):
        if not parent is None:
            ast.parent = parent

        # recursively resolve parents of child nodes
        for node in gast.iter_child_nodes(ast):
            walk_parent(node, ast)
Example #6
0
 def walk_block(ast, block=None):
     if not block is None:
         ast.block = block
     # detect code blocks by checking for attributes
     ast.is_block = any(hasattr(ast, attr) for attr in ["body", "orelse"])
     if ast.is_block:
         block = ast
     # recursively resolve code blocks of child nodes
     for node in gast.iter_child_nodes(ast):
         walk_block(node, block)
Example #7
0
    def _check_wrapper(self, wrapper, node_to_wrapper_map):
        self.assertEqual(node_to_wrapper_map[wrapper.node], wrapper)
        if wrapper.parent is not None:
            self.assertTrue(wrapper in wrapper.parent.children)

        children_ast_nodes = [
            child for child in gast.iter_child_nodes(wrapper.node)
        ]
        self.assertEqual(len(wrapper.children), len(children_ast_nodes))
        for child in wrapper.children:
            self.assertTrue(child.node in children_ast_nodes)
            self._check_wrapper(child, node_to_wrapper_map)
Example #8
0
    def walk_resolve(ast, symbol_table=deque([{}])):
        # get current stack frame of the symbol table
        symbol_frame = symbol_table[-1]
        new_scope = False

        def push_definition(symbol_frame, target_sym, assign_value):
            # record definition (assign value) for symbol in symbol table
            definitions = symbol_frame.get(target_sym, [])
            definitions.append(assign_value)
            symbol_frame[target_sym] = definitions

        if isinstance(ast, (Assign, AnnAssign)):
            # update frame with definitions for symbol defined in assignment
            assign = ast
            target_syms = {t.symbol: getattr(t, "symbol", None) for t in assign.tgts}
            for target_sym, assign_value in zip(target_syms, assign.values):
                push_definition(symbol_frame, target_sym, assign_value)
        elif isinstance(ast, FunctionDef):
            # record symbol defined by function definition
            fn_def = ast
            push_definition(symbol_frame, target_sym=fn_def.name, assign_value=fn_def)
            # record arguments defined in function as symbols
            for arg in fn_def.args.args:
                push_definition(symbol_frame, target_sym=arg.id, assign_value=arg)
            # function definition creates a new scope
            new_scope = True
        elif hasattr(ast, "symbol"):
            # try to resolve symbol definitions
            definitions = symbol_frame.get(ast.symbol, [])
            # label latest definition of symbol on symbol AST node
            ast.definition = definitions[-1] if len(definitions) >= 1 else None
            # label all resolved definitions of symbol on symbol AST node
            ast.definitions = definitions

        # create a new stack frame if in new scope
        if new_scope:
            new_frame = deepcopy(symbol_frame)
            symbol_table.append(new_frame)
        # recursively resolve attributes of child nodes
        for node in gast.iter_child_nodes(ast):
            walk_resolve(node, symbol_table)
        # pop stack frame from symbol table to revert to previous frame
        if new_scope:
            symbol_table.pop()
Example #9
0
    def _go(subtree, parent_func, parent_loop):
        """Recursively process a subtree.

    The high level strategy is to recursively walk down the tree. When we see
    a function or loop node, we update the `parent_func` or `parent_loop`
    argument, and then continue descending. When we reach a `break`, `continue`,
    or `return`, we then connect these nodes to the corresponding innermost
    function or loop.

    Args:
      subtree: Current subtree to process.
      parent_func: The AST node corresponding to the (innermost) FunctionDef
        that contains this subtree.
      parent_loop: The AST node corresponding to the (innermost) For or While
        loop that contains this subtree.
    """
        if isinstance(subtree, gast.Return):
            assert parent_func, "return outside function"
            if from_return:
                result.append(
                    (ast_to_node_id[id(subtree)],
                     ast_to_node_id[id(parent_func)], JUMPS_OUT_OF_EDGE_TYPE))
            if from_retval and subtree.value:
                result.append(
                    (ast_to_node_id[id(subtree.value)],
                     ast_to_node_id[id(parent_func)], JUMPS_OUT_OF_EDGE_TYPE))
        elif isinstance(subtree, (gast.Break, gast.Continue)):
            assert parent_loop, "break or continue outside loop"
            if from_break_cont:
                result.append(
                    (ast_to_node_id[id(subtree)],
                     ast_to_node_id[id(parent_loop)], JUMPS_OUT_OF_EDGE_TYPE))
        elif isinstance(subtree, gast.FunctionDef):
            # Update current function
            for stmt in subtree.body:
                _go(stmt, subtree, None)
        elif isinstance(subtree, (gast.For, gast.While)):
            # Update current loop
            for stmt in subtree.body:
                _go(stmt, parent_func, subtree)
        else:
            for child in gast.iter_child_nodes(subtree):
                _go(child, parent_func, parent_loop)
Example #10
0
 def generic_visit(self, node):
     is_pure = all([self.visit(x) for x in ast.iter_child_nodes(node)])
     if is_pure:
         self.result.add(node)
     return is_pure
Example #11
0
 def generic_visit(self, node):
     lambdas = [self.visit(child) for child in ast.iter_child_nodes(node)]
     return lambda ctx: sum(l(ctx) for l in lambdas)
Example #12
0
 def generic_visit(self, node):
     lambdas = [self.visit(child) for child in ast.iter_child_nodes(node)]
     return lambda ctx: sum(l(ctx) for l in lambdas)
Example #13
0
 def generic_visit(self, node):
     is_pure = all([self.visit(x) for x in ast.iter_child_nodes(node)])
     if is_pure:
         self.result.add(node)
     return is_pure
Example #14
0
 def generic_visit(self, node):
     is_pure = all(map(self.visit, ast.iter_child_nodes(node)))
     if is_pure:
         self.result.add(node)
     return is_pure
Example #15
0
 def generic_visit(self, node):
     is_pure = all(map(self.visit, ast.iter_child_nodes(node)))
     if is_pure:
         self.result.add(node)
     return is_pure