def visit_For(self, node):
        modified_node = self.generic_visit(node)
        continued_id = len(self.for_continued_stack)
        continued_flags = self.for_continued_stack.pop()
        if continued_flags:
            node.body.insert(
                0,
                gast.Assign(targets=[
                    gast.Name(id=self.continued_flag + str(continued_id),
                              ctx=gast.Store(),
                              annotation=None,
                              type_comment=None)
                ],
                            value=gast.Constant(value=False, kind=None)))
        breaked_id = len(self.for_breaked_stack)
        breaked_flags = self.for_breaked_stack.pop()
        bool_values = []
        if breaked_flags:
            node.body.insert(
                0,
                gast.Assign(targets=[
                    gast.Name(id=self.breaked_flag + str(breaked_id),
                              ctx=gast.Store(),
                              annotation=None,
                              type_comment=None)
                ],
                            value=gast.Constant(value=False, kind=None)))
            bool_values.append(
                gast.Name(id=self.breaked_flag + str(breaked_id),
                          ctx=gast.Load(),
                          annotation=None,
                          type_comment=None))

        if len(self.func_returned_stack) > 0:
            returned_id = len(self.func_returned_stack)
            returned_flags = self.func_returned_stack[-1]
            if returned_flags:
                bool_values.append(
                    gast.Name(id=self.returned_flag + str(returned_id),
                              ctx=gast.Load(),
                              annotation=None,
                              type_comment=None))

        if len(bool_values) > 0:
            if len(bool_values) == 1:
                cond = bool_values[0]
            elif len(bool_values) > 1:
                cond = gast.BoolOp(op=gast.Or(), values=bool_values)

            node.body.append(
                gast.Assign(targets=[
                    gast.Name(id=self.keepgoing_flag,
                              ctx=gast.Store(),
                              annotation=None,
                              type_comment=None)
                ],
                            value=gast.UnaryOp(op=gast.Not(), operand=cond)))
            node.body.append(gast.If(test=cond, body=[gast.Break()],
                                     orelse=[]))
        return modified_node
示例#2
0
 def visit_Assign(self, node):
     if len(node.targets) > 1:
         raise NotImplementedError("cannot process multiple assignment")
     if not isinstance(node.targets[0], gast.Name):
         raise NotImplementedError("cannot process indexed assignment")
     # $lhs = $lhs.update_($rhs, matchbox.EXECUTION_MASK) if (lhs in vars()
     # or lhs in globals()) and isinstance($lhs, (matchbox.MaskedBatch,
     # matchbox.TENSOR_TYPE)) else $rhs
     node.value = gast.IfExp(
         gast.BoolOp(
             gast.And(),
             [
                 gast.BoolOp(gast.Or(), [
                     gast.Compare(gast.Str(
                         node.targets[0].id), [gast.In()], [
                             gast.Call(gast.Name('vars', gast.Load, None),
                                       [], [])
                         ]),
                     gast.Compare(gast.Str(
                         node.targets[0].id), [gast.In()], [
                             gast.Call(
                                 gast.Name('globals', gast.Load, None), [],
                                 [])
                         ])
                 ]),
                 # gast.Compare(
                 #    gast.Attribute(
                 #      gast.Name('matchbox', gast.Load(), None),
                 #      gast.Name('EXECUTION_MASK', gast.Load(), None),
                 #      gast.Load()),
                 #    [gast.IsNot()],
                 #    [gast.NameConstant(None)]),
                 gast.Call(gast.Name('isinstance', gast.Load(), None), [
                     node.targets[0],
                     gast.Tuple([
                         gast.Attribute(
                             gast.Name('matchbox', gast.Load(), None),
                             gast.Name('MaskedBatch', gast.Load(), None),
                             gast.Load()),
                         gast.Attribute(
                             gast.Name('matchbox', gast.Load(), None),
                             gast.Name('TENSOR_TYPE', gast.Load(), None),
                             gast.Load())
                     ], gast.Load())
                 ], [])
             ]),
         gast.Call(
             gast.Attribute(
                 gast.Name(node.targets[0].id, gast.Load(), None),
                 gast.Name('_update', gast.Load(), None), gast.Load()), [
                     node.value,
                     gast.Attribute(
                         gast.Name('matchbox', gast.Load(), None),
                         gast.Name('EXECUTION_MASK', gast.Load(), None),
                         gast.Load())
                 ], []),
         node.value)
     return node
示例#3
0
 def visit_For(self, node):
     modified_node = self.generic_visit(node)
     continue_flags = self.for_continue_stack.pop()
     for flag in continue_flags:
         node.body.insert(
             0,
             gast.Assign(targets=[
                 gast.Name(id=flag, ctx=gast.Store(), annotation=None)
             ],
                         value=gast.NameConstant(value=False)))
     breaked_flags = self.for_breaked_stack.pop()
     bool_values = []
     for flag in breaked_flags:
         node.body.insert(
             0,
             gast.Assign(targets=[
                 gast.Name(id=flag, ctx=gast.Store(), annotation=None)
             ],
                         value=gast.NameConstant(value=False)))
         bool_values.append(
             gast.Name(id=flag, ctx=gast.Load(), annotation=None))
     if len(bool_values) > 0:
         if len(bool_values) == 1:
             cond = bool_values[0]
         elif len(bool_values) > 1:
             cond = gast.BoolOp(op=gast.Or(), values=bool_values)
         if isinstance(modified_node, gast.For):
             modified_node.body.append(
                 gast.Assign(targets=[
                     gast.Name(id=self.keepgoing_flag,
                               ctx=gast.Store(),
                               annotation=None)
                 ],
                             value=gast.UnaryOp(op=gast.Not(),
                                                operand=cond)))
             modified_node.body.append(
                 gast.If(test=cond, body=[gast.Break()], orelse=[]))
         elif isinstance(modified_node, gast.If):
             if isinstance(modified_node.body[0], gast.For):
                 modified_node.body[0].body.append(
                     gast.Assign(targets=[
                         gast.Name(id=self.keepgoing_flag,
                                   ctx=gast.Store(),
                                   annotation=None)
                     ],
                                 value=gast.UnaryOp(op=gast.Not(),
                                                    operand=cond)))
                 modified_node.body[0].body.append(
                     gast.If(test=cond, body=[gast.Break()], orelse=[]))
     return modified_node
示例#4
0
 def visit_Compare(self, node):
     self.generic_visit(node)
     all_comparators = [node.left] + node.comparators
     if len(all_comparators) == 2:
         # Replace `a < b` with `not a >= b`
         inverted_op = OP_INVERSIONS[type(node.ops[0])]
         return gast.UnaryOp(op=gast.Not(),
                             operand=gast.Compare(
                                 left=node.left,
                                 ops=[inverted_op()],
                                 comparators=node.comparators))
     else:
         # Replace `a < b < c` with `not (a >= b or b >= c)`
         or_clauses = []
         for left, op, right in zip(all_comparators[:-1], node.ops,
                                    all_comparators[1:]):
             inverted_op = OP_INVERSIONS[type(op)]
             or_clauses.append(
                 gast.Compare(left=left,
                              ops=[inverted_op()],
                              comparators=[right]))
         return gast.UnaryOp(op=gast.Not(),
                             operand=gast.BoolOp(op=gast.Or(),
                                                 values=or_clauses))
示例#5
0
def negate(node):
    if isinstance(node, ast.Name):
        # Not type info, could be anything :(
        raise UnsupportedExpression()

    if isinstance(node, ast.UnaryOp):
        # !~x <> ~x == 0 <> x == ~0 <> x == -1
        if isinstance(node.op, ast.Invert):
            return ast.Compare(node.operand,
                               [ast.Eq()],
                               [ast.Constant(-1, None)])
        # !!x <> x
        if isinstance(node.op, ast.Not):
            return node.operand
        # !+x <> +x == 0 <> x == 0 <> !x
        if isinstance(node.op, ast.UAdd):
            return node.operand
        # !-x <> -x == 0 <> x == 0 <> !x
        if isinstance(node.op, ast.USub):
            return node.operand

    if isinstance(node, ast.BoolOp):
        new_values = [ast.UnaryOp(ast.Not(), v) for v in node.values]
        # !(x or y) <> !x and !y
        if isinstance(node.op, ast.Or):
            return ast.BoolOp(ast.And(), new_values)
        # !(x and y) <> !x or !y
        if isinstance(node.op, ast.And):
            return ast.BoolOp(ast.Or(), new_values)

    if isinstance(node, ast.Compare):
        cmps = [ast.Compare(x, [negate(o)], [y])
                for x, o, y
                in zip([node.left] + node.comparators[:-1], node.ops,
                       node.comparators)]
        if len(cmps) == 1:
            return cmps[0]
        return ast.BoolOp(ast.Or(), cmps)

    if isinstance(node, ast.Eq):
        return ast.NotEq()
    if isinstance(node, ast.NotEq):
        return ast.Eq()
    if isinstance(node, ast.Gt):
        return ast.LtE()
    if isinstance(node, ast.GtE):
        return ast.Lt()
    if isinstance(node, ast.Lt):
        return ast.GtE()
    if isinstance(node, ast.LtE):
        return ast.Gt()
    if isinstance(node, ast.In):
        return ast.NotIn()
    if isinstance(node, ast.NotIn):
        return ast.In()

    if isinstance(node, ast.Attribute):
        if node.attr == 'False':
            return ast.Constant(True, None)
        if node.attr == 'True':
            return ast.Constant(False, None)

    raise UnsupportedExpression()
示例#6
0
 def test_exception(self):
     with self.assertRaises(KeyError):
         cmpop_node_to_str(gast.Or())