Beispiel #1
0
    def attach_data(self, node):
        '''Generic method called for visit_XXXX() with XXXX in
        GatherOMPData.statements list

        '''
        if self.current:
            for curr in self.current:
                md = OMPDirective(curr)
                metadata.add(node, md)
            self.current = list()
        # add a Pass to hold some directives
        for field_name, field in ast.iter_fields(node):
            if field_name in GatherOMPData.statement_lists:
                if(field and
                   isinstance(field[-1], ast.Expr) and
                   self.isompdirective(field[-1].value)):
                    field.append(ast.Pass())
        self.generic_visit(node)

        # add an If to hold scoping OpenMP directives
        directives = metadata.get(node, OMPDirective)
        field_names = {n for n, _ in ast.iter_fields(node)}
        has_no_scope = field_names.isdisjoint(GatherOMPData.statement_lists)
        if directives and has_no_scope:
            # some directives create a scope, but the holding stmt may not
            # artificially create one here if needed
            sdirective = ''.join(d.s for d in directives)
            scoping = ('parallel', 'task', 'section')
            if any(s in sdirective for s in scoping):
                metadata.clear(node, OMPDirective)
                node = ast.If(ast.Num(1), [node], [])
                for directive in directives:
                    metadata.add(node, directive)

        return node
Beispiel #2
0
    def attach_data(self, node):
        '''Generic method called for visit_XXXX() with XXXX in
        GatherOMPData.statements list

        '''
        if self.current:
            for curr in self.current:
                md = OMPDirective(curr)
                metadata.add(node, md)
            self.current = list()
        # add a Pass to hold some directives
        for field_name, field in ast.iter_fields(node):
            if field_name in GatherOMPData.statement_lists:
                if(field and
                   isinstance(field[-1], ast.Expr) and
                   self.isompdirective(field[-1].value)):
                    field.append(ast.Pass())
        self.generic_visit(node)

        # add an If to hold scoping OpenMP directives
        directives = metadata.get(node, OMPDirective)
        field_names = {n for n, _ in ast.iter_fields(node)}
        has_no_scope = field_names.isdisjoint(GatherOMPData.statement_lists)
        if directives and has_no_scope:
            # some directives create a scope, but the holding stmt may not
            # artificially create one here if needed
            sdirective = ''.join(d.s for d in directives)
            scoping = ('parallel', 'task', 'section')
            if any(s in sdirective for s in scoping):
                node = ast.If(ast.Num(1), [node], [])
        return node
Beispiel #3
0
 def _insert_func_nodes(self, node):
     """
     Defined `true_func` and `false_func` will be inserted in front of corresponding
     `layers.cond` statement instead of inserting them all into body of parent node.
     Because private variables of class or other external scope will be modified.
     For example, `self.var_dict["key"]`. In this case, nested structure of newly
     defined functions is easier to understand.
     """
     if not self.new_func_nodes:
         return
     idx = -1
     if isinstance(node, list):
         idx = len(node) - 1
     elif isinstance(node, gast.AST):
         for _, child in gast.iter_fields(node):
             self._insert_func_nodes(child)
     while idx >= 0:
         child_node = node[idx]
         if child_node in self.new_func_nodes:
             node[idx:idx] = self.new_func_nodes[child_node]
             idx = idx + len(self.new_func_nodes[child_node]) - 1
             del self.new_func_nodes[child_node]
         else:
             self._insert_func_nodes(child_node)
             idx = idx - 1
Beispiel #4
0
    def _transform_var_shape_if_necessary(self, cond):
        need_transformed = False
        for child_node in gast.walk(cond):
            var_shape_node = None
            if isinstance(child_node, (gast.Attribute)):
                if self.is_var_shape(child_node):
                    var_shape_node = child_node
            elif isinstance(child_node, (gast.Name)):
                if child_node.id in self.name_to_var_shape:
                    var_shape_node = self.name_to_var_shape[child_node.id]

            if var_shape_node:
                need_transformed = True
                wrapper_node = self.node_to_wrapper_map.get(child_node)
                parent_node = wrapper_node.parent.node
                for field, value in gast.iter_fields(parent_node):
                    if child_node is value:
                        setattr(parent_node, field,
                                create_convert_shape_node(var_shape_node))
                        break
                    # Some child_node may be in a list such as gast.Compare
                    if isinstance(value, list):
                        has_converted_shape = False
                        for i, v in enumerate(value):
                            if child_node is v:
                                value[i] = create_convert_shape_node(
                                    var_shape_node)
                                has_converted_shape = True
                                break
                        if has_converted_shape:
                            break
        return need_transformed
Beispiel #5
0
    def insert_comment_in_parent_before_node(self, line, lineno, node):

        parent = self.ancestors.parent(node)
        comment = line.strip()

        for field, value in gast.iter_fields(parent):
            if isinstance(value, list):
                if node in value:
                    index = value.index(node)
                    value.insert(index, CommentLine(comment, lineno))
Beispiel #6
0
    def generic_visit(self, pattern):
        """
        Check if the pattern match with the checked node.

        a node match if:
            - type match
            - all field match
        """
        return (isinstance(pattern, type(self.node)) and all(
            self.field_match(value, getattr(pattern, field))
            for field, value in iter_fields(self.node)))
    def generic_visit(self, pattern):
        """
        Check if the pattern match with the checked node.

        a node match if:
            - type match
            - all field match
        """
        return (isinstance(pattern, type(self.node)) and
                all(self.field_match(value, getattr(pattern, field))
                    for field, value in iter_fields(self.node)))
Beispiel #8
0
 def generic_visit(self, node):
     # Because we change ancestor nodes during visit_Return, not current
     # node, original generic_visit of NodeTransformer will visit node
     # which may be deleted. To prevent that node being added into
     # transformed AST, We self-write a generic_visit and visit
     for field, value in gast.iter_fields(node):
         if isinstance(value, list):
             for item in value:
                 if isinstance(item, gast.AST):
                     self.visit(item)
         elif isinstance(value, gast.AST):
             self.visit(value)
 def generic_visit(self, node):
     # TODO: because we change ancestor nodes during visit_Break/Continue,
     # not current node, so generic_visit of NodeTransformer will visit node
     # which may be deleted. To prevent that node being added into
     # transformed AST, I have to self-write a generic_visit, but this is
     # NOT a good thing. Considering refactorying this whole class.
     for field, value in gast.iter_fields(node):
         if isinstance(value, list):
             for item in value:
                 if isinstance(item, gast.AST):
                     self.visit(item)
         elif isinstance(value, gast.AST):
             self.visit(value)
Beispiel #10
0
 def generic_visit(self, node):
     is_top = False
     if self._top:
         is_top = True
         self._top = False
     for field, old_value in gast.iter_fields(node):
         if isinstance(old_value, list):
             if (type(node), field) in grammar.BLOCKS:
                 self.to_prepend_block.append(deque())
                 self.to_append_block.append(deque())
                 self.to_insert.append(deque())
                 new_values = copy(self.visit_statements(old_value))
                 self.to_insert.pop()
             else:
                 new_values = []
                 for value in old_value:
                     if isinstance(value, gast.AST):
                         value = self.visit(value)
                         if value is None:
                             continue
                         elif not isinstance(value, gast.AST):
                             new_values.extend(value)
                             continue
                     new_values.append(value)
             if isinstance(node, gast.FunctionDef) and field == 'body':
                 new_values.extendleft(self.to_insert_top)
                 self.to_insert_top = deque([])
             if (type(node), field) in grammar.BLOCKS:
                 new_values.extendleft(self.to_prepend_block.pop())
                 return_ = None
                 if new_values and isinstance(new_values[-1], gast.Return):
                     return_ = new_values.pop()
                 new_values.extend(self.to_append_block.pop())
                 if return_:
                     new_values.append(return_)
             old_value[:] = new_values
         elif isinstance(old_value, gast.AST):
             new_node = self.visit(old_value)
             if new_node is None:
                 delattr(node, field)
             else:
                 setattr(node, field, new_node)
     if is_top and self.to_remove:
         Remove(self.to_remove).visit(node)
     return node
Beispiel #11
0
    def walk_symbol(ast, qualifiers=[], field_name=""):
        # assume is not symbol unless proven otherwise
        ast.is_symbol = False
        # check if name is unqualified symbol
        if isinstance(ast, Name) and len(qualifiers) == 0:
            ast.is_symbol, ast.symbol, ast.base_symbol = True, ast.id, ast.id
        # append qualifiers of a incomplete qualified symbol
        elif isinstance(ast, (Subscript, Attribute)):
            qualifiers = qualifiers + [ast]
        # check if name part of qualified symbol and not part of a subcript's slice
        elif isinstance(ast, Name) and field_name != "slice":
            # qualifiers are recorded in reverse order
            base_sym = symbol = ast.id
            for qualifier in reversed(qualifiers):
                if isinstance(qualifier, Attribute):
                    symbol += f".{qualifier.attr}"
                elif isinstance(qualifier, Subscript):
                    # render slice AST as source code
                    src_slice = unparse(qualifier.slice).rstrip()
                    symbol += f"[{src_slice}]"
            # label top level symbol
            top_attr = qualifiers[0]
            top_attr.is_symbol, top_attr.symbol, top_attr.base_symbol = (
                True,
                symbol,
                base_sym,
            )

        # recursively inspect child nodes for constants
        for name, value in gast.iter_fields(ast):
            # extract AST nodes from fields
            if isinstance(value, AST):
                nodes = [value]
            elif isinstance(value, Iterable) and all(
                [isinstance(v, AST) for v in value]
            ):
                nodes = value
            else:
                # non node field-skip
                continue

            for node in nodes:
                walk_symbol(node, qualifiers, field_name=name)
Beispiel #12
0
    def _transform_tensor_shape_if_necessary(self, cond):
        for child_node in gast.walk(cond):
            tensor_shape_node = None
            if isinstance(child_node, (gast.Attribute)):
                if self.is_tensor_shape(child_node):
                    tensor_shape_node = child_node
            elif isinstance(child_node, (gast.Name)):
                if child_node.id in self.name_to_tensor_shape:
                    tensor_shape_node = self.name_to_tensor_shape[
                        child_node.id]

            if tensor_shape_node:
                wrapper_node = self.node_to_wrapper_map.get(child_node)
                parent_node = wrapper_node.parent.node
                for field, value in gast.iter_fields(parent_node):
                    if child_node is value:
                        setattr(parent_node, field,
                                create_api_shape_node(tensor_shape_node))
                        break
Beispiel #13
0
    def _transform_var_shape_if_necessary(self, cond):
        need_transformed = False
        for child_node in gast.walk(cond):
            var_shape_node = None
            if isinstance(child_node,
                          (gast.Name, gast.Attribute, gast.Subscript)):
                child_name = ast_to_source_code(child_node).strip()
                if child_name in self.name_to_var_shape:
                    var_shape_node = create_choose_shape_node(
                        child_name, self.name_to_var_shape[child_name])
                elif self._is_var_shape(child_node):
                    var_shape_node = child_node

            if var_shape_node:
                need_transformed = True
                wrapper_node = self.node_to_wrapper_map.get(child_node)
                parent_node = wrapper_node.parent.node
                for field, value in gast.iter_fields(parent_node):
                    if child_node is value:
                        if var_shape_node is child_node:
                            setattr(
                                parent_node, field,
                                create_convert_shape_node(
                                    var_shape_node, None, True))
                        else:
                            setattr(parent_node, field, var_shape_node)
                        break
                    # Some child_node may be in a list such as gast.Compare
                    if isinstance(value, list):
                        has_converted_shape = False
                        for i, v in enumerate(value):
                            if child_node is v:
                                if var_shape_node is child_node:
                                    value[i] = create_convert_shape_node(
                                        var_shape_node, None, True)
                                else:
                                    value[i] = var_shape_node
                                has_converted_shape = True
                                break
                        if has_converted_shape:
                            break
        return need_transformed
    def _transform_var_shape_if_necessary(self, cond):
        need_transformed = False
        for child_node in gast.walk(cond):
            var_shape_node = None
            if isinstance(child_node, (gast.Attribute)):
                if self.is_var_shape(child_node):
                    var_shape_node = child_node
            elif isinstance(child_node, (gast.Name)):
                if child_node.id in self.name_to_var_shape:
                    var_shape_node = self.name_to_var_shape[child_node.id]

            if var_shape_node:
                need_transformed = True
                wrapper_node = self.node_to_wrapper_map.get(child_node)
                parent_node = wrapper_node.parent.node
                for field, value in gast.iter_fields(parent_node):
                    if child_node is value:
                        setattr(parent_node, field,
                                create_convert_shape_node(var_shape_node))
                        break
        return need_transformed
Beispiel #15
0
 def test_iter_child_nodes(self):
     tree = gast.UnaryOp(gast.USub(), gast.Constant(value=1, kind=None))
     self.assertEqual(len(list(gast.iter_fields(tree))), 2)
Beispiel #16
0
 def test_iter_fields(self):
     tree = gast.Constant(value=1, kind=None)
     self.assertEqual({name
                       for name, _ in gast.iter_fields(tree)},
                      {'value', 'kind'})
Beispiel #17
0
 def test_iter_child_nodes(self):
     tree = gast.UnaryOp(gast.USub(), gast.Num(n=1))
     self.assertEqual(len(list(gast.iter_fields(tree))),
                      2)
Beispiel #18
0
 def test_iter_fields(self):
     tree = gast.Num(n=1)
     self.assertEqual({name for name, _ in gast.iter_fields(tree)},
                      {'n'})