Ejemplo n.º 1
0
    def visit_Return(self, node):
        cur_func_node = self.function_def[-1]
        return_name = unique_name.generate(RETURN_PREFIX)
        self.return_name[cur_func_node].append(return_name)
        max_return_length = self.pre_analysis.get_func_max_return_length(
            cur_func_node)
        parent_node_of_return = self.ancestor_nodes[-2]

        for ancestor_index in reversed(range(len(self.ancestor_nodes) - 1)):
            ancestor = self.ancestor_nodes[ancestor_index]
            cur_node = self.ancestor_nodes[ancestor_index + 1]
            if hasattr(ancestor,
                       "body") and index_in_list(ancestor.body, cur_node) != -1:
                if cur_node == node:
                    self._replace_return_in_stmt_list(
                        ancestor.body, cur_node, return_name, max_return_length,
                        parent_node_of_return)
                self._replace_after_node_to_if_in_stmt_list(
                    ancestor.body, cur_node, return_name, parent_node_of_return)
            elif hasattr(ancestor, "orelse") and index_in_list(ancestor.orelse,
                                                               cur_node) != -1:
                if cur_node == node:
                    self._replace_return_in_stmt_list(
                        ancestor.orelse, cur_node, return_name,
                        max_return_length, parent_node_of_return)
                self._replace_after_node_to_if_in_stmt_list(
                    ancestor.orelse, cur_node, return_name,
                    parent_node_of_return)

            # If return node in while loop, add `not return_name` in gast.While.test
            if isinstance(ancestor, gast.While):
                cond_var_node = gast.UnaryOp(
                    op=gast.Not(),
                    operand=gast.Name(
                        id=return_name,
                        ctx=gast.Load(),
                        annotation=None,
                        type_comment=None))
                ancestor.test = gast.BoolOp(
                    op=gast.And(), values=[ancestor.test, cond_var_node])
                continue

            # If return node in for loop, add `not return_name` in gast.While.test
            if isinstance(ancestor, gast.For):
                cond_var_node = gast.UnaryOp(
                    op=gast.Not(),
                    operand=gast.Name(
                        id=return_name,
                        ctx=gast.Load(),
                        annotation=None,
                        type_comment=None))
                parent_node = self.ancestor_nodes[ancestor_index - 1]
                for_to_while = ForToWhileTransformer(parent_node, ancestor,
                                                     cond_var_node)
                new_stmts = for_to_while.transform()
                while_node = new_stmts[-1]
                self.ancestor_nodes[ancestor_index] = while_node

            if ancestor == cur_func_node:
                break
Ejemplo n.º 2
0
    def generic_visit(self, node):
        if isinstance(node, gast.stmt):
            if self.stack_has_flags(
                    self.for_continued_stack) or self.stack_has_flags(
                        self.for_breaked_stack) or self.stack_has_flags(
                            self.func_returned_stack):
                bool_values = []
                if self.stack_has_flags(self.for_continued_stack):
                    continued_id = len(self.for_continued_stack)
                    bool_values.append(
                        gast.UnaryOp(op=gast.Not(),
                                     operand=gast.Name(id=self.continued_flag +
                                                       str(continued_id),
                                                       ctx=gast.Load(),
                                                       annotation=None,
                                                       type_comment=None)))
                if self.stack_has_flags(self.for_breaked_stack):
                    breaked_id = len(self.for_breaked_stack)
                    bool_values.append(
                        gast.UnaryOp(op=gast.Not(),
                                     operand=gast.Name(id=self.breaked_flag +
                                                       str(breaked_id),
                                                       ctx=gast.Load(),
                                                       annotation=None,
                                                       type_comment=None)))
                if self.stack_has_flags(self.func_returned_stack):
                    returned_id = len(self.func_returned_stack)
                    bool_values.append(
                        gast.UnaryOp(op=gast.Not(),
                                     operand=gast.Name(id=self.returned_flag +
                                                       str(returned_id),
                                                       ctx=gast.Load(),
                                                       annotation=None,
                                                       type_comment=None)))

                if isinstance(node, gast.For):
                    self.for_continued_stack.append(False)
                    self.for_breaked_stack.append(False)
                elif isinstance(node, gast.FunctionDef):
                    self.func_returned_stack.append(False)

                modified_node = super().generic_visit(node)
                if len(bool_values) == 1:
                    cond = bool_values[0]
                else:
                    cond = gast.BoolOp(op=gast.And(), values=bool_values)
                replacement = gast.If(test=cond,
                                      body=[modified_node],
                                      orelse=[])
                ret = gast.copy_location(replacement, node)
            else:
                if isinstance(node, gast.For):
                    self.for_continued_stack.append(False)
                    self.for_breaked_stack.append(False)
                elif isinstance(node, gast.FunctionDef):
                    self.func_returned_stack.append(False)
                ret = super().generic_visit(node)
        else:
            ret = super().generic_visit(node)
        return ret
Ejemplo n.º 3
0
 def visit_BoolOp(self, node):
     self.generic_visit(node)
     flipped_op = BOOL_INVERSIONS[type(node.op)]
     return gast.UnaryOp(op=gast.Not(),
                         operand=gast.BoolOp(op=flipped_op(),
                                             values=[
                                                 gast.UnaryOp(op=gast.Not(),
                                                              operand=val)
                                                 for val in node.values
                                             ]))
Ejemplo n.º 4
0
 def visit_For(self, node):
     modified_node = self.generic_visit(node)
     continue_flags = self.for_continue_stack.pop()
     for flag in continue_flags:
         node.body.insert(
             0,
             gast.Assign(targets=[
                 gast.Name(id=flag, ctx=gast.Store(), annotation=None)
             ],
                         value=gast.NameConstant(value=False)))
     breaked_flags = self.for_breaked_stack.pop()
     bool_values = []
     for flag in breaked_flags:
         node.body.insert(
             0,
             gast.Assign(targets=[
                 gast.Name(id=flag, ctx=gast.Store(), annotation=None)
             ],
                         value=gast.NameConstant(value=False)))
         bool_values.append(
             gast.Name(id=flag, ctx=gast.Load(), annotation=None))
     if len(bool_values) > 0:
         if len(bool_values) == 1:
             cond = bool_values[0]
         elif len(bool_values) > 1:
             cond = gast.BoolOp(op=gast.Or(), values=bool_values)
         if isinstance(modified_node, gast.For):
             modified_node.body.append(
                 gast.Assign(targets=[
                     gast.Name(id=self.keepgoing_flag,
                               ctx=gast.Store(),
                               annotation=None)
                 ],
                             value=gast.UnaryOp(op=gast.Not(),
                                                operand=cond)))
             modified_node.body.append(
                 gast.If(test=cond, body=[gast.Break()], orelse=[]))
         elif isinstance(modified_node, gast.If):
             if isinstance(modified_node.body[0], gast.For):
                 modified_node.body[0].body.append(
                     gast.Assign(targets=[
                         gast.Name(id=self.keepgoing_flag,
                                   ctx=gast.Store(),
                                   annotation=None)
                     ],
                                 value=gast.UnaryOp(op=gast.Not(),
                                                    operand=cond)))
                 modified_node.body[0].body.append(
                     gast.If(test=cond, body=[gast.Break()], orelse=[]))
     return modified_node
Ejemplo n.º 5
0
    def visit_If(self, node):
        self.generic_visit(node)

        try:
            if ast.literal_eval(node.test):
                if not metadata.get(node, OMPDirective):
                    self.update = True
                    return node.body
            else:
                if not metadata.get(node, OMPDirective):
                    self.update = True
                    return node.orelse
        except ValueError:
            # not a constant expression
            pass

        have_body = any(not isinstance(x, ast.Pass) for x in node.body)
        have_else = any(not isinstance(x, ast.Pass) for x in node.orelse)
        # If the "body" is empty but "else content" is useful, switch branches
        # and remove else content
        if not have_body and have_else:
            test = ast.UnaryOp(op=ast.Not(), operand=node.test)
            self.update = True
            return ast.If(test=test, body=node.orelse, orelse=list())
        # if neither "if" and "else" are useful, keep test if it is not pure
        elif not have_body:
            self.update = True
            if node.test in self.pure_expressions:
                return ast.Pass()
            else:
                node = ast.Expr(value=node.test)
                self.generic_visit(node)
        return node
Ejemplo n.º 6
0
 def visit_loop_successor(self, node):
     for successor in self.cfg.successors(node):
         if successor is not node.body[0]:
             if isinstance(node, ast.While):
                 bound_range(self.result, self.aliases,
                             ast.UnaryOp(ast.Not(), node.test))
             return [successor]
Ejemplo n.º 7
0
    def visit_For(self, node):
        modified_node = self.generic_visit(node)
        continued_id = len(self.for_continued_stack)
        continued_flags = self.for_continued_stack.pop()
        if continued_flags:
            node.body.insert(
                0,
                gast.Assign(targets=[
                    gast.Name(id=self.continued_flag + str(continued_id),
                              ctx=gast.Store(),
                              annotation=None,
                              type_comment=None)
                ],
                            value=gast.Constant(value=False, kind=None)))
        breaked_id = len(self.for_breaked_stack)
        breaked_flags = self.for_breaked_stack.pop()
        bool_values = []
        if breaked_flags:
            node.body.insert(
                0,
                gast.Assign(targets=[
                    gast.Name(id=self.breaked_flag + str(breaked_id),
                              ctx=gast.Store(),
                              annotation=None,
                              type_comment=None)
                ],
                            value=gast.Constant(value=False, kind=None)))
            bool_values.append(
                gast.Name(id=self.breaked_flag + str(breaked_id),
                          ctx=gast.Load(),
                          annotation=None,
                          type_comment=None))

        if len(self.func_returned_stack) > 0:
            returned_id = len(self.func_returned_stack)
            returned_flags = self.func_returned_stack[-1]
            if returned_flags:
                bool_values.append(
                    gast.Name(id=self.returned_flag + str(returned_id),
                              ctx=gast.Load(),
                              annotation=None,
                              type_comment=None))

        if len(bool_values) > 0:
            if len(bool_values) == 1:
                cond = bool_values[0]
            elif len(bool_values) > 1:
                cond = gast.BoolOp(op=gast.Or(), values=bool_values)

            node.body.append(
                gast.Assign(targets=[
                    gast.Name(id=self.keepgoing_flag,
                              ctx=gast.Store(),
                              annotation=None,
                              type_comment=None)
                ],
                            value=gast.UnaryOp(op=gast.Not(), operand=cond)))
            node.body.append(gast.If(test=cond, body=[gast.Break()],
                                     orelse=[]))
        return modified_node
Ejemplo n.º 8
0
    def _replace_after_node_to_if_in_stmt_list(
            self, stmt_list, node, return_name, parent_node_of_return):
        i = index_in_list(stmt_list, node)
        if i < 0 or i >= len(stmt_list):
            return False
        if i == len(stmt_list) - 1:
            # No need to add, we consider this as added successfully
            return True

        if_stmt = gast.If(test=gast.UnaryOp(
            op=gast.Not(),
            operand=gast.Name(
                id=return_name,
                ctx=gast.Store(),
                annotation=None,
                type_comment=None)),
                          body=stmt_list[i + 1:],
                          orelse=[])

        stmt_list[i + 1:] = [if_stmt]

        # Here assume that the parent node of return is gast.If
        if isinstance(parent_node_of_return, gast.If):
            # Prepend control flow boolean nodes such as '__return@1 = False'
            node_str = "{} = paddle.jit.dy2static.create_bool_as_type({}, False)".format(
                return_name,
                ast_to_source_code(parent_node_of_return.test).strip())
            assign_false_node = gast.parse(node_str).body[0]

            stmt_list[i:i] = [assign_false_node]
        return True
Ejemplo n.º 9
0
 def test_fix_missing_locations(self):
     node = gast.Num(n=6)
     tree = gast.UnaryOp(gast.USub(), node)
     tree.lineno = 1
     tree.col_offset = 2
     gast.fix_missing_locations(tree)
     self.assertEqual(node.lineno, tree.lineno)
     self.assertEqual(node.col_offset, tree.col_offset)
Ejemplo n.º 10
0
 def test_fix_missing_locations(self):
     node = gast.Constant(value=6, kind=None)
     tree = gast.UnaryOp(gast.USub(), node)
     tree.lineno = 1
     tree.col_offset = 2
     gast.fix_missing_locations(tree)
     self.assertEqual(node.lineno, tree.lineno)
     self.assertEqual(node.col_offset, tree.col_offset)
Ejemplo n.º 11
0
 def generic_visit(self, node):
     if isinstance(node, gast.stmt):
         if (len(self.for_continue_stack) > 0
                 and len(self.for_continue_stack[-1]) > 0) or (
                     len(self.for_breaked_stack) > 0
                     and len(self.for_breaked_stack[-1]) > 0):
             bool_values = []
             if (len(self.for_continue_stack) > 0
                     and len(self.for_continue_stack[-1]) > 0):
                 for flag in self.for_continue_stack[-1]:
                     bool_values.append(
                         gast.UnaryOp(op=gast.Not(),
                                      operand=gast.Name(id=flag,
                                                        ctx=gast.Load(),
                                                        annotation=None)))
             if (len(self.for_breaked_stack) > 0
                     and len(self.for_breaked_stack[-1]) > 0):
                 for flag in self.for_breaked_stack[-1]:
                     bool_values.append(
                         gast.UnaryOp(op=gast.Not(),
                                      operand=gast.Name(id=flag,
                                                        ctx=gast.Load(),
                                                        annotation=None)))
             if isinstance(node, gast.For):
                 self.for_continue_stack.append([])
                 self.for_breaked_stack.append([])
             node = super().generic_visit(node)
             if len(bool_values) == 1:
                 cond = bool_values[0]
             else:
                 cond = gast.BoolOp(op=gast.And(), values=bool_values)
             replacement = gast.If(test=cond, body=[node], orelse=[])
             ret = gast.copy_location(replacement, node)
         else:
             if isinstance(node, gast.For):
                 self.for_continue_stack.append([])
                 self.for_breaked_stack.append([])
             ret = super().generic_visit(node)
     else:
         ret = super().generic_visit(node)
     return ret
Ejemplo n.º 12
0
    def test_NodeTransformer(self):
        node = gast.Constant(value=6, kind=None)
        tree = gast.UnaryOp(gast.USub(), node)

        class Trans(gast.NodeTransformer):
            def visit_Constant(self, node):
                node.value *= 2
                return node

        tree = Trans().visit(tree)

        self.assertEqual(node.value, 12)
Ejemplo n.º 13
0
    def visit_BinOp(self, node):
        self.generic_visit(node)
        left_val = node.left
        right_val = node.right
        left_is_num = isinstance(left_val, gast.Num)
        right_is_num = isinstance(right_val, gast.Num)

        if isinstance(node.op, gast.Mult):
            if left_is_num and right_is_num:
                return gast.Num(left_val.n * right_val.n)
            if left_is_num:
                if left_val.n == 0:
                    return gast.Num(0)
                elif left_val.n == 1:
                    return right_val
            if right_is_num:
                if right_val.n == 0:
                    return gast.Num(0)
                elif right_val.n == 1:
                    return left_val
        elif isinstance(node.op, gast.Add):
            if left_is_num and right_is_num:
                return gast.Num(left_val.n + right_val.n)
            if left_is_num and left_val.n == 0:
                return right_val
            if right_is_num and right_val.n == 0:
                return left_val
        elif isinstance(node.op, gast.Sub):
            if left_is_num and right_is_num:
                return gast.Num(left_val.n - right_val.n)
            if left_is_num and left_val.n == 0:
                return gast.UnaryOp(op=gast.USub(), operand=right_val)
            if right_is_num and right_val.n == 0:
                return left_val
        elif isinstance(node.op, gast.Div):
            if left_is_num and right_is_num:
                return gast.Num(left_val.n / right_val.n)
            if right_is_num and right_val.n == 1:
                return left_val
        elif isinstance(node.op, gast.Pow):
            if left_is_num and right_is_num:
                return gast.Num(left_val.n**right_val.n)
            if left_is_num:
                if left_val.n == 0:
                    return gast.Num(0)
                elif left_val.n == 1:
                    return gast.Num(1)
            if right_is_num:
                if right_val.n == 0:
                    return gast.Num(1)
                elif right_val.n == 1:
                    return left_val
        return node
Ejemplo n.º 14
0
    def test_NodeTransformer(self):
        node = gast.Num(n=6)
        tree = gast.UnaryOp(gast.USub(), node)

        class Trans(gast.NodeTransformer):

            def visit_Num(self, node):
                node.n *= 2
                return node

        tree = Trans().visit(tree)

        self.assertEqual(node.n, 12)
Ejemplo n.º 15
0
    def test_NodeVisitor(self):
        node = gast.Constant(value=6, kind=None)
        tree = gast.UnaryOp(gast.USub(), node)

        class Vis(gast.NodeTransformer):
            def __init__(self):
                self.state = []

            def visit_Constant(self, node):
                self.state.append(node.value)

        vis = Vis()
        vis.visit(tree)

        self.assertEqual(vis.state, [6])
Ejemplo n.º 16
0
 def visit_Compare(self, node):
     self.generic_visit(node)
     all_comparators = [node.left] + node.comparators
     if len(all_comparators) == 2:
         # Replace `a < b` with `not a >= b`
         inverted_op = OP_INVERSIONS[type(node.ops[0])]
         return gast.UnaryOp(op=gast.Not(),
                             operand=gast.Compare(
                                 left=node.left,
                                 ops=[inverted_op()],
                                 comparators=node.comparators))
     else:
         # Replace `a < b < c` with `not (a >= b or b >= c)`
         or_clauses = []
         for left, op, right in zip(all_comparators[:-1], node.ops,
                                    all_comparators[1:]):
             inverted_op = OP_INVERSIONS[type(op)]
             or_clauses.append(
                 gast.Compare(left=left,
                              ops=[inverted_op()],
                              comparators=[right]))
         return gast.UnaryOp(op=gast.Not(),
                             operand=gast.BoolOp(op=gast.Or(),
                                                 values=or_clauses))
Ejemplo n.º 17
0
 def _build_print_call_node(self, node):
     return gast.Call(
         func=gast.parse('fluid.layers.Print').body[0].value,
         args=[node],
         keywords=[
             gast.keyword(
                 arg='summarize',
                 value=gast.UnaryOp(
                     op=gast.USub(),
                     operand=gast.Constant(
                         value=1, kind=None))), gast.keyword(
                             arg='print_phase',
                             value=gast.Constant(
                                 value='forward', kind=None))
         ])
Ejemplo n.º 18
0
    def visit_Compare(self, node):
        self.generic_visit(node)
        noned_var, negated = self.match_is_none(node)
        if noned_var is None:
            return node
        call = ast.Call(
            ast.Attribute(
                ast.Attribute(ast.Name('__builtin__', ast.Load(), None),
                              'pythran', ast.Load()), 'is_none', ast.Load()),
            [noned_var], [])

        if negated:
            return ast.UnaryOp(ast.Not(), call)
        else:
            return call
Ejemplo n.º 19
0
    def test_NodeVisitor(self):
        node = gast.Num(n=6)
        tree = gast.UnaryOp(gast.USub(), node)

        class Vis(gast.NodeTransformer):

            def __init__(self):
                self.state = []

            def visit_Num(self, node):
                self.state.append(node.n)

        vis = Vis()
        vis.visit(tree)

        self.assertEqual(vis.state, [6])
Ejemplo n.º 20
0
    def inlineFixedSizeArrayUnaryOp(self, node):

        if isinstance(node.operand, (ast.Num, ast.List, ast.Tuple)):
            return node

        base, size = self.fixedSizeArray(node.operand)
        if not base:
            return node

        self.update = True

        operands = [
            ast.UnaryOp(type(node.op)(), self.make_array_index(base, size, i))
            for i in range(size)
        ]
        res = ast.Call(path_to_attr(('numpy', 'array')),
                       [ast.Tuple(operands, ast.Load())], [])
        self.aliases[res.func] = {path_to_node(('numpy', 'array'))}
        return res
Ejemplo n.º 21
0
 def _replace_after_node_to_if_in_stmt_list(self, stmt_list, node,
                                            return_name):
     i = index_in_list(stmt_list, node)
     if i < 0 or i >= len(stmt_list):
         return False
     if i == len(stmt_list) - 1:
         # No need to add, we consider this as added successfully
         return True
     if_stmt = gast.If(test=gast.UnaryOp(
         op=gast.Not(),
         operand=gast.Name(
             id=return_name,
             ctx=gast.Store(),
             annotation=None,
             type_comment=None)),
                       body=stmt_list[i + 1:],
                       orelse=[])
     stmt_list[i + 1:] = [if_stmt]
     return True
Ejemplo n.º 22
0
    def visit_For(self, node):
        self.generic_visit(node.target)
        self.generic_visit(node.iter)
        scope = anno.getanno(node, 'body_scope')

        break_var = self.namer.new_symbol('break_requested', scope.referenced)
        self.break_uses.append([False, break_var])
        node.body = self._manual_visit_list(node.body)
        if self.break_uses[-1][0]:
            anno.setanno(
                node, 'extra_cond',
                gast.UnaryOp(gast.Not(), gast.Name(break_var, gast.Load(),
                                                   None)))
            final_nodes = [self._create_break_init(), node]
        else:
            final_nodes = node
        self.break_uses.pop()

        for n in node.orelse:
            self.generic_visit(n)
        return final_nodes
Ejemplo n.º 23
0
    def visit_While(self, node):
        self.generic_visit(node.test)
        scope = anno.getanno(node, NodeAnno.BODY_SCOPE)

        break_var = self.context.namer.new_symbol('break_requested',
                                                  scope.referenced)
        self.break_uses.append([False, break_var])
        node.body = self._manual_visit_list(node.body)
        if self.break_uses[-1][0]:
            node.test = gast.BoolOp(gast.And(), [
                node.test,
                gast.UnaryOp(gast.Not(), gast.Name(break_var, gast.Load(),
                                                   None))
            ])
            final_nodes = [self._create_break_init(), node]
        else:
            final_nodes = node
        self.break_uses.pop()

        for n in node.orelse:
            self.generic_visit(n)
        return final_nodes
Ejemplo n.º 24
0
 def visit(self, node):
     if node in self.to_remove:
         self.remove = True
     if anno.hasanno(node, 'pri_call') or anno.hasanno(node, 'adj_call'):
         # We don't remove function calls for now; removing them also
         # removes the push statements inside of them, but not the
         # corresponding pop statements
         self.is_call = True
     new_node = super(Remove, self).visit(node)
     if isinstance(node, grammar.STATEMENTS):
         if self.remove and not self.is_call:
             new_node = None
         self.remove = self.is_call = False
     if isinstance(node, gast.If) and not node.body:
         # If we optimized away an entire if block, we need to handle that
         if not node.orelse:
             return
         else:
             node.test = gast.UnaryOp(op=gast.Not(), operand=node.test)
             node.body, node.orelse = node.orelse, node.body
     elif isinstance(node, (gast.While, gast.For)) and not node.body:
         return node.orelse
     return new_node
    def visit_Break(self, node):
        loop_node_index = self._find_ancestor_loop_index(node)
        assert loop_node_index != -1, "SyntaxError: 'break' outside loop"
        loop_node = self.ancestor_nodes[loop_node_index]

        # 1. Map the 'break/continue' stmt with an unique boolean variable V.
        variable_name = unique_name.generate(BREAK_NAME_PREFIX)

        # 2. Find the first ancestor block containing this 'break/continue', a
        # block can be a node containing stmt list. We should remove all stmts
        # after the 'break/continue' and set the V to True here.
        first_block_index = self._remove_stmts_after_break_continue(
            node, variable_name, loop_node_index)

        # 3. Add 'if V' for stmts in ancestor blocks between the first one
        # (exclusive) and the ancestor loop (inclusive)
        self._replace_if_stmt(loop_node_index, first_block_index, variable_name)

        # 4. For 'break' add break into condition of the loop.
        assign_false_node = create_fill_constant_node(variable_name, False)
        self._add_stmt_before_cur_node(loop_node_index, assign_false_node)

        cond_var_node = gast.UnaryOp(
            op=gast.Not(),
            operand=gast.Name(
                id=variable_name,
                ctx=gast.Load(),
                annotation=None,
                type_comment=None))
        if isinstance(loop_node, gast.While):
            loop_node.test = gast.BoolOp(
                op=gast.And(), values=[loop_node.test, cond_var_node])
        elif isinstance(loop_node, gast.For):
            parent_node = self.ancestor_nodes[loop_node_index - 1]
            for_to_while = ForToWhileTransformer(parent_node, loop_node,
                                                 cond_var_node)
            for_to_while.transform()
Ejemplo n.º 26
0
 def test_iter_child_nodes(self):
     tree = gast.UnaryOp(gast.USub(), gast.Constant(value=1, kind=None))
     self.assertEqual(len(list(gast.iter_fields(tree))), 2)
Ejemplo n.º 27
0
 def invert(self, node):
     return gast.UnaryOp(op=gast.Not(), operand=node)
Ejemplo n.º 28
0
    def visit_If(self, node):
        """ Handle iterate variable across branches

        >>> import gast as ast
        >>> from pythran import passmanager, backend
        >>> pm = passmanager.PassManager("test")

        >>> node = ast.parse('''
        ... def foo(a):
        ...     if a > 1: b = 1
        ...     else: b = 3
        ...     pass''')

        >>> res = pm.gather(RangeValues, node)
        >>> res['b']
        Interval(low=1, high=3)

        >>> node = ast.parse('''
        ... def foo(a):
        ...     if a > 1: b = a
        ...     else: b = 3
        ...     pass''')
        >>> res = pm.gather(RangeValues, node)
        >>> res['b']
        Interval(low=2, high=inf)

        >>> node = ast.parse('''
        ... def foo(a):
        ...     if 0 < a < 4: b = a
        ...     else: b = 3
        ...     pass''')
        >>> res = pm.gather(RangeValues, node)
        >>> res['b']
        Interval(low=1, high=3)

        >>> node = ast.parse('''
        ... def foo(a):
        ...     if (0 < a) and (a < 4): b = a
        ...     else: b = 3
        ...     pass''')
        >>> res = pm.gather(RangeValues, node)
        >>> res['b']
        Interval(low=1, high=3)

        >>> node = ast.parse('''
        ... def foo(a):
        ...     if (a == 1) or (a == 2): b = a
        ...     else: b = 3
        ...     pass''')
        >>> res = pm.gather(RangeValues, node)
        >>> res['b']
        Interval(low=1, high=3)

        >>> node = ast.parse('''
        ... def foo(a):
        ...     b = 5
        ...     if a > 0: b = a
        ...     pass''')
        >>> res = pm.gather(RangeValues, node)
        >>> res['a'], res['b']
        (Interval(low=-inf, high=inf), Interval(low=1, high=inf))

        >>> node = ast.parse('''
        ... def foo(a):
        ...     if a > 3: b = 1
        ...     else: b = 2
        ...     if a > 1: b = 2
        ...     pass''')
        >>> res = pm.gather(RangeValues, node)
        >>> res['b']
        Interval(low=2, high=2)
        """
        # handling each branch becomes too costly, opt for a simpler,
        # less accurate algorithm.
        if self.no_if_split == 4:
            raise RangeValueTooCostly()

        self.no_if_split += 1

        test_range = self.visit(node.test)
        init_state = self.result.copy()

        if 1 in test_range:
            bound_range(self.result, self.aliases, node.test)
            self.cfg_visit(node.body[0])

        visited_successors = {node.body[0]}

        if node.orelse:
            if 0 in test_range:
                prev_state = self.result
                self.result = init_state.copy()
                bound_range(self.result, self.aliases,
                            ast.UnaryOp(ast.Not(), node.test))
                self.cfg_visit(node.orelse[0])
                self.unionify(prev_state)
            visited_successors.add(node.orelse[0])

        elif 0 in test_range:
            successors = self.cfg.successors(node)
            for successor in list(successors):
                # no else branch
                if successor not in visited_successors:
                    self.result, prev_state = init_state.copy(), self.result
                    bound_range(self.result, self.aliases,
                                ast.UnaryOp(ast.Not(), node.test))
                    self.cfg_visit(successor)
                    self.unionify(prev_state)

        self.no_if_split -= 1
Ejemplo n.º 29
0
 def generate_UnaryOp(self):
   operand = self.generate_Name()
   op = UnaryOpSampler().sample()()
   return gast.UnaryOp(op, operand)
Ejemplo n.º 30
0
def negate(node):
    if isinstance(node, ast.Name):
        # Not type info, could be anything :(
        raise UnsupportedExpression()

    if isinstance(node, ast.UnaryOp):
        # !~x <> ~x == 0 <> x == ~0 <> x == -1
        if isinstance(node.op, ast.Invert):
            return ast.Compare(node.operand,
                               [ast.Eq()],
                               [ast.Constant(-1, None)])
        # !!x <> x
        if isinstance(node.op, ast.Not):
            return node.operand
        # !+x <> +x == 0 <> x == 0 <> !x
        if isinstance(node.op, ast.UAdd):
            return node.operand
        # !-x <> -x == 0 <> x == 0 <> !x
        if isinstance(node.op, ast.USub):
            return node.operand

    if isinstance(node, ast.BoolOp):
        new_values = [ast.UnaryOp(ast.Not(), v) for v in node.values]
        # !(x or y) <> !x and !y
        if isinstance(node.op, ast.Or):
            return ast.BoolOp(ast.And(), new_values)
        # !(x and y) <> !x or !y
        if isinstance(node.op, ast.And):
            return ast.BoolOp(ast.Or(), new_values)

    if isinstance(node, ast.Compare):
        cmps = [ast.Compare(x, [negate(o)], [y])
                for x, o, y
                in zip([node.left] + node.comparators[:-1], node.ops,
                       node.comparators)]
        if len(cmps) == 1:
            return cmps[0]
        return ast.BoolOp(ast.Or(), cmps)

    if isinstance(node, ast.Eq):
        return ast.NotEq()
    if isinstance(node, ast.NotEq):
        return ast.Eq()
    if isinstance(node, ast.Gt):
        return ast.LtE()
    if isinstance(node, ast.GtE):
        return ast.Lt()
    if isinstance(node, ast.Lt):
        return ast.GtE()
    if isinstance(node, ast.LtE):
        return ast.Gt()
    if isinstance(node, ast.In):
        return ast.NotIn()
    if isinstance(node, ast.NotIn):
        return ast.In()

    if isinstance(node, ast.Attribute):
        if node.attr == 'False':
            return ast.Constant(True, None)
        if node.attr == 'True':
            return ast.Constant(False, None)

    raise UnsupportedExpression()