コード例 #1
0
 def visit_Assign(self, node):
     if len(node.targets) > 1:
         raise NotImplementedError("cannot process multiple assignment")
     if not isinstance(node.targets[0], gast.Name):
         raise NotImplementedError("cannot process indexed assignment")
     # $lhs = $lhs.update_($rhs, matchbox.EXECUTION_MASK) if (lhs in vars()
     # or lhs in globals()) and isinstance($lhs, (matchbox.MaskedBatch,
     # matchbox.TENSOR_TYPE)) else $rhs
     node.value = gast.IfExp(
         gast.BoolOp(
             gast.And(),
             [
                 gast.BoolOp(gast.Or(), [
                     gast.Compare(gast.Str(
                         node.targets[0].id), [gast.In()], [
                             gast.Call(gast.Name('vars', gast.Load, None),
                                       [], [])
                         ]),
                     gast.Compare(gast.Str(
                         node.targets[0].id), [gast.In()], [
                             gast.Call(
                                 gast.Name('globals', gast.Load, None), [],
                                 [])
                         ])
                 ]),
                 # gast.Compare(
                 #    gast.Attribute(
                 #      gast.Name('matchbox', gast.Load(), None),
                 #      gast.Name('EXECUTION_MASK', gast.Load(), None),
                 #      gast.Load()),
                 #    [gast.IsNot()],
                 #    [gast.NameConstant(None)]),
                 gast.Call(gast.Name('isinstance', gast.Load(), None), [
                     node.targets[0],
                     gast.Tuple([
                         gast.Attribute(
                             gast.Name('matchbox', gast.Load(), None),
                             gast.Name('MaskedBatch', gast.Load(), None),
                             gast.Load()),
                         gast.Attribute(
                             gast.Name('matchbox', gast.Load(), None),
                             gast.Name('TENSOR_TYPE', gast.Load(), None),
                             gast.Load())
                     ], gast.Load())
                 ], [])
             ]),
         gast.Call(
             gast.Attribute(
                 gast.Name(node.targets[0].id, gast.Load(), None),
                 gast.Name('_update', gast.Load(), None), gast.Load()), [
                     node.value,
                     gast.Attribute(
                         gast.Name('matchbox', gast.Load(), None),
                         gast.Name('EXECUTION_MASK', gast.Load(), None),
                         gast.Load())
                 ], []),
         node.value)
     return node
コード例 #2
0
    def visit_For(self, node):
        modified_node = self.generic_visit(node)
        continued_id = len(self.for_continued_stack)
        continued_flags = self.for_continued_stack.pop()
        if continued_flags:
            node.body.insert(
                0,
                gast.Assign(targets=[
                    gast.Name(id=self.continued_flag + str(continued_id),
                              ctx=gast.Store(),
                              annotation=None,
                              type_comment=None)
                ],
                            value=gast.Constant(value=False, kind=None)))
        breaked_id = len(self.for_breaked_stack)
        breaked_flags = self.for_breaked_stack.pop()
        bool_values = []
        if breaked_flags:
            node.body.insert(
                0,
                gast.Assign(targets=[
                    gast.Name(id=self.breaked_flag + str(breaked_id),
                              ctx=gast.Store(),
                              annotation=None,
                              type_comment=None)
                ],
                            value=gast.Constant(value=False, kind=None)))
            bool_values.append(
                gast.Name(id=self.breaked_flag + str(breaked_id),
                          ctx=gast.Load(),
                          annotation=None,
                          type_comment=None))

        if len(self.func_returned_stack) > 0:
            returned_id = len(self.func_returned_stack)
            returned_flags = self.func_returned_stack[-1]
            if returned_flags:
                bool_values.append(
                    gast.Name(id=self.returned_flag + str(returned_id),
                              ctx=gast.Load(),
                              annotation=None,
                              type_comment=None))

        if len(bool_values) > 0:
            if len(bool_values) == 1:
                cond = bool_values[0]
            elif len(bool_values) > 1:
                cond = gast.BoolOp(op=gast.Or(), values=bool_values)

            node.body.append(
                gast.Assign(targets=[
                    gast.Name(id=self.keepgoing_flag,
                              ctx=gast.Store(),
                              annotation=None,
                              type_comment=None)
                ],
                            value=gast.UnaryOp(op=gast.Not(), operand=cond)))
            node.body.append(gast.If(test=cond, body=[gast.Break()],
                                     orelse=[]))
        return modified_node
コード例 #3
0
    def generic_visit(self, node):
        if isinstance(node, gast.stmt):
            if self.stack_has_flags(
                    self.for_continued_stack) or self.stack_has_flags(
                        self.for_breaked_stack) or self.stack_has_flags(
                            self.func_returned_stack):
                bool_values = []
                if self.stack_has_flags(self.for_continued_stack):
                    continued_id = len(self.for_continued_stack)
                    bool_values.append(
                        gast.UnaryOp(op=gast.Not(),
                                     operand=gast.Name(id=self.continued_flag +
                                                       str(continued_id),
                                                       ctx=gast.Load(),
                                                       annotation=None,
                                                       type_comment=None)))
                if self.stack_has_flags(self.for_breaked_stack):
                    breaked_id = len(self.for_breaked_stack)
                    bool_values.append(
                        gast.UnaryOp(op=gast.Not(),
                                     operand=gast.Name(id=self.breaked_flag +
                                                       str(breaked_id),
                                                       ctx=gast.Load(),
                                                       annotation=None,
                                                       type_comment=None)))
                if self.stack_has_flags(self.func_returned_stack):
                    returned_id = len(self.func_returned_stack)
                    bool_values.append(
                        gast.UnaryOp(op=gast.Not(),
                                     operand=gast.Name(id=self.returned_flag +
                                                       str(returned_id),
                                                       ctx=gast.Load(),
                                                       annotation=None,
                                                       type_comment=None)))

                if isinstance(node, gast.For):
                    self.for_continued_stack.append(False)
                    self.for_breaked_stack.append(False)
                elif isinstance(node, gast.FunctionDef):
                    self.func_returned_stack.append(False)

                modified_node = super().generic_visit(node)
                if len(bool_values) == 1:
                    cond = bool_values[0]
                else:
                    cond = gast.BoolOp(op=gast.And(), values=bool_values)
                replacement = gast.If(test=cond,
                                      body=[modified_node],
                                      orelse=[])
                ret = gast.copy_location(replacement, node)
            else:
                if isinstance(node, gast.For):
                    self.for_continued_stack.append(False)
                    self.for_breaked_stack.append(False)
                elif isinstance(node, gast.FunctionDef):
                    self.func_returned_stack.append(False)
                ret = super().generic_visit(node)
        else:
            ret = super().generic_visit(node)
        return ret
コード例 #4
0
    def visit_Return(self, node):
        cur_func_node = self.function_def[-1]
        return_name = unique_name.generate(RETURN_PREFIX)
        self.return_name[cur_func_node].append(return_name)
        max_return_length = self.pre_analysis.get_func_max_return_length(
            cur_func_node)
        parent_node_of_return = self.ancestor_nodes[-2]

        for ancestor_index in reversed(range(len(self.ancestor_nodes) - 1)):
            ancestor = self.ancestor_nodes[ancestor_index]
            cur_node = self.ancestor_nodes[ancestor_index + 1]
            if hasattr(ancestor,
                       "body") and index_in_list(ancestor.body, cur_node) != -1:
                if cur_node == node:
                    self._replace_return_in_stmt_list(
                        ancestor.body, cur_node, return_name, max_return_length,
                        parent_node_of_return)
                self._replace_after_node_to_if_in_stmt_list(
                    ancestor.body, cur_node, return_name, parent_node_of_return)
            elif hasattr(ancestor, "orelse") and index_in_list(ancestor.orelse,
                                                               cur_node) != -1:
                if cur_node == node:
                    self._replace_return_in_stmt_list(
                        ancestor.orelse, cur_node, return_name,
                        max_return_length, parent_node_of_return)
                self._replace_after_node_to_if_in_stmt_list(
                    ancestor.orelse, cur_node, return_name,
                    parent_node_of_return)

            # If return node in while loop, add `not return_name` in gast.While.test
            if isinstance(ancestor, gast.While):
                cond_var_node = gast.UnaryOp(
                    op=gast.Not(),
                    operand=gast.Name(
                        id=return_name,
                        ctx=gast.Load(),
                        annotation=None,
                        type_comment=None))
                ancestor.test = gast.BoolOp(
                    op=gast.And(), values=[ancestor.test, cond_var_node])
                continue

            # If return node in for loop, add `not return_name` in gast.While.test
            if isinstance(ancestor, gast.For):
                cond_var_node = gast.UnaryOp(
                    op=gast.Not(),
                    operand=gast.Name(
                        id=return_name,
                        ctx=gast.Load(),
                        annotation=None,
                        type_comment=None))
                parent_node = self.ancestor_nodes[ancestor_index - 1]
                for_to_while = ForToWhileTransformer(parent_node, ancestor,
                                                     cond_var_node)
                new_stmts = for_to_while.transform()
                while_node = new_stmts[-1]
                self.ancestor_nodes[ancestor_index] = while_node

            if ancestor == cur_func_node:
                break
コード例 #5
0
 def visit_BoolOp(self, node):
     self.generic_visit(node)
     flipped_op = BOOL_INVERSIONS[type(node.op)]
     return gast.UnaryOp(op=gast.Not(),
                         operand=gast.BoolOp(op=flipped_op(),
                                             values=[
                                                 gast.UnaryOp(op=gast.Not(),
                                                              operand=val)
                                                 for val in node.values
                                             ]))
コード例 #6
0
 def visit_For(self, node):
     modified_node = self.generic_visit(node)
     continue_flags = self.for_continue_stack.pop()
     for flag in continue_flags:
         node.body.insert(
             0,
             gast.Assign(targets=[
                 gast.Name(id=flag, ctx=gast.Store(), annotation=None)
             ],
                         value=gast.NameConstant(value=False)))
     breaked_flags = self.for_breaked_stack.pop()
     bool_values = []
     for flag in breaked_flags:
         node.body.insert(
             0,
             gast.Assign(targets=[
                 gast.Name(id=flag, ctx=gast.Store(), annotation=None)
             ],
                         value=gast.NameConstant(value=False)))
         bool_values.append(
             gast.Name(id=flag, ctx=gast.Load(), annotation=None))
     if len(bool_values) > 0:
         if len(bool_values) == 1:
             cond = bool_values[0]
         elif len(bool_values) > 1:
             cond = gast.BoolOp(op=gast.Or(), values=bool_values)
         if isinstance(modified_node, gast.For):
             modified_node.body.append(
                 gast.Assign(targets=[
                     gast.Name(id=self.keepgoing_flag,
                               ctx=gast.Store(),
                               annotation=None)
                 ],
                             value=gast.UnaryOp(op=gast.Not(),
                                                operand=cond)))
             modified_node.body.append(
                 gast.If(test=cond, body=[gast.Break()], orelse=[]))
         elif isinstance(modified_node, gast.If):
             if isinstance(modified_node.body[0], gast.For):
                 modified_node.body[0].body.append(
                     gast.Assign(targets=[
                         gast.Name(id=self.keepgoing_flag,
                                   ctx=gast.Store(),
                                   annotation=None)
                     ],
                                 value=gast.UnaryOp(op=gast.Not(),
                                                    operand=cond)))
                 modified_node.body[0].body.append(
                     gast.If(test=cond, body=[gast.Break()], orelse=[]))
     return modified_node
コード例 #7
0
 def make_Iterator(self, gen):
     if gen.ifs:
         ldFilter = ast.Lambda(
             ast.arguments([ast.Name(gen.target.id, ast.Param(), None)],
                           None, [], [], None, []),
             ast.BoolOp(ast.And(), gen.ifs)
             if len(gen.ifs) > 1 else gen.ifs[0])
         ifilterName = ast.Attribute(value=ast.Name(id=MODULE,
                                                    ctx=ast.Load(),
                                                    annotation=None),
                                     attr=IFILTER,
                                     ctx=ast.Load())
         return ast.Call(ifilterName, [ldFilter, gen.iter], [])
     else:
         return gen.iter
コード例 #8
0
    def get_for_stmt_nodes(self, node):
        assert isinstance(
            node, gast.For), "Input node is NOT gast.For in get_for_stmt_nodes"

        # 1. parse current gast.For node
        current_for_node_parser = ForNodeVisitor(node)
        stmts_tuple = current_for_node_parser.parse()
        if stmts_tuple is None:
            return [node]
        init_stmts, cond_stmt, body_stmts = stmts_tuple

        # 2. append break statement
        new_cond_stmt = gast.BoolOp(
            op=gast.And(), values=[cond_stmt, self.condition_node])

        # 3. construct gast.While node
        while_node = gast.While(
            test=new_cond_stmt, body=body_stmts, orelse=node.orelse)
        init_stmts.append(while_node)
        return init_stmts
コード例 #9
0
    def get_for_args_stmts(self, iter_name, args_list):
        '''
        Returns 3 gast stmt nodes for argument.
        1. Initailize of iterate variable
        2. Condition for the loop
        3. Statement for changing of iterate variable during the loop
        '''
        len_range_args = len(args_list)
        assert len_range_args >= 1 and len_range_args <= 3, "range() function takes 1 to 3 arguments"
        if len_range_args == 1:
            init_stmt = get_constant_variable_node(iter_name, 0)
        else:
            init_stmt = gast.Assign(targets=[
                gast.Name(id=iter_name,
                          ctx=gast.Store(),
                          annotation=None,
                          type_comment=None)
            ],
                                    value=args_list[0])

        range_max_node = args_list[0] if len_range_args == 1 else args_list[1]
        step_node = args_list[2] if len_range_args == 3 else gast.Constant(
            value=1, kind=None)

        old_cond_stmt = gast.Compare(left=gast.BinOp(left=gast.Name(
            id=iter_name, ctx=gast.Load(), annotation=None, type_comment=None),
                                                     op=gast.Add(),
                                                     right=step_node),
                                     ops=[gast.LtE()],
                                     comparators=[range_max_node])
        cond_stmt = gast.BoolOp(op=gast.And(),
                                values=[old_cond_stmt, self.condition_node])

        change_stmt = gast.AugAssign(target=gast.Name(id=iter_name,
                                                      ctx=gast.Store(),
                                                      annotation=None,
                                                      type_comment=None),
                                     op=gast.Add(),
                                     value=step_node)

        return init_stmt, cond_stmt, change_stmt
コード例 #10
0
 def generic_visit(self, node):
     if isinstance(node, gast.stmt):
         if (len(self.for_continue_stack) > 0
                 and len(self.for_continue_stack[-1]) > 0) or (
                     len(self.for_breaked_stack) > 0
                     and len(self.for_breaked_stack[-1]) > 0):
             bool_values = []
             if (len(self.for_continue_stack) > 0
                     and len(self.for_continue_stack[-1]) > 0):
                 for flag in self.for_continue_stack[-1]:
                     bool_values.append(
                         gast.UnaryOp(op=gast.Not(),
                                      operand=gast.Name(id=flag,
                                                        ctx=gast.Load(),
                                                        annotation=None)))
             if (len(self.for_breaked_stack) > 0
                     and len(self.for_breaked_stack[-1]) > 0):
                 for flag in self.for_breaked_stack[-1]:
                     bool_values.append(
                         gast.UnaryOp(op=gast.Not(),
                                      operand=gast.Name(id=flag,
                                                        ctx=gast.Load(),
                                                        annotation=None)))
             if isinstance(node, gast.For):
                 self.for_continue_stack.append([])
                 self.for_breaked_stack.append([])
             node = super().generic_visit(node)
             if len(bool_values) == 1:
                 cond = bool_values[0]
             else:
                 cond = gast.BoolOp(op=gast.And(), values=bool_values)
             replacement = gast.If(test=cond, body=[node], orelse=[])
             ret = gast.copy_location(replacement, node)
         else:
             if isinstance(node, gast.For):
                 self.for_continue_stack.append([])
                 self.for_breaked_stack.append([])
             ret = super().generic_visit(node)
     else:
         ret = super().generic_visit(node)
     return ret
コード例 #11
0
    def visit_While(self, node):
        self.generic_visit(node.test)
        scope = anno.getanno(node, 'body_scope')

        break_var = self.namer.new_symbol('break_requested', scope.referenced)
        self.break_uses.append([False, break_var])
        node.body = self._manual_visit_list(node.body)
        if self.break_uses[-1][0]:
            node.test = gast.BoolOp(gast.And(), [
                node.test,
                gast.UnaryOp(gast.Not(), gast.Name(break_var, gast.Load(),
                                                   None))
            ])
            final_nodes = [self._create_break_init(), node]
        else:
            final_nodes = node
        self.break_uses.pop()

        for n in node.orelse:
            self.generic_visit(n)
        return final_nodes
コード例 #12
0
    def visit_Break(self, node):
        loop_node_index = self._find_ancestor_loop_index(node)
        assert loop_node_index != -1, "SyntaxError: 'break' outside loop"
        loop_node = self.ancestor_nodes[loop_node_index]

        # 1. Map the 'break/continue' stmt with an unique boolean variable V.
        variable_name = unique_name.generate(BREAK_NAME_PREFIX)

        # 2. Find the first ancestor block containing this 'break/continue', a
        # block can be a node containing stmt list. We should remove all stmts
        # after the 'break/continue' and set the V to True here.
        first_block_index = self._remove_stmts_after_break_continue(
            node, variable_name, loop_node_index)

        # 3. Add 'if V' for stmts in ancestor blocks between the first one
        # (exclusive) and the ancestor loop (inclusive)
        self._replace_if_stmt(loop_node_index, first_block_index, variable_name)

        # 4. For 'break' add break into condition of the loop.
        assign_false_node = create_fill_constant_node(variable_name, False)
        self._add_stmt_before_cur_node(loop_node_index, assign_false_node)

        cond_var_node = gast.UnaryOp(
            op=gast.Not(),
            operand=gast.Name(
                id=variable_name,
                ctx=gast.Load(),
                annotation=None,
                type_comment=None))
        if isinstance(loop_node, gast.While):
            loop_node.test = gast.BoolOp(
                op=gast.And(), values=[loop_node.test, cond_var_node])
        elif isinstance(loop_node, gast.For):
            parent_node = self.ancestor_nodes[loop_node_index - 1]
            for_to_while = ForToWhileTransformer(parent_node, loop_node,
                                                 cond_var_node)
            for_to_while.transform()
コード例 #13
0
 def visit_Compare(self, node):
     self.generic_visit(node)
     all_comparators = [node.left] + node.comparators
     if len(all_comparators) == 2:
         # Replace `a < b` with `not a >= b`
         inverted_op = OP_INVERSIONS[type(node.ops[0])]
         return gast.UnaryOp(op=gast.Not(),
                             operand=gast.Compare(
                                 left=node.left,
                                 ops=[inverted_op()],
                                 comparators=node.comparators))
     else:
         # Replace `a < b < c` with `not (a >= b or b >= c)`
         or_clauses = []
         for left, op, right in zip(all_comparators[:-1], node.ops,
                                    all_comparators[1:]):
             inverted_op = OP_INVERSIONS[type(op)]
             or_clauses.append(
                 gast.Compare(left=left,
                              ops=[inverted_op()],
                              comparators=[right]))
         return gast.UnaryOp(op=gast.Not(),
                             operand=gast.BoolOp(op=gast.Or(),
                                                 values=or_clauses))
コード例 #14
0
def negate(node):
    if isinstance(node, ast.Name):
        # Not type info, could be anything :(
        raise UnsupportedExpression()

    if isinstance(node, ast.UnaryOp):
        # !~x <> ~x == 0 <> x == ~0 <> x == -1
        if isinstance(node.op, ast.Invert):
            return ast.Compare(node.operand,
                               [ast.Eq()],
                               [ast.Constant(-1, None)])
        # !!x <> x
        if isinstance(node.op, ast.Not):
            return node.operand
        # !+x <> +x == 0 <> x == 0 <> !x
        if isinstance(node.op, ast.UAdd):
            return node.operand
        # !-x <> -x == 0 <> x == 0 <> !x
        if isinstance(node.op, ast.USub):
            return node.operand

    if isinstance(node, ast.BoolOp):
        new_values = [ast.UnaryOp(ast.Not(), v) for v in node.values]
        # !(x or y) <> !x and !y
        if isinstance(node.op, ast.Or):
            return ast.BoolOp(ast.And(), new_values)
        # !(x and y) <> !x or !y
        if isinstance(node.op, ast.And):
            return ast.BoolOp(ast.Or(), new_values)

    if isinstance(node, ast.Compare):
        cmps = [ast.Compare(x, [negate(o)], [y])
                for x, o, y
                in zip([node.left] + node.comparators[:-1], node.ops,
                       node.comparators)]
        if len(cmps) == 1:
            return cmps[0]
        return ast.BoolOp(ast.Or(), cmps)

    if isinstance(node, ast.Eq):
        return ast.NotEq()
    if isinstance(node, ast.NotEq):
        return ast.Eq()
    if isinstance(node, ast.Gt):
        return ast.LtE()
    if isinstance(node, ast.GtE):
        return ast.Lt()
    if isinstance(node, ast.Lt):
        return ast.GtE()
    if isinstance(node, ast.LtE):
        return ast.Gt()
    if isinstance(node, ast.In):
        return ast.NotIn()
    if isinstance(node, ast.NotIn):
        return ast.In()

    if isinstance(node, ast.Attribute):
        if node.attr == 'False':
            return ast.Constant(True, None)
        if node.attr == 'True':
            return ast.Constant(False, None)

    raise UnsupportedExpression()
コード例 #15
0
 def build(left, right):
     return gast.BoolOp(op=op, values=[left, right])