Esempio n. 1
0
def split_critical_edges(func, cfg, phis):
    """
    Split critical edges to correctly handle cycles in phis. See 2) above.
    """
    b = Builder(func)
    for block in cfg.node:
        successors = cfg.neighbors(block)
        if len(successors) > 1:
            # More than one successor, we need to split
            # (Alternatively, we could move our copies into the successor block
            #  if we were the only predecessor, but this seems simpler)

            # Split successors with phis
            new_succs = {}  # old_successor -> new_successor
            for succ in successors:
                if phis[succ]:
                    label = func.temp("split_critical")
                    new_succ = func.new_block(label, after=block)
                    new_succs[succ] = new_succ
                    b.position_at_end(new_succ)
                    b.jump(succ)

            # Patch our basic-block terminator to point to new blocks
            if new_succs:
                terminator = block.terminator
                assert terminator.opcode == 'cbranch', terminator
                test, truebb, falsebb = terminator.args
                terminator.set_args([
                    test,
                    new_succs.get(truebb, truebb),
                    new_succs.get(falsebb, falsebb)
                ])
Esempio n. 2
0
def split_critical_edges(func, cfg, phis):
    """
    Split critical edges to correctly handle cycles in phis. See 2) above.
    """
    b = Builder(func)
    for block in cfg.node:
        successors = cfg.neighbors(block)
        if len(successors) > 1:
            # More than one successor, we need to split
            # (Alternatively, we could move our copies into the successor block
            #  if we were the only predecessor, but this seems simpler)

            # Split successors with phis
            new_succs = {} # old_successor -> new_successor
            for succ in successors:
                if phis[succ]:
                    label = func.temp("split_critical")
                    new_succ = func.new_block(label, after=block)
                    new_succs[succ] = new_succ
                    b.position_at_end(new_succ)
                    b.jump(succ)

            # Patch our basic-block terminator to point to new blocks
            if new_succs:
                terminator = block.terminator
                assert terminator.opcode == 'cbranch', terminator
                test, truebb, falsebb = terminator.args
                terminator.set_args([test,
                                     new_succs.get(truebb, truebb),
                                     new_succs.get(falsebb, falsebb)])
Esempio n. 3
0
def move_generator(func, consumer, empty_body):
    gen = consumer.generator
    gen.unlink()

    b = Builder(func)
    b.position_at_end(empty_body)
    b.emit(gen)

    with b.at_end(empty_body):
        loop_exit = determine_loop_exit(consumer.loop)
        b.jump(loop_exit)
Esempio n. 4
0
def move_generator(func, consumer, empty_body):
    gen = consumer.generator
    gen.unlink()

    b = Builder(func)
    b.position_at_end(empty_body)
    b.emit(gen)

    with b.at_end(empty_body):
        loop_exit = determine_loop_exit(consumer.loop)
        b.jump(loop_exit)
Esempio n. 5
0
class TestBuilder(unittest.TestCase):
    def setUp(self):
        self.f = Function("testfunc", ['a'],
                          types.Function(types.Float32, [types.Int32]))
        self.b = Builder(self.f)
        self.b.position_at_end(self.f.add_block('entry'))
        self.a = self.f.get_arg('a')

    def test_basic_builder(self):
        v = self.b.alloca(types.Pointer(types.Float32), [])
        result = self.b.mul(types.Int32, [self.a, self.a], result='r')
        c = self.b.convert(types.Float32, [result])
        self.b.store(c, v)
        val = self.b.load(types.Float32, [v])
        self.b.ret(val)
        # print(string(self.f))
        self.assertEqual(str(self.f).strip(), basic_expected)

    def test_splitblock(self):
        old, new = self.b.splitblock('newblock')
        with self.b.at_front(old):
            self.b.add(types.Int32, [self.a, self.a])
        with self.b.at_end(new):
            self.b.div(types.Int32, [self.a, self.a])
        # print(string(self.f))
        self.assertEqual(split_expected, string(self.f))

    def test_loop_builder(self):
        square = self.b.mul(types.Int32, [self.a, self.a])
        c = self.b.convert(types.Float32, [square])
        self.b.position_after(square)
        _, block = self.b.splitblock('start', terminate=True)
        self.b.position_at_end(block)

        const = partial(Const, type=types.Int32)
        cond, body, exit = self.b.gen_loop(const(5), const(10), const(2))
        with self.b.at_front(body):
            self.b.print_(c)
        with self.b.at_end(exit):
            self.b.ret(c)

        # print(string(self.f))
        # verify.verify(self.f)
        # self.assertEqual(loop_expected, string(self.f))


# TestBuilder('test_basic_builder').debug()
# TestBuilder('test_splitblock').debug()
# TestBuilder('test_loop_builder').debug()
# unittest.main()
Esempio n. 6
0
class TestBuilder(unittest.TestCase):

    def setUp(self):
        self.f = Function("testfunc", ['a'],
                          types.Function(types.Float32, [types.Int32]))
        self.b = Builder(self.f)
        self.b.position_at_end(self.f.add_block('entry'))
        self.a = self.f.get_arg('a')

    def test_basic_builder(self):
        v = self.b.alloca(types.Pointer(types.Float32), [])
        result = self.b.mul(types.Int32, [self.a, self.a], result='r')
        c = self.b.convert(types.Float32, [result])
        self.b.store(c, v)
        val = self.b.load(types.Float32, [v])
        self.b.ret(val)
        # print(string(self.f))
        self.assertEqual(str(self.f).strip(), basic_expected)

    def test_splitblock(self):
        old, new = self.b.splitblock('newblock')
        with self.b.at_front(old):
            self.b.add(types.Int32, [self.a, self.a])
        with self.b.at_end(new):
            self.b.div(types.Int32, [self.a, self.a])
        # print(string(self.f))
        self.assertEqual(split_expected, string(self.f))

    def test_loop_builder(self):
        square = self.b.mul(types.Int32, [self.a, self.a])
        c = self.b.convert(types.Float32, [square])
        self.b.position_after(square)
        _, block = self.b.splitblock('start', terminate=True)
        self.b.position_at_end(block)

        const = partial(Const, type=types.Int32)
        cond, body, exit = self.b.gen_loop(const(5), const(10), const(2))
        with self.b.at_front(body):
            self.b.print_(c)
        with self.b.at_end(exit):
            self.b.ret(c)

        # print(string(self.f))
        # verify.verify(self.f)
        # self.assertEqual(loop_expected, string(self.f))

# TestBuilder('test_basic_builder').debug()
# TestBuilder('test_splitblock').debug()
# TestBuilder('test_loop_builder').debug()
# unittest.main()
Esempio n. 7
0
def detach_loop(func, consumer):
    loop, iter = consumer.loop, consumer.iter

    for block in loop.blocks:
        func.del_block(block)

    func.reset_uses()

    b = Builder(func)
    jump = iter.block.terminator
    assert jump.opcode == 'jump' and jump.args[0] == loop.head
    jump.delete()

    b.position_at_end(iter.block)
    _, newblock = b.splitblock(terminate=True)
    return newblock
Esempio n. 8
0
def detach_loop(func, consumer):
    loop, iter = consumer.loop, consumer.iter

    for block in loop.blocks:
        func.del_block(block)

    func.reset_uses()

    b = Builder(func)
    jump = iter.block.terminator
    assert jump.opcode == 'jump' and jump.args[0] == loop.head
    jump.delete()

    b.position_at_end(iter.block)
    _, newblock = b.splitblock(terminate=True)
    return newblock
Esempio n. 9
0
def splitblock(block, trailing, name=None, terminate=False, preserve_exc=True):
    """Split the current block, returning (old_block, new_block)"""
    from pykit.analysis import cfa
    from pykit.ir import Builder

    func = block.parent

    if block.is_terminated():
        successors = cfa.deduce_successors(block)
    else:
        successors = []

    # -------------------------------------------------
    # Sanity check

    # Allow splitting only after leaders and before terminator
    # TODO: error check

    # -------------------------------------------------
    # Split

    blockname = name or func.temp('Block')
    newblock = func.new_block(blockname, after=block)

    # -------------------------------------------------
    # Move ops after the split to new block

    for op in trailing:
        op.unlink()
    newblock.extend(trailing)

    if terminate and not block.is_terminated():
        # Terminate
        b = Builder(func)
        b.position_at_end(block)
        b.jump(newblock)

    # Update phis and preserve exception blocks
    patch_phis(block, newblock, successors)
    if preserve_exc:
        preserve_exceptions(block, newblock)

    return block, newblock
Esempio n. 10
0
def consume_yields(func, consumer, generator_func, valuemap):
    b = Builder(func)
    copier = lambda x : x

    loop = consumer.loop
    inlined_values = set(valuemap.values())

    for block in func.blocks:
        if block in inlined_values:
            for op in block.ops:
                if op.opcode == 'yield':
                    # -- Replace 'yield' by the loop body -- #
                    b.position_after(op)
                    _, resume = b.splitblock()

                    # Copy blocks
                    blocks = [copier(block) for block in loop.blocks]

                    # Insert blocks
                    prev = op.block
                    for block in blocks:
                        func.add_block(block, after=prev)
                        prev = block

                    # Fix wiring
                    b.jump(blocks[0])
                    b.position_at_end(blocks[-1])
                    b.jump(resume)

                    # We just introduced a bunch of copied blocks
                    func.reset_uses()

                    # Update phis with new predecessor
                    b.replace_predecessor(loop.tail, op.block, loop.head)
                    b.replace_predecessor(loop.tail, op.block, loop.head)

                    # Replace next() by value produced by yield
                    value = op.args[0]
                    consumer.next.replace_uses(value)
                    op.delete()

    # We don't need these anymore
    consumer.next.delete()
Esempio n. 11
0
def consume_yields(func, consumer, generator_func, valuemap):
    b = Builder(func)
    copier = lambda x: x

    loop = consumer.loop
    inlined_values = set(valuemap.values())

    for block in func.blocks:
        if block in inlined_values:
            for op in block.ops:
                if op.opcode == 'yield':
                    # -- Replace 'yield' by the loop body -- #
                    b.position_after(op)
                    _, resume = b.splitblock()

                    # Copy blocks
                    blocks = [copier(block) for block in loop.blocks]

                    # Insert blocks
                    prev = op.block
                    for block in blocks:
                        func.add_block(block, after=prev)
                        prev = block

                    # Fix wiring
                    b.jump(blocks[0])
                    b.position_at_end(blocks[-1])
                    b.jump(resume)

                    # We just introduced a bunch of copied blocks
                    func.reset_uses()

                    # Update phis with new predecessor
                    b.replace_predecessor(loop.tail, op.block, loop.head)
                    b.replace_predecessor(loop.tail, op.block, loop.head)

                    # Replace next() by value produced by yield
                    value = op.args[0]
                    consumer.next.replace_uses(value)
                    op.delete()

    # We don't need these anymore
    consumer.next.delete()
Esempio n. 12
0
class TestBuilder(unittest.TestCase):

    def setUp(self):
        self.f = Function("testfunc", ['a'],
                          types.Function(types.Float32, [types.Int32]))
        self.b = Builder(self.f)
        self.b.position_at_end(self.f.new_block('entry'))
        self.a = self.f.get_arg('a')

    def test_basic_builder(self):
        v = self.b.alloca(types.Pointer(types.Float32), [])
        result = self.b.mul(types.Int32, [self.a, self.a], result='r')
        c = self.b.convert(types.Float32, [result])
        self.b.store(c, v)
        val = self.b.load(types.Float32, [v])
        self.b.ret(val)
        # print(string(self.f))
        assert interp.run(self.f, args=[10]) == 100

    def test_splitblock(self):
        old, new = self.b.splitblock('newblock')
        with self.b.at_front(old):
            self.b.add(types.Int32, [self.a, self.a])
        with self.b.at_end(new):
            self.b.div(types.Int32, [self.a, self.a])
        self.assertEqual(opcodes(self.f), ['add', 'div'])

    def test_loop_builder(self):
        square = self.b.mul(types.Int32, [self.a, self.a])
        c = self.b.convert(types.Float32, [square])
        self.b.position_after(square)
        _, block = self.b.splitblock('start', terminate=True)
        self.b.position_at_end(block)

        const = partial(Const, type=types.Int32)
        cond, body, exit = self.b.gen_loop(const(5), const(10), const(2))
        with self.b.at_front(body):
            self.b.print(c)
        with self.b.at_end(exit):
            self.b.ret(c)

        self.assertEqual(interp.run(self.f, args=[10]), 100.0)
def splitblock(block, trailing, name=None, terminate=False, preserve_exc=True):
    """Split the current block, returning (old_block, new_block)"""

    func = block.parent

    if block.is_terminated():
        successors = deduce_successors(block)
    else:
        successors = []

    # -------------------------------------------------
    # Sanity check

    # Allow splitting only after leaders and before terminator
    # TODO: error check

    # -------------------------------------------------
    # Split

    blockname = name or func.temp('Block')
    newblock = func.new_block(blockname, after=block)

    # -------------------------------------------------
    # Move ops after the split to new block

    for op in trailing:
        op.unlink()
    newblock.extend(trailing)

    if terminate and not block.is_terminated():
        # Terminate
        b = Builder(func)
        b.position_at_end(block)
        b.jump(newblock)

    # Update phis and preserve exception blocks
    patch_phis(block, newblock, successors)
    if preserve_exc:
        preserve_exceptions(block, newblock)

    return block, newblock
Esempio n. 14
0
class TestBuilder(unittest.TestCase):
    def setUp(self):
        self.f = Function("testfunc", ['a'],
                          types.Function(types.Float32, [types.Int32], False))
        self.b = Builder(self.f)
        self.b.position_at_end(self.f.new_block('entry'))
        self.a = self.f.get_arg('a')

    def test_basic_builder(self):
        v = self.b.alloca(types.Pointer(types.Float32))
        result = self.b.mul(self.a, self.a, result='r')
        c = self.b.convert(types.Float32, result)
        self.b.store(c, v)
        val = self.b.load(v)
        self.b.ret(val)
        # print(string(self.f))
        assert interp.run(self.f, args=[10]) == 100

    def test_splitblock(self):
        old, new = self.b.splitblock('newblock')
        with self.b.at_front(old):
            self.b.add(self.a, self.a)
        with self.b.at_end(new):
            self.b.div(self.a, self.a)
        self.assertEqual(opcodes(self.f), ['add', 'div'])

    def test_loop_builder(self):
        square = self.b.mul(self.a, self.a)
        c = self.b.convert(types.Float32, square)
        self.b.position_after(square)
        _, block = self.b.splitblock('start', terminate=True)
        self.b.position_at_end(block)

        const = partial(Const, type=types.Int32)
        cond, body, exit = self.b.gen_loop(const(5), const(10), const(2))
        with self.b.at_front(body):
            self.b.print(c)
        with self.b.at_end(exit):
            self.b.ret(c)

        self.assertEqual(interp.run(self.f, args=[10]), 100.0)

    def test_splitblock_preserve_phis(self):
        """
        block1:
            %0 = mul a a
            jump(newblock)

        newblock:
            %1 = phi([block1], [%0])
            ret %1
        """
        square = self.b.mul(self.a, self.a)
        old, new = self.b.splitblock('newblock', terminate=True)
        with self.b.at_front(new):
            phi = self.b.phi(types.Int32, [self.f.startblock], [square])
            self.b.ret(phi)

        # Now split block1
        self.b.position_after(square)
        block1, split = self.b.splitblock(terminate=True)

        phi, ret = new.ops
        blocks, values = phi.args
        self.assertEqual(blocks, [split])
Esempio n. 15
0
class PykitIRVisitor(c_ast.NodeVisitor):
    """
    Map pykit IR in the form of polymorphic C to in-memory pykit IR.

        int function(float x) {
            int i = 0;        /* I am a comment */
            while (i < 10) {  /*: { "unroll": true } :*/
                x = call_external("sqrt", x * x);
            }
            return (int) x;
        }

    Attributes:
    """

    in_function = False

    def __init__(self, type_env=None):
        self.mod = Module()
        self.type_env = type_env or {}

        self.func = None
        self.builder = None
        self.local_vars = None
        self.allocas = None

        self.global_vars = {}
        self.functions = {}

    # ______________________________________________________________________

    @property
    def vars(self):
        if self.in_function:
            return self.local_vars
        else:
            return self.global_vars

    def enter_func(self):
        self.in_function = True
        self.local_vars = {}
        self.allocas = {}

    def leave_func(self):
        self.in_function = False
        self.mod.add_function(self.func)
        self.local_vars = None
        self.allocas = None
        self.func = None

    def visit(self, node, type=None):
        """
        Visit a node.

        :type: Whether we have a type for this opcode, which is an LHS type
               or a cast. E.g.:

              (Int) call(...)    // cast
              result = call(...) // assmnt, assuming 'result' is declared
              result = call(..., call(...)) // second 'call' isn't typed

        """
        self.type = type
        method = 'visit_' + node.__class__.__name__
        visitor = getattr(self, method, self.generic_visit)
        # if visitor is None:
        #     raise SyntaxError(
        #         "Node %s not supported in %s:%s" % (node, node.coord.file,
        #                                             node.coord.line))
        return visitor(node)

    def visitif(self, node):
        if node:
            return self.visit(node)

    def visits(self, node):
        return list(map(self.visit, node))

    # ______________________________________________________________________

    def alloca(self, varname):
        if varname not in self.allocas:
            # Allocate variable with alloca
            with self.builder.at_front(self.func.blocks[0]):
                type = self.local_vars[varname]
                self.allocas[varname] = self.builder.alloca(type, [], varname)

        return self.allocas[varname]

    def assign(self, varname, rhs):
        if not self.in_function:
            error(rhs, "Assignment only allowed in functions")

        if varname not in self.allocas:
            # Allocate variable with alloca
            with self.builder.at_front(self.func.blocks[0]):
                type = self.local_vars[varname]
                self.allocas[varname] = self.builder.alloca(type, [], varname)

        self.builder.store(self.visit(rhs), self.alloca(varname))

    # ______________________________________________________________________

    def visit_Decl(self, decl):
        if decl.name in self.vars:
            error(decl, "Var '%s' already declared!" % (decl.name,))

        type = self.visit(decl.type)
        self.vars[decl.name] = type
        if decl.init:
            self.assign(decl.name, decl.init)

        return type

    def visit_TypeDecl(self, decl):
        return self.visit(decl.type)

    visit_Typename = visit_TypeDecl

    def visit_PtrDecl(self, decl):
        return types.Pointer(self.visit(decl.type.type))

    def visit_FuncDecl(self, decl):
        return types.Function(self.visit(decl.type),
                              self.visits(decl.args.params))

    def visit_IdentifierType(self, node):
        name, = node.names
        return self.type_env[name]

    def visit_Typedef(self, node):
        if node.name in ("Type", "_list"):
            type = self.type_env[node.name]
        else:
            type = self.visit(node.type)
            if type == types.Type:
                type = getattr(types, node.name)

            self.type_env[node.name] = type

        return type

    def visit_Template(self, node):
        left = self.visit(node.left)
        subtypes = self.visits(node.right)
        if left is list:
            return list(subtypes)
        else:
            assert issubclass(left, types.Type)
            subtypes = self.visits(node.right)
            return left(*subtypes)

    # ______________________________________________________________________

    def visit_FuncDef(self, node):
        assert not node.param_decls
        self.enter_func()

        name = node.decl.name
        type = self.visit(node.decl.type)
        argnames = [p.name for p in node.decl.type.args.params]
        self.func = Function(name, argnames, type)
        self.func.add_block('entry')
        self.builder = Builder(self.func)
        self.builder.position_at_end(self.func.blocks[0])
        self.generic_visit(node.body)

        self.leave_func()

    # ______________________________________________________________________

    def visit_FuncCall(self, node):
        name = node.name.name
        if not self.in_typed_context:
            error(node, "Expected a type for sub-expression "
                        "(add a cast or assignment)")
        if not hasattr(self.builder, name):
            error(node, "No opcode %s" % (name,))
        self.in_typed_context = False

        buildop = getattr(self.builder, name)
        args = self.visits(node.args.exprs)
        return buildop, args

    def visit_ID(self, node):
        if self.in_function:
            if node.name not in self.local_vars:
                error(node, "Not a local: %r" % node.name)

            result = self.alloca(node.name)
            return self.builder.load(result.type, result)

    def visit_Cast(self, node):
        type = self.visit(node.to_type)
        if isinstance(node.expr, c_ast.FuncCall):
            self.in_typed_context = True
            buildop, args = self.visit(node.expr)
            return buildop(type, args, "temp")
        else:
            result = self.visit(node.expr)
            if result.type == type:
                return result
            return self.builder.convert(type, [result], "temp")

    def visit_Assignment(self, node):
        if node.op != '=':
            error(node, "Only assignment with '=' is supported")
        if not isinstance(node.lvalue, c_ast.ID):
            error(node, "Canot only assign to a name")
        self.assign(node.lvalue.name, node.rvalue)

    def visit_Constant(self, node):
        type = self.type_env[node.type]
        const = types.convert(node.value, type)
        return Const(const)

    def visit_UnaryOp(self, node):
        op = defs.unary_defs[node.op]
        buildop = getattr(self.builder, op)
        arg = self.visit(node.expr)
        type = self.type or arg.type
        return buildop(type, [arg])

    def visit_BinaryOp(self, node):
        op = binary_defs[node.op]
        buildop = getattr(self.builder, op)
        left, right = self.visits([node.left, node.right])
        if not self.type:
            assert left.type == right.type, (left, right)
        return buildop(self.type or left.type, [left, right], "temp")

    def _loop(self, init, cond, next, body):
        _, exit_block = self.builder.splitblock("exit")
        _, body_block = self.builder.splitblock("body")
        _, cond_block = self.builder.splitblock("cond")

        self.visitif(init)
        self.builder.jump(cond_block)

        with self.builder.at_front(cond_block):
            cond = self.visit(cond, type=types.Bool)
            self.builder.cbranch(cond, cond_block, exit_block)

        with self.builder.at_front(body_block):
            self.visit(body)
            self.visitif(next)
            self.builder.jump(cond_block)

        self.builder.position_at_end(exit_block)

    def visit_While(self, node):
        self._loop(None, node.cond, None, node.stmt)

    def visit_For(self, node):
        self._loop(node.init, node.cond, node.next, node.stmt)

    def visit_Return(self, node):
        self.builder.ret(self.visit(node.expr))
Esempio n. 16
0
class PykitIRVisitor(c_ast.NodeVisitor):
    """
    Map pykit IR in the form of polymorphic C to in-memory pykit IR.

        int function(float x) {
            int i = 0;        /* I am a comment */
            while (i < 10) {  /*: { "unroll": true } :*/
                x = call_external("sqrt", x * x);
            }
            return (int) x;
        }

    Attributes:
    """

    in_function = False

    def __init__(self, type_env=None):
        self.mod = Module()
        self.type_env = type_env or {}

        self.func = None
        self.builder = None
        self.local_vars = None
        self.allocas = None

        self.global_vars = {}
        self.functions = {}

    # ______________________________________________________________________

    @property
    def vars(self):
        if self.in_function:
            return self.local_vars
        else:
            return self.global_vars

    def enter_func(self):
        self.in_function = True
        self.local_vars = {}
        self.allocas = {}

    def leave_func(self):
        self.in_function = False
        self.mod.add_function(self.func)
        self.local_vars = None
        self.allocas = None
        self.func = None

    def visit(self, node, type=None):
        """
        Visit a node.

        :type: Whether we have a type for this opcode, which is an LHS type
               or a cast. E.g.:

              (Int) call(...)    // cast
              result = call(...) // assmnt, assuming 'result' is declared
              result = call(..., call(...)) // second 'call' isn't typed

        """
        self.type = type
        method = 'visit_' + node.__class__.__name__
        visitor = getattr(self, method, self.generic_visit)
        # if visitor is None:
        #     raise SyntaxError(
        #         "Node %s not supported in %s:%s" % (node, node.coord.file,
        #                                             node.coord.line))
        return visitor(node)

    def visitif(self, node):
        if node:
            return self.visit(node)

    def visits(self, node):
        return list(map(self.visit, node))

    # ______________________________________________________________________

    def alloca(self, varname):
        if varname not in self.allocas:
            # Allocate variable with alloca
            with self.builder.at_front(self.func.blocks[0]):
                type = self.local_vars[varname]
                self.allocas[varname] = self.builder.alloca(type, [], varname)

        return self.allocas[varname]

    def assign(self, varname, rhs):
        if not self.in_function:
            error(rhs, "Assignment only allowed in functions")

        if varname not in self.allocas:
            # Allocate variable with alloca
            with self.builder.at_front(self.func.blocks[0]):
                type = self.local_vars[varname]
                self.allocas[varname] = self.builder.alloca(type, [], varname)

        self.builder.store(self.visit(rhs), self.alloca(varname))

    # ______________________________________________________________________

    def visit_Decl(self, decl):
        if decl.name in self.vars:
            error(decl, "Var '%s' already declared!" % (decl.name, ))

        type = self.visit(decl.type)
        self.vars[decl.name] = type
        if decl.init:
            self.assign(decl.name, decl.init)

        return type

    def visit_TypeDecl(self, decl):
        return self.visit(decl.type)

    visit_Typename = visit_TypeDecl

    def visit_PtrDecl(self, decl):
        return types.Pointer(self.visit(decl.type.type))

    def visit_FuncDecl(self, decl):
        return types.Function(self.visit(decl.type),
                              self.visits(decl.args.params))

    def visit_IdentifierType(self, node):
        name, = node.names
        return self.type_env[name]

    def visit_Typedef(self, node):
        if node.name in ("Type", "_list"):
            type = self.type_env[node.name]
        else:
            type = self.visit(node.type)
            if type == types.Type:
                type = getattr(types, node.name)

            self.type_env[node.name] = type

        return type

    def visit_Template(self, node):
        left = self.visit(node.left)
        subtypes = self.visits(node.right)
        if left is list:
            return list(subtypes)
        else:
            assert issubclass(left, types.Type)
            subtypes = self.visits(node.right)
            return left(*subtypes)

    # ______________________________________________________________________

    def visit_FuncDef(self, node):
        assert not node.param_decls
        self.enter_func()

        name = node.decl.name
        type = self.visit(node.decl.type)
        argnames = [p.name for p in node.decl.type.args.params]
        self.func = Function(name, argnames, type)
        self.func.add_block('entry')
        self.builder = Builder(self.func)
        self.builder.position_at_end(self.func.blocks[0])
        self.generic_visit(node.body)

        self.leave_func()

    # ______________________________________________________________________

    def visit_FuncCall(self, node):
        name = node.name.name
        if not self.in_typed_context:
            error(
                node, "Expected a type for sub-expression "
                "(add a cast or assignment)")
        if not hasattr(self.builder, name):
            error(node, "No opcode %s" % (name, ))
        self.in_typed_context = False

        buildop = getattr(self.builder, name)
        args = self.visits(node.args.exprs)
        return buildop, args

    def visit_ID(self, node):
        if self.in_function:
            if node.name not in self.local_vars:
                error(node, "Not a local: %r" % node.name)

            result = self.alloca(node.name)
            return self.builder.load(result.type, result)

    def visit_Cast(self, node):
        type = self.visit(node.to_type)
        if isinstance(node.expr, c_ast.FuncCall):
            self.in_typed_context = True
            buildop, args = self.visit(node.expr)
            return buildop(type, args, "temp")
        else:
            result = self.visit(node.expr)
            if result.type == type:
                return result
            return self.builder.convert(type, [result], "temp")

    def visit_Assignment(self, node):
        if node.op != '=':
            error(node, "Only assignment with '=' is supported")
        if not isinstance(node.lvalue, c_ast.ID):
            error(node, "Canot only assign to a name")
        self.assign(node.lvalue.name, node.rvalue)

    def visit_Constant(self, node):
        type = self.type_env[node.type]
        const = types.convert(node.value, type)
        return Const(const)

    def visit_UnaryOp(self, node):
        op = defs.unary_defs[node.op]
        buildop = getattr(self.builder, op)
        arg = self.visit(node.expr)
        type = self.type or arg.type
        return buildop(type, [arg])

    def visit_BinaryOp(self, node):
        op = binary_defs[node.op]
        buildop = getattr(self.builder, op)
        left, right = self.visits([node.left, node.right])
        if not self.type:
            assert left.type == right.type, (left, right)
        return buildop(self.type or left.type, [left, right], "temp")

    def _loop(self, init, cond, next, body):
        _, exit_block = self.builder.splitblock("exit")
        _, body_block = self.builder.splitblock("body")
        _, cond_block = self.builder.splitblock("cond")

        self.visitif(init)
        self.builder.jump(cond_block)

        with self.builder.at_front(cond_block):
            cond = self.visit(cond, type=types.Bool)
            self.builder.cbranch(cond, cond_block, exit_block)

        with self.builder.at_front(body_block):
            self.visit(body)
            self.visitif(next)
            self.builder.jump(cond_block)

        self.builder.position_at_end(exit_block)

    def visit_While(self, node):
        self._loop(None, node.cond, None, node.stmt)

    def visit_For(self, node):
        self._loop(node.init, node.cond, node.next, node.stmt)

    def visit_Return(self, node):
        self.builder.ret(self.visit(node.expr))
Esempio n. 17
0
class PykitIRVisitor(c_ast.NodeVisitor):
    """
    Map pykit IR in the form of polymorphic C to in-memory pykit IR.

        int function(float x) {
            int i = 0;        /* I am a comment */
            while (i < 10) {  /*: { "unroll": true } :*/
                x = call_external("sqrt", x * x);
            }
            return (int) x;
        }

    Attributes:
    """

    in_function = False

    def __init__(self, type_env=None):
        self.mod = Module()
        self.type_env = type_env or {}

        self.func = None
        self.builder = None
        self.local_vars = None
        self.allocas = None

        self.global_vars = {}
        self.functions = {}

    # ______________________________________________________________________

    @property
    def vars(self):
        if self.in_function:
            return self.local_vars
        else:
            return self.global_vars

    def enter_func(self):
        self.in_function = True
        self.local_vars = {}
        self.allocas = {}

    def leave_func(self):
        self.in_function = False
        self.mod.add_function(self.func)
        self.local_vars = None
        self.allocas = None
        self.func = None

    def visit(self, node, type=None):
        """
        Visit a node.

        :type: Whether we have a type for this opcode, which is an LHS type
               or a cast. E.g.:

              (Int) call(...)    // cast
              result = call(...) // assmnt, assuming 'result' is declared
              result = call(..., call(...)) // second 'call' isn't typed

        """
        self.type = type
        method = 'visit_' + node.__class__.__name__
        visitor = getattr(self, method, self.generic_visit)
        # if visitor is None:
        #     raise SyntaxError(
        #         "Node %s not supported in %s:%s" % (node, node.coord.file,
        #                                             node.coord.line))
        return visitor(node)

    def visitif(self, node):
        if node:
            return self.visit(node)

    def visits(self, node):
        return list(map(self.visit, node))

    # ______________________________________________________________________

    def alloca(self, varname):
        if varname not in self.allocas:
            # Allocate variable with alloca
            with self.builder.at_front(self.func.startblock):
                type = types.Pointer(self.local_vars[varname])
                result = self.func.temp(varname)
                self.allocas[varname] = self.builder.alloca(type, [], result)

        return self.allocas[varname]

    def assignvar(self, varname, rhs):
        self.builder.store(rhs, self.alloca(varname))

    def assign(self, varname, rhs):
        if self.in_function:
            # Local variable
            type = self.local_vars[varname]
            self.assignvar(varname, self.visit(rhs, type=type))
        else:
            # Global variable
            type = self.global_vars[varname]
            self.mod.add_global(GlobalValue(varname, type=self.type,
                                            value=self.visit(rhs, type=type)))

    # ______________________________________________________________________

    def visit_Decl(self, decl):
        if decl.name in self.vars:
            error(decl, "Var '%s' already declared!" % (decl.name,))

        type = self.visit(decl.type)
        self.vars[decl.name] = type
        if decl.init:
            self.assign(decl.name, decl.init)
        elif not self.in_function:
            extern = decl.storage == 'external'
            self.mod.add_global(GlobalValue(decl.name, type, external=extern))

        return type

    def visit_TypeDecl(self, decl):
        return self.visit(decl.type)

    visit_Typename = visit_TypeDecl

    def visit_PtrDecl(self, decl):
        return types.Pointer(self.visit(decl.type.type))

    def visit_FuncDecl(self, decl):
        if decl.args:
            params = self.visits(decl.args.params)
        else:
            params = []
        return types.Function(self.visit(decl.type), params)

    def visit_IdentifierType(self, node):
        name, = node.names
        return self.type_env[name]

    def visit_Typedef(self, node):
        if node.name in ("Type", "_list"):
            type = self.type_env[node.name]
        else:
            type = self.visit(node.type)
            if type == types.Type:
                type = getattr(types, node.name)

            self.type_env[node.name] = type

        return type

    def visit_Template(self, node):
        left = self.visit(node.left)
        subtypes = self.visits(node.right)
        if left is list:
            return list(subtypes)
        else:
            assert issubclass(left, types.Type)
            subtypes = self.visits(node.right)
            return left(*subtypes)

    # ______________________________________________________________________

    def visit_FuncDef(self, node):
        assert not node.param_decls
        self.enter_func()

        name = node.decl.name
        type = self.visit(node.decl.type)
        if node.decl.type.args:
            argnames = [p.name or "" for p in node.decl.type.args.params]
        else:
            argnames = []
        self.func = Function(name, argnames, type)
        self.func.new_block('entry')
        self.builder = Builder(self.func)
        self.builder.position_at_end(self.func.startblock)

        # Store arguments in stack variables
        for argname in argnames:
            self.assignvar(argname, self.func.get_arg(argname))

        self.generic_visit(node.body)
        self.leave_func()

    # ______________________________________________________________________

    def visit_FuncCall(self, node):
        type = self.type
        opcode = node.name.name
        args = self.visits(node.args.exprs) if node.args else []

        if opcode == "list":
            return args
        elif not type and not ops.is_void(opcode):
            error(node, "Expected a type for sub-expression "
                        "(add a cast or assignment)")
        elif not hasattr(self.builder, opcode):
            if opcode in self.mod.functions:
                return self.builder.call(type, [self.mod.get_function(opcode),
                                                args])
            error(node, "No opcode %s" % (opcode,))

        buildop = getattr(self.builder, opcode)
        if ops.is_void(opcode):
            return buildop(*args)
        else:
            return buildop(type or "Unset", args)

    def visit_ID(self, node):
        if self.in_function:
            if node.name in self.local_vars:
                result = self.alloca(node.name)
                return self.builder.load(result.type.base, [result])

            global_val = (self.mod.get_function(node.name) or
                          self.mod.get_global(node.name))

            if not global_val:
                error(node, "Not a local or global: %r" % node.name)

            return global_val

    def visit_Cast(self, node):
        type = self.visit(node.to_type)
        if isinstance(node.expr, c_ast.FuncCall):
            op = self.visit(node.expr, type=type)
            op.type = type
            return op
        else:
            result = self.visit(node.expr)
            if result.type == type:
                return result
            return self.builder.convert(type, [result])

    def visit_Assignment(self, node):
        if node.op != '=':
            error(node, "Only assignment with '=' is supported")
        if not isinstance(node.lvalue, c_ast.ID):
            error(node, "Canot only assign to a name")
        self.assign(node.lvalue.name, node.rvalue)

    def visit_Constant(self, node):
        type = self.type_env[node.type]
        const = types.convert(node.value, types.resolve_typedef(type))
        if isinstance(const, basestring):
            const = const[1:-1] # slice away quotes
        return Const(const)

    def visit_UnaryOp(self, node):
        op = defs.unary_defs[node.op]
        buildop = getattr(self.builder, op)
        arg = self.visit(node.expr)
        type = self.type or arg.type
        return buildop(type, [arg])

    def visit_BinaryOp(self, node):
        op = binary_defs[node.op]
        buildop = getattr(self.builder, op)
        left, right = self.visits([node.left, node.right])
        type = self.type
        if not type:
            l, r = map(types.resolve_typedef, [left.type, right.type])
            assert l == r, (l, r)
        if node.op in defs.compare_defs:
            type = types.Bool
        return buildop(type or left.type, [left, right])

    def visit_If(self, node):
        cond = self.visit(node.cond)
        ifpos, elsepos, exit_block = self.builder.ifelse(cond)

        with ifpos:
            self.visit(node.iftrue)
            self.builder.jump(exit_block)

        with elsepos:
            if node.iffalse:
                self.visit(node.iffalse)
            self.builder.jump(exit_block)

        self.builder.position_at_end(exit_block)

    def _loop(self, init, cond, next, body):
        _, exit_block = self.builder.splitblock(self.func.temp("exit"))
        _, body_block = self.builder.splitblock(self.func.temp("body"))
        _, cond_block = self.builder.splitblock(self.func.temp("cond"))

        self.visitif(init)
        self.builder.jump(cond_block)

        with self.builder.at_front(cond_block):
            cond = self.visit(cond, type=types.Bool)
            self.builder.cbranch(cond, body_block, exit_block)

        with self.builder.at_front(body_block):
            self.visit(body)
            self.visitif(next)
            bb = self.builder.basic_block
            if not bb.tail or not ops.is_terminator(bb.tail.opcode):
                self.builder.jump(cond_block)

        self.builder.position_at_end(exit_block)

    def visit_While(self, node):
        self._loop(None, node.cond, None, node.stmt)

    def visit_For(self, node):
        # avoid silly 2to3 rewrite to 'node.__next__'
        next = getattr(node, 'next')
        self._loop(node.init, node.cond, next, node.stmt)

    def visit_Return(self, node):
        b = self.builder
        value = self.visit(node.expr)
        t = self.func.temp
        b.ret(b.convert(self.func.type.restype, [value]))
Esempio n. 18
0
class PykitIRVisitor(c_ast.NodeVisitor):
    """
    Map pykit IR in the form of polymorphic C to in-memory pykit IR.

        int function(float x) {
            int i = 0;        /* I am a comment */
            while (i < 10) {  /*: { "unroll": true } :*/
                x = call_external("sqrt", x * x);
            }
            return (int) x;
        }

    Attributes:
    """

    in_function = False

    def __init__(self, type_env=None):
        self.mod = Module()
        self.type_env = type_env or {}

        self.func = None
        self.builder = None
        self.local_vars = None
        self.allocas = None

        self.global_vars = {}
        self.functions = {}

    # ______________________________________________________________________

    @property
    def vars(self):
        if self.in_function:
            return self.local_vars
        else:
            return self.global_vars

    def enter_func(self):
        self.in_function = True
        self.local_vars = {}
        self.allocas = {}

    def leave_func(self):
        self.in_function = False
        self.mod.add_function(self.func)
        self.local_vars = None
        self.allocas = None
        self.func = None

    def visit(self, node, type=None):
        """
        Visit a node.

        :type: Whether we have a type for this opcode, which is an LHS type
               or a cast. E.g.:

              (Int) call(...)    // cast
              result = call(...) // assmnt, assuming 'result' is declared
              result = call(..., call(...)) // second 'call' isn't typed

        """
        self.type = type
        method = 'visit_' + node.__class__.__name__
        visitor = getattr(self, method, self.generic_visit)
        # if visitor is None:
        #     raise SyntaxError(
        #         "Node %s not supported in %s:%s" % (node, node.coord.file,
        #                                             node.coord.line))
        return visitor(node)

    def visitif(self, node):
        if node:
            return self.visit(node)

    def visits(self, node):
        return list(map(self.visit, node))

    # ______________________________________________________________________

    def alloca(self, varname):
        if varname not in self.allocas:
            # Allocate variable with alloca
            with self.builder.at_front(self.func.startblock):
                type = types.Pointer(self.local_vars[varname])
                result = self.func.temp(varname)
                self.allocas[varname] = self.builder.alloca(type, [], result)

        return self.allocas[varname]

    def assignvar(self, varname, rhs):
        self.builder.store(rhs, self.alloca(varname))

    def assign(self, varname, rhs):
        if self.in_function:
            # Local variable
            type = self.local_vars[varname]
            self.assignvar(varname, self.visit(rhs, type=type))
        else:
            # Global variable
            type = self.global_vars[varname]
            self.mod.add_global(
                GlobalValue(varname,
                            type=self.type,
                            value=self.visit(rhs, type=type)))

    # ______________________________________________________________________

    def visit_Decl(self, decl):
        if decl.name in self.vars:
            error(decl, "Var '%s' already declared!" % (decl.name, ))

        type = self.visit(decl.type)
        self.vars[decl.name] = type
        if decl.init:
            self.assign(decl.name, decl.init)
        elif not self.in_function:
            extern = decl.storage == 'external'
            self.mod.add_global(GlobalValue(decl.name, type, external=extern))

        return type

    def visit_TypeDecl(self, decl):
        return self.visit(decl.type)

    visit_Typename = visit_TypeDecl

    def visit_PtrDecl(self, decl):
        return types.Pointer(self.visit(decl.type.type))

    def visit_FuncDecl(self, decl):
        if decl.args:
            params = self.visits(decl.args.params)
        else:
            params = []
        return types.Function(self.visit(decl.type), params)

    def visit_IdentifierType(self, node):
        name, = node.names
        return self.type_env[name]

    def visit_Typedef(self, node):
        if node.name in ("Type", "_list"):
            type = self.type_env[node.name]
        else:
            type = self.visit(node.type)
            if type == types.Type:
                type = getattr(types, node.name)

            self.type_env[node.name] = type

        return type

    def visit_Template(self, node):
        left = self.visit(node.left)
        subtypes = self.visits(node.right)
        if left is list:
            return list(subtypes)
        else:
            assert issubclass(left, types.Type)
            subtypes = self.visits(node.right)
            return left(*subtypes)

    # ______________________________________________________________________

    def visit_FuncDef(self, node):
        assert not node.param_decls
        self.enter_func()

        name = node.decl.name
        type = self.visit(node.decl.type)
        if node.decl.type.args:
            argnames = [p.name or "" for p in node.decl.type.args.params]
        else:
            argnames = []
        self.func = Function(name, argnames, type)
        self.func.new_block('entry')
        self.builder = Builder(self.func)
        self.builder.position_at_end(self.func.startblock)

        # Store arguments in stack variables
        for argname in argnames:
            self.assignvar(argname, self.func.get_arg(argname))

        self.generic_visit(node.body)
        self.leave_func()

    # ______________________________________________________________________

    def visit_FuncCall(self, node):
        type = self.type
        opcode = node.name.name
        args = self.visits(node.args.exprs) if node.args else []

        if opcode == "list":
            return args
        elif not type and not ops.is_void(opcode):
            error(
                node, "Expected a type for sub-expression "
                "(add a cast or assignment)")
        elif not hasattr(self.builder, opcode):
            if opcode in self.mod.functions:
                return self.builder.call(type,
                                         [self.mod.get_function(opcode), args])
            error(node, "No opcode %s" % (opcode, ))

        buildop = getattr(self.builder, opcode)
        if ops.is_void(opcode):
            return buildop(*args)
        else:
            return buildop(type or "Unset", args)

    def visit_ID(self, node):
        if self.in_function:
            if node.name in self.local_vars:
                result = self.alloca(node.name)
                return self.builder.load(result.type.base, [result])

            global_val = (self.mod.get_function(node.name)
                          or self.mod.get_global(node.name))

            if not global_val:
                error(node, "Not a local or global: %r" % node.name)

            return global_val

    def visit_Cast(self, node):
        type = self.visit(node.to_type)
        if isinstance(node.expr, c_ast.FuncCall):
            op = self.visit(node.expr, type=type)
            op.type = type
            return op
        else:
            result = self.visit(node.expr)
            if result.type == type:
                return result
            return self.builder.convert(type, [result])

    def visit_Assignment(self, node):
        if node.op != '=':
            error(node, "Only assignment with '=' is supported")
        if not isinstance(node.lvalue, c_ast.ID):
            error(node, "Canot only assign to a name")
        self.assign(node.lvalue.name, node.rvalue)

    def visit_Constant(self, node):
        type = self.type_env[node.type]
        const = types.convert(node.value, types.resolve_typedef(type))
        if isinstance(const, basestring):
            const = const[1:-1]  # slice away quotes
        return Const(const)

    def visit_UnaryOp(self, node):
        op = defs.unary_defs[node.op]
        buildop = getattr(self.builder, op)
        arg = self.visit(node.expr)
        type = self.type or arg.type
        return buildop(type, [arg])

    def visit_BinaryOp(self, node):
        op = binary_defs[node.op]
        buildop = getattr(self.builder, op)
        left, right = self.visits([node.left, node.right])
        type = self.type
        if not type:
            l, r = map(types.resolve_typedef, [left.type, right.type])
            assert l == r, (l, r)
        if node.op in defs.compare_defs:
            type = types.Bool
        return buildop(type or left.type, [left, right])

    def visit_If(self, node):
        cond = self.visit(node.cond)
        ifpos, elsepos, exit_block = self.builder.ifelse(cond)

        with ifpos:
            self.visit(node.iftrue)
            self.builder.jump(exit_block)

        with elsepos:
            if node.iffalse:
                self.visit(node.iffalse)
            self.builder.jump(exit_block)

        self.builder.position_at_end(exit_block)

    def _loop(self, init, cond, next, body):
        _, exit_block = self.builder.splitblock(self.func.temp("exit"))
        _, body_block = self.builder.splitblock(self.func.temp("body"))
        _, cond_block = self.builder.splitblock(self.func.temp("cond"))

        self.visitif(init)
        self.builder.jump(cond_block)

        with self.builder.at_front(cond_block):
            cond = self.visit(cond, type=types.Bool)
            self.builder.cbranch(cond, body_block, exit_block)

        with self.builder.at_front(body_block):
            self.visit(body)
            self.visitif(next)
            bb = self.builder.basic_block
            if not bb.tail or not ops.is_terminator(bb.tail.opcode):
                self.builder.jump(cond_block)

        self.builder.position_at_end(exit_block)

    def visit_While(self, node):
        self._loop(None, node.cond, None, node.stmt)

    def visit_For(self, node):
        # avoid silly 2to3 rewrite to 'node.__next__'
        next = getattr(node, 'next')
        self._loop(node.init, node.cond, next, node.stmt)

    def visit_Return(self, node):
        b = self.builder
        value = self.visit(node.expr)
        t = self.func.temp
        b.ret(b.convert(self.func.type.restype, [value]))
Esempio n. 19
0
class Translate(object):
    """
    Translate bytecode to untypes pykit IR.
    """
    def __init__(self, func, env):
        self.func = func
        self.env = env
        self.bytecode = ByteCode(func)

        # -------------------------------------------------
        # Find predecessors

        self.blocks = {}  # offset -> Block
        self.block2offset = {}  # Block -> offset
        self.allocas = {}  # varname -> alloca
        self.stacks = {}  # Block -> value stack
        self.exc_handlers = set()  # { Block }

        # -------------------------------------------------
        # Block stacks

        self.block_stack = []
        self.loop_stack = []
        self.except_stack = []
        self.finally_stack = []

        # -------------------------------------------------
        # CFG

        self.predecessors = collections.defaultdict(set)
        self.phis = collections.defaultdict(list)

        # -------------------------------------------------
        # Variables and scoping

        self.code = self.bytecode.code
        self.varnames = self.bytecode.code.co_varnames
        self.consts = self.bytecode.code.co_consts
        self.names = self.bytecode.code.co_names
        self.argnames = list(self.varnames[:self.bytecode.code.co_argcount])

        self.globals = dict(vars(__builtin__))
        self.builtins = set(self.globals.values())
        self.globals.update(self.func.func_globals)

        self.call_annotations = collections.defaultdict(dict)

        # -------------------------------------------------
        # Error checks

        argspec = inspect.getargspec(self.func)
        if argspec.varargs:
            self.argnames.append(argspec.varargs)
        if argspec.keywords:
            self.argnames.append(argspec.keywords)

        assert not argspec.keywords, "keywords not yet supported"

    def initialize(self):
        """Initialize pykit untypes structures"""

        # Setup Function
        sig = types.Function(types.Opaque, [types.Opaque] * len(self.argnames),
                             False)
        self.dst = Function(func_name(self.func), self.argnames, sig)

        # Setup Builder
        self.builder = Builder(self.dst)

        # Setup Blocks
        for offset in self.bytecode.labels:
            name = blockname(self.func, offset)
            block = self.dst.new_block(name)
            self.blocks[offset] = block
            self.stacks[block] = []

        # Setup Variables
        self.builder.position_at_beginning(self.dst.startblock)
        for varname in self.varnames:
            stackvar = self.builder.alloca(types.Pointer(types.Opaque),
                                           result=self.dst.temp(varname))
            self.allocas[varname] = stackvar

            # Initialize function arguments
            if varname in self.argnames:
                self.builder.store(self.dst.get_arg(varname), stackvar)

    def interpret(self):
        self.curblock = self.dst.startblock

        for inst in self.bytecode:
            if inst.offset in self.blocks:
                # Block switch
                newblock = self.blocks[inst.offset]
                if self.curblock != newblock:
                    self.switchblock(newblock)
            elif self.curblock.is_terminated():
                continue

            self.op(inst)

        # -------------------------------------------------
        # Finalize

        self.update_phis()

    def op(self, inst):
        during = "Operation translate in %s" % (self.func.__name__, )
        with error_context(lineno=inst.lineno,
                           during="Translate operation",
                           pyfunc=self.func):
            self.lineno = inst.lineno
            attr = 'op_%s' % inst.opname.replace('+', '_')
            fn = getattr(self, attr, self.generic_op)
            fn(inst)

    def generic_op(self, inst):
        raise NotImplementedError(inst)

    def switchblock(self, newblock):
        """
        Switch to a new block and merge incoming values from the stacks.
        """
        #print("%s -> %s" % (self.curblock.name, newblock.name), self.stack)
        if not self.curblock.is_terminated():
            self.jump(newblock)

        self.builder.position_at_end(newblock)
        self.prevblock = self.curblock
        self.curblock = newblock

        # -------------------------------------------------
        # Find predecessors

        if newblock in self.exc_handlers:
            self.push_insert('exc_fetch')
            self.push_insert('exc_fetch_value')
            self.push_insert('exc_fetch_tb')

        # -------------------------------------------------
        # Find predecessors

        incoming = self.predecessors.get(newblock)
        if not incoming:
            return

        # -------------------------------------------------
        # Merge stack values

        stack = max([self.stacks[block] for block in incoming], key=len)
        for value in stack:
            phi = self.push_insert('phi', [], [])
            self.phis[newblock].append(phi)

        assert len(self.stack) == len(stack)

    def update_phis(self):
        laststack = self.stacks[self.dst.blocks.tail]
        assert not laststack, laststack

        for block in self.dst.blocks:
            phis = self.phis[block]
            preds = list(self.predecessors[block])
            stacks = [self.stacks[pred] for pred in preds]
            stacklen = len(phis)

            # -------------------------------------------------
            # Sanity check

            assert all(len(stack) == stacklen
                       for stack in stacks), (preds, stacks)

            if not preds or not stacklen:
                continue

            # -------------------------------------------------
            # Update φs with stack values from predecessors

            for pos, phi in enumerate(phis):
                values = []
                for pred in preds:
                    value_stack = self.stacks[pred]
                    value = value_stack[pos]
                    values.append(value)

                phi.set_args([preds, values])

    @property
    def stack(self):
        return self.stacks[self.curblock]

    @property
    def stack_level(self):
        return len(self.stack)

    def insert(self, opcode, *args):
        type = types.Void if ops.is_void(opcode) else types.Opaque
        op = Op(opcode, type, list(args))
        op.add_metadata({'lineno': self.lineno})
        self.builder.emit(op)
        return op

    def push_insert(self, opcode, *args):
        inst = self.insert(opcode, *args)
        self.push(inst)
        return inst

    def push(self, val):
        self.stack.append(val)

    def peek(self):
        """
        Take a peek at the top of stack.
        """
        if not self.stack:
            # Assuming the bytecode is valid, our predecessors must have left
            # some values on the stack.
            # return self._insert_phi()
            raise EmptyStackError
        else:
            return self.stack[-1]

    def pop(self):
        if not self.stack:
            # return self._insert_phi()
            raise EmptyStackError
        else:
            return self.stack.pop()

    def _insert_phi(self):
        with self.builder.at_front(self.curblock):
            phi = self.insert('phi', [], [])

        self.phis[self.curblock].append(phi)
        return phi

    def call(self, func, args=()):
        if not isinstance(func, Value):
            func = const(func)
        return self.push_insert('call', func, list(args))

    def call_pop(self, func, args=()):
        self.call(func, args)
        return self.pop()

    def binary_op(self, op):
        rhs = self.pop()
        lhs = self.pop()
        self.call(op, args=(lhs, rhs))

    def unary_op(self, op):
        tos = self.pop()
        self.call(op, args=(tos, ))

    def jump(self, target):
        self.predecessors[target].add(self.curblock)
        self.insert('jump', target)

    def jump_if(self, cond, truebr, falsebr):
        self.predecessors[truebr].add(self.curblock)
        self.predecessors[falsebr].add(self.curblock)
        self.insert('cbranch', cond, truebr, falsebr)

    # ------- stack ------- #

    def op_POP_BLOCK(self, inst):
        block = self.block_stack.pop()
        if isinstance(block, LoopBlock):
            self.loop_stack.pop()
        elif isinstance(block, ExceptionBlock):
            self.except_stack.pop()
        elif isinstance(block, FinallyBlock):
            self.finally_stack.pop()

        del self.stack[block.level:]

    def op_POP_TOP(self, inst):
        self.pop()

    def op_DUP_TOP(self, inst):
        value = self.pop()
        self.push(value)
        self.push(value)

    def op_DUP_TOPX(self, inst):
        count = inst.arg
        self.stack.extend(self.stack[-count:])

    def op_ROT_TWO(self, inst):
        one = self.pop()
        two = self.pop()
        self.push(one)
        self.push(two)

    def op_ROT_THREE(self, inst):
        one = self.pop()
        two = self.pop()
        three = self.pop()
        self.push(one)
        self.push(three)
        self.push(two)

    def op_ROT_FOUR(self, inst):
        one = self.pop()
        two = self.pop()
        three = self.pop()
        four = self.pop()
        self.push(one)
        self.push(four)
        self.push(three)
        self.push(two)

    # ------- control flow ------- #

    def op_POP_JUMP_IF_TRUE(self, inst):
        falsebr = self.blocks[inst.next]
        truebr = self.blocks[inst.arg]
        self.jump_if(self.pop(), truebr, falsebr)

    def op_POP_JUMP_IF_FALSE(self, inst):
        truebr = self.blocks[inst.next]
        falsebr = self.blocks[inst.arg]
        self.jump_if(self.pop(), truebr, falsebr)

    def op_JUMP_IF_TRUE(self, inst):
        falsebr = self.blocks[inst.next]
        truebr = self.blocks[inst.next + inst.arg]
        self.jump_if(self.peek(), truebr, falsebr)

    def op_JUMP_IF_FALSE(self, inst):
        truebr = self.blocks[inst.next]
        falsebr = self.blocks[inst.next + inst.arg]
        self.jump_if(self.peek(), truebr, falsebr)

    def _make_popblock(self):
        popblock = self.dst.new_block(self.dst.temp("popblock"),
                                      after=self.curblock)
        self.stacks[popblock] = []
        return popblock

    def op_JUMP_IF_TRUE_OR_POP(self, inst):
        falsebr = self.blocks[inst.next]
        truebr = self.blocks[inst.arg]

        popblock = self._make_popblock()
        self.jump_if(self.peek(), truebr, popblock)
        self.switchblock(popblock)
        self.pop()
        self.jump(falsebr)

    def op_JUMP_IF_FALSE_OR_POP(self, inst):
        truebr = self.blocks[inst.next]
        falsebr = self.blocks[inst.arg]

        popblock = self._make_popblock()
        self.jump_if(self.peek(), popblock, falsebr)
        self.switchblock(popblock)
        self.pop()
        self.jump(truebr)

    def op_JUMP_ABSOLUTE(self, inst):
        target = self.blocks[inst.arg]
        self.jump(target)

    def op_JUMP_FORWARD(self, inst):
        target = self.blocks[inst.next + inst.arg]
        self.jump(target)

    def op_RETURN_VALUE(self, inst):
        val = self.pop()
        if isinstance(val, Const) and val.const is None:
            val = None  # Generate a bare 'ret' instruction
        self.insert('ret', val)

    def op_CALL_FUNCTION(self, inst, varargs=None):
        argc = inst.arg & 0xff
        kwsc = (inst.arg >> 8) & 0xff

        def pop_kws():
            val = self.pop()
            key = self.pop()
            if key.opcode != 'const':
                raise ValueError('keyword must be a constant')
            return key.value, val

        kws = list(reversed([pop_kws() for i in range(kwsc)]))
        args = list(reversed([self.pop() for i in range(argc)]))
        assert not kws, "Keyword arguments not yet supported"

        func = self.pop()
        return self.call(func, args)

    def op_CALL_FUNCTION_VAR(self, inst):
        it = self.call_pop(tuple, [self.pop()])
        #varargs = self.insert('unpack', it)
        call = self.op_CALL_FUNCTION(inst, varargs=it)

        # Add unpacked iterable to args list
        f, args = call.args
        call.set_args([f, args + [it]])

        # Annotate call as a 'varargs' application
        self.call_annotations[call]['varargs'] = True

    def op_GET_ITER(self, inst):
        self.call(iter, [self.pop()])

    def op_FOR_ITER(self, inst):
        """
        Translate a for loop to:

            it = getiter(iterable)
            try:
                while 1:
                    i = next(t)
                    ...
            except StopIteration:
                pass
        """
        iterobj = self.peek()
        delta = inst.arg
        loopexit = self.blocks[inst.next + delta]

        loop_block = self.loop_stack[-1]
        loop_block.catch_block = loopexit

        # -------------------------------------------------
        # Try

        self.insert('exc_setup', [loopexit])
        self.call(next, [iterobj])

        # We assume a 1-to-1 block mapping, resolve a block split in a
        # later pass
        self.insert('exc_end')

        # -------------------------------------------------
        # Catch

        with self.builder.at_front(loopexit):
            self.insert('exc_catch',
                        [Const(StopIteration, type=types.Exception)])

        # -------------------------------------------------
        # Exit

        # Add the loop exit at a successor to the header
        self.predecessors[loopexit].add(self.curblock)

        # Remove ourselves as a predecessor from the actual exit block, set by
        # SETUP_LOOP
        self.predecessors[loop_block.end].remove(self.prevblock)

    def op_BREAK_LOOP(self, inst):
        loopblock = self.loop_stack[-1]
        self.jump(target=loopblock.catch_block or loopblock.end)

    def op_BUILD_TUPLE(self, inst):
        count = inst.arg
        items = [self.pop() for _ in range(count)]
        ordered = list(reversed(items))
        if all(isinstance(item, Const) for item in ordered):
            # create constant tuple
            self.push(const(tuple(item.const for item in ordered)))
        elif len(ordered) < tupleobject.STATIC_THRESHOLD:
            # Build static tuple
            result = self.call_pop(tupleobject.EmptyTuple)
            for item in items:
                result = self.call_pop(tupleobject.StaticTuple,
                                       args=(item, result))
            self.push(result)
        else:
            raise NotImplementedError("Generic tuples")

    def op_BUILD_LIST(self, inst):
        count = inst.arg
        if not count:
            self.call(listobject.EmptyList, ())
            return

        self.op_BUILD_TUPLE(inst)
        result_tuple = self.pop()
        self.call(list, (result_tuple, ))

    def op_LOAD_ATTR(self, inst):
        attr = self.names[inst.arg]
        obj = self.pop()
        if isinstance(obj, Const) and hasattr(obj.const, attr):
            val = getattr(obj.const, attr)
            self.push(const(val))
        else:
            self.push_insert('getfield', obj, attr)

    def op_LOAD_GLOBAL(self, inst):
        name = self.names[inst.arg]
        if name not in self.globals:
            raise NameError("Could not resolve %r at compile time" % name)
        value = self.globals[name]
        self.push(const(value))

    def op_LOAD_DEREF(self, inst):
        i = inst.arg
        cell = self.func.__closure__[i]
        value = cell.cell_contents
        self.push(const(value))

    def op_LOAD_FAST(self, inst):
        name = self.varnames[inst.arg]
        self.push_insert('load', self.allocas[name])

    def op_LOAD_CONST(self, inst):
        val = self.consts[inst.arg]
        self.push(const(val))

    def op_STORE_FAST(self, inst):
        value = self.pop()
        name = self.varnames[inst.arg]
        self.insert('store', value, self.allocas[name])

    def op_STORE_ATTR(self, inst):
        attr = self.names[inst.arg]
        obj = self.pop()
        value = self.pop()
        self.insert('setfield', obj, attr, value)

    def op_STORE_SUBSCR(self, inst):
        tos0 = self.pop()
        tos1 = self.pop()
        tos2 = self.pop()
        self.call(operator.setitem, (tos1, tos0, tos2))
        self.pop()

    def op_UNPACK_SEQUENCE(self, inst):
        value = self.pop()
        itemct = inst.arg
        for i in reversed(range(itemct)):
            self.push_insert('unpack', value, i, itemct)

    def op_COMPARE_OP(self, inst):
        opname = dis.cmp_op[inst.arg]

        if opname == 'not in':
            self.binary_op(COMPARE_OP_FUNC['in'])
            self.unary_op(operator.not_)
        elif opname == 'is not':
            self.binary_op(COMPARE_OP_FUNC['is'])
            self.unary_op(operator.not_)
        else:
            opfunc = COMPARE_OP_FUNC[opname]
            self.binary_op(opfunc)

    def op_UNARY_POSITIVE(self, inst):
        self.unary_op(operator.pos)

    def op_UNARY_NEGATIVE(self, inst):
        self.unary_op(operator.neg)

    def op_UNARY_INVERT(self, inst):
        self.unary_op(operator.invert)

    def op_UNARY_NOT(self, inst):
        self.unary_op(operator.not_)

    def op_BINARY_SUBSCR(self, inst):
        self.binary_op(operator.getitem)

    def op_BINARY_ADD(self, inst):
        self.binary_op(operator.add)

    def op_BINARY_SUBTRACT(self, inst):
        self.binary_op(operator.sub)

    def op_BINARY_MULTIPLY(self, inst):
        self.binary_op(operator.mul)

    def op_BINARY_DIVIDE(self, inst):
        self.binary_op(operator.floordiv)

    def op_BINARY_FLOOR_DIVIDE(self, inst):
        self.binary_op(operator.floordiv)

    def op_BINARY_TRUE_DIVIDE(self, inst):
        self.binary_op(operator.truediv)

    def op_BINARY_MODULO(self, inst):
        self.binary_op(operator.mod)

    def op_BINARY_POWER(self, inst):
        self.binary_op(operator.pow)

    def op_BINARY_RSHIFT(self, inst):
        self.binary_op(operator.rshift)

    def op_BINARY_LSHIFT(self, inst):
        self.binary_op(operator.lshift)

    def op_BINARY_AND(self, inst):
        self.binary_op(operator.and_)

    def op_BINARY_OR(self, inst):
        self.binary_op(operator.or_)

    def op_BINARY_XOR(self, inst):
        self.binary_op(operator.xor)

    def op_INPLACE_ADD(self, inst):
        self.binary_op(operator.add)

    def op_INPLACE_SUBTRACT(self, inst):
        self.binary_op(operator.sub)

    def op_INPLACE_MULTIPLY(self, inst):
        self.binary_op(operator.mul)

    def op_INPLACE_DIVIDE(self, inst):
        self.binary_op(operator.floordiv)

    def op_INPLACE_FLOOR_DIVIDE(self, inst):
        self.binary_op(operator.floordiv)

    def op_INPLACE_TRUE_DIVIDE(self, inst):
        self.binary_op(operator.truediv)

    def op_INPLACE_MODULO(self, inst):
        self.binary_op(operator.mod)

    def op_INPLACE_POWER(self, inst):
        self.binary_op(operator.pow)

    def op_INPLACE_RSHIFT(self, inst):
        self.binary_op(operator.rshift)

    def op_INPLACE_LSHIFT(self, inst):
        self.binary_op(operator.lshift)

    def op_INPLACE_AND(self, inst):
        self.binary_op(operator.and_)

    def op_INPLACE_OR(self, inst):
        self.binary_op(operator.or_)

    def op_INPLACE_XOR(self, inst):
        self.binary_op(operator.xor)

    def slice(self, start=None, stop=None, step=None):
        start, stop, step = map(const, [start, stop, step])
        return self.call_pop(const(sliceobject.Slice), [start, stop, step])

    def op_SLICE_0(self, inst):
        tos = self.pop()
        self.call(operator.getitem, args=(tos, self.slice()))

    def op_SLICE_1(self, inst):
        start = self.pop()
        tos = self.pop()
        self.call(operator.getitem, args=(tos, self.slice(start=start)))

    def op_SLICE_2(self, inst):
        stop = self.pop()
        tos = self.pop()
        self.call(operator.getitem, args=(tos, self.slice(stop=stop)))

    def op_SLICE_3(self, inst):
        stop = self.pop()
        start = self.pop()
        tos = self.pop()
        self.call(operator.getitem, args=(tos, self.slice(start, stop)))

    def op_STORE_SLICE_0(self, inst):
        tos = self.pop()
        val = self.pop()
        self.call_pop(operator.setitem, args=(tos, self.slice(), val))

    def op_STORE_SLICE_1(self, inst):
        start = self.pop()
        tos = self.pop()
        val = self.pop()
        self.call_pop(operator.setitem,
                      args=(tos, self.slice(start=start), val))

    def op_STORE_SLICE_2(self, inst):
        stop = self.pop()
        tos = self.pop()
        val = self.pop()
        self.call_pop(operator.setitem, args=(tos, self.slice(stop=stop), val))

    def op_STORE_SLICE_3(self, inst):
        stop = self.pop()
        start = self.pop()
        tos = self.pop()
        val = self.pop()
        self.call_pop(operator.setitem,
                      args=(tos, self.slice(start, stop), val))

    def op_BUILD_SLICE(self, inst):
        argc = inst.arg
        tos = [self.pop() for _ in range(argc)]

        if argc == 2:
            start, stop, step = [tos[1], tos[0], None]
        elif argc == 3:
            start, stop, step = [tos[2], tos[1], tos[0]]
        else:
            raise Exception('unreachable')

        self.push(self.slice(start, stop, step))

    # ------- Exceptions ------- #

    def op_RAISE_VARARGS(self, inst):
        nargs = inst.arg
        if nargs == 3:
            raise CompileError("Traceback argument to raise not supported")

        args = list(reversed([self.pop() for _ in range(nargs)]))

        if self.except_stack:
            except_block = self.except_stack[-1]
            self.predecessors[except_block].add(self.curblock)

        self.insert('exc_throw', *args)

    # ------- Generators ------- #

    def op_YIELD_VALUE(self, inst):
        val = self.pop()
        self.push_insert('yield', val)
        self.env['flypy.state.generator'] += 1

    # ------- Blocks ------- #

    def op_SETUP_LOOP(self, inst):
        exit_block = self.blocks[inst.next + inst.arg]
        self.predecessors[exit_block].add(self.curblock)

        block = LoopBlock(None, exit_block, self.stack_level)
        self.block_stack.append(block)
        self.loop_stack.append(block)

    def op_SETUP_EXCEPT(self, inst):
        try_block = self.blocks[inst.next]
        except_block = self.blocks[inst.next + inst.arg]
        self.predecessors[except_block].add(self.curblock)
        self.exc_handlers.add(except_block)

        with self.builder.at_front(self.curblock):
            self.builder.exc_setup([except_block])

        block = ExceptionBlock(try_block, except_block, self.stack_level)
        self.block_stack.append(block)
        self.except_stack.append(block)

    def op_SETUP_FINALLY(self, inst):
        try_block = self.blocks[inst.next]
        finally_block = self.blocks[inst.next + inst.arg]
        self.predecessors[finally_block].add(self.curblock)

        block = FinallyBlock(try_block, finally_block, self.stack_level)
        self.block_stack.append(block)
        self.finally_stack.append(block)

    def op_END_FINALLY(self, inst):
        self.pop()
        self.pop()
        self.pop()
        # self.insert('end_finally')

    # ------- print ------- #

    def op_PRINT_ITEM(self, inst):
        self.call_pop(print, [self.pop()])

    def op_PRINT_NEWLINE(self, inst):
        self.call_pop(print, [const('\n')])

    # ------- Misc ------- #

    def op_STOP_CODE(self, inst):
        pass
Esempio n. 20
0
class Translate(object):
    """
    Translate bytecode to untypes pykit IR.
    """

    def __init__(self, func, env):
        self.func = func
        self.env = env
        self.bytecode = ByteCode(func)

        # -------------------------------------------------
        # Find predecessors

        self.blocks = {}            # offset -> Block
        self.block2offset = {}      # Block -> offset
        self.allocas = {}           # varname -> alloca
        self.stacks = {}            # Block -> value stack
        self.exc_handlers = set()   # { Block }

        # -------------------------------------------------
        # Block stacks

        self.block_stack   = []
        self.loop_stack    = []
        self.except_stack  = []
        self.finally_stack = []

        # -------------------------------------------------
        # CFG

        self.predecessors = collections.defaultdict(set)
        self.phis = collections.defaultdict(list)

        # -------------------------------------------------
        # Variables and scoping

        self.code = self.bytecode.code
        self.varnames = self.bytecode.code.co_varnames
        self.consts = self.bytecode.code.co_consts
        self.names = self.bytecode.code.co_names
        self.argnames = list(self.varnames[:self.bytecode.code.co_argcount])

        self.globals = dict(vars(__builtin__))
        self.builtins = set(self.globals.values())
        self.globals.update(self.func.func_globals)

        self.call_annotations = collections.defaultdict(dict)

        # -------------------------------------------------
        # Error checks

        argspec = inspect.getargspec(self.func)
        if argspec.varargs:
            self.argnames.append(argspec.varargs)
        if argspec.keywords:
            self.argnames.append(argspec.keywords)

        assert not argspec.keywords, "keywords not yet supported"

    def initialize(self):
        """Initialize pykit untypes structures"""

        # Setup Function
        sig = types.Function(types.Opaque, [types.Opaque] * len(self.argnames),
                             False)
        self.dst = Function(func_name(self.func), self.argnames, sig)

        # Setup Builder
        self.builder = Builder(self.dst)

        # Setup Blocks
        for offset in self.bytecode.labels:
            name = blockname(self.func, offset)
            block = self.dst.new_block(name)
            self.blocks[offset] = block
            self.stacks[block] = []

        # Setup Variables
        self.builder.position_at_beginning(self.dst.startblock)
        for varname in self.varnames:
            stackvar = self.builder.alloca(types.Pointer(types.Opaque),
                                           result=self.dst.temp(varname))
            self.allocas[varname] = stackvar

            # Initialize function arguments
            if varname in self.argnames:
                self.builder.store(self.dst.get_arg(varname), stackvar)

    def interpret(self):
        self.curblock = self.dst.startblock

        for inst in self.bytecode:
            if inst.offset in self.blocks:
                # Block switch
                newblock = self.blocks[inst.offset]
                if self.curblock != newblock:
                    self.switchblock(newblock)
            elif self.curblock.is_terminated():
                continue

            self.op(inst)

        # -------------------------------------------------
        # Finalize

        self.update_phis()

    def op(self, inst):
        during = "Operation translate in %s" % (self.func.__name__, )
        with error_context(lineno=inst.lineno, during="Translate operation",
                           pyfunc=self.func):
            self.lineno = inst.lineno
            attr = 'op_%s' % inst.opname.replace('+', '_')
            fn = getattr(self, attr, self.generic_op)
            fn(inst)

    def generic_op(self, inst):
        raise NotImplementedError(inst)

    def switchblock(self, newblock):
        """
        Switch to a new block and merge incoming values from the stacks.
        """
        #print("%s -> %s" % (self.curblock.name, newblock.name), self.stack)
        if not self.curblock.is_terminated():
            self.jump(newblock)

        self.builder.position_at_end(newblock)
        self.prevblock = self.curblock
        self.curblock = newblock

        # -------------------------------------------------
        # Find predecessors

        if newblock in self.exc_handlers:
            self.push_insert('exc_fetch')
            self.push_insert('exc_fetch_value')
            self.push_insert('exc_fetch_tb')

        # -------------------------------------------------
        # Find predecessors

        incoming = self.predecessors.get(newblock)
        if not incoming:
            return

        # -------------------------------------------------
        # Merge stack values

        stack = max([self.stacks[block] for block in incoming], key=len)
        for value in stack:
            phi = self.push_insert('phi', [], [])
            self.phis[newblock].append(phi)

        assert len(self.stack) == len(stack)

    def update_phis(self):
        laststack = self.stacks[self.dst.blocks.tail]
        assert not laststack, laststack

        for block in self.dst.blocks:
            phis = self.phis[block]
            preds  = list(self.predecessors[block])
            stacks = [self.stacks[pred] for pred in preds]
            stacklen = len(phis)

            # -------------------------------------------------
            # Sanity check

            assert all(len(stack) == stacklen for stack in stacks), (preds, stacks)

            if not preds or not stacklen:
                continue

            # -------------------------------------------------
            # Update φs with stack values from predecessors

            for pos, phi in enumerate(phis):
                values = []
                for pred in preds:
                    value_stack = self.stacks[pred]
                    value = value_stack[pos]
                    values.append(value)

                phi.set_args([preds, values])

    @property
    def stack(self):
        return self.stacks[self.curblock]

    @property
    def stack_level(self):
        return len(self.stack)

    def insert(self, opcode, *args):
        type = types.Void if ops.is_void(opcode) else types.Opaque
        op = Op(opcode, type, list(args))
        op.add_metadata({'lineno': self.lineno})
        self.builder.emit(op)
        return op

    def push_insert(self, opcode, *args):
        inst = self.insert(opcode, *args)
        self.push(inst)
        return inst

    def push(self, val):
        self.stack.append(val)

    def peek(self):
        """
        Take a peek at the top of stack.
        """
        if not self.stack:
            # Assuming the bytecode is valid, our predecessors must have left
            # some values on the stack.
            # return self._insert_phi()
            raise EmptyStackError
        else:
            return self.stack[-1]

    def pop(self):
        if not self.stack:
            # return self._insert_phi()
            raise EmptyStackError
        else:
            return self.stack.pop()

    def _insert_phi(self):
        with self.builder.at_front(self.curblock):
            phi = self.insert('phi', [], [])

        self.phis[self.curblock].append(phi)
        return phi

    def call(self, func, args=()):
        if not isinstance(func, Value):
            func = const(func)
        return self.push_insert('call', func, list(args))

    def call_pop(self, func, args=()):
        self.call(func, args)
        return self.pop()

    def binary_op(self, op):
        rhs = self.pop()
        lhs = self.pop()
        self.call(op, args=(lhs, rhs))

    def unary_op(self, op):
        tos = self.pop()
        self.call(op, args=(tos,))

    def jump(self, target):
        self.predecessors[target].add(self.curblock)
        self.insert('jump', target)

    def jump_if(self, cond, truebr, falsebr):
        self.predecessors[truebr].add(self.curblock)
        self.predecessors[falsebr].add(self.curblock)
        self.insert('cbranch', cond, truebr, falsebr)

    # ------- stack ------- #

    def op_POP_BLOCK(self, inst):
        block = self.block_stack.pop()
        if isinstance(block, LoopBlock):
            self.loop_stack.pop()
        elif isinstance(block, ExceptionBlock):
            self.except_stack.pop()
        elif isinstance(block, FinallyBlock):
            self.finally_stack.pop()

        del self.stack[block.level:]

    def op_POP_TOP(self, inst):
        self.pop()

    def op_DUP_TOP(self, inst):
        value = self.pop()
        self.push(value)
        self.push(value)

    def op_DUP_TOPX(self, inst):
        count = inst.arg
        self.stack.extend(self.stack[-count:])

    def op_ROT_TWO(self, inst):
        one = self.pop()
        two = self.pop()
        self.push(one)
        self.push(two)

    def op_ROT_THREE(self, inst):
        one = self.pop()
        two = self.pop()
        three = self.pop()
        self.push(one)
        self.push(three)
        self.push(two)

    def op_ROT_FOUR(self, inst):
        one = self.pop()
        two = self.pop()
        three = self.pop()
        four = self.pop()
        self.push(one)
        self.push(four)
        self.push(three)
        self.push(two)

    # ------- control flow ------- #

    def op_POP_JUMP_IF_TRUE(self, inst):
        falsebr = self.blocks[inst.next]
        truebr = self.blocks[inst.arg]
        self.jump_if(self.pop(), truebr, falsebr)

    def op_POP_JUMP_IF_FALSE(self, inst):
        truebr = self.blocks[inst.next]
        falsebr = self.blocks[inst.arg]
        self.jump_if(self.pop(), truebr, falsebr)

    def op_JUMP_IF_TRUE(self, inst):
        falsebr = self.blocks[inst.next]
        truebr = self.blocks[inst.next + inst.arg]
        self.jump_if(self.peek(), truebr, falsebr)

    def op_JUMP_IF_FALSE(self, inst):
        truebr = self.blocks[inst.next]
        falsebr = self.blocks[inst.next + inst.arg]
        self.jump_if(self.peek(), truebr, falsebr)

    def _make_popblock(self):
        popblock = self.dst.new_block(self.dst.temp("popblock"),
                                      after=self.curblock)
        self.stacks[popblock] = []
        return popblock

    def op_JUMP_IF_TRUE_OR_POP(self, inst):
        falsebr = self.blocks[inst.next]
        truebr = self.blocks[inst.arg]

        popblock = self._make_popblock()
        self.jump_if(self.peek(), truebr, popblock)
        self.switchblock(popblock)
        self.pop()
        self.jump(falsebr)


    def op_JUMP_IF_FALSE_OR_POP(self, inst):
        truebr = self.blocks[inst.next]
        falsebr = self.blocks[inst.arg]

        popblock = self._make_popblock()
        self.jump_if(self.peek(), popblock, falsebr)
        self.switchblock(popblock)
        self.pop()
        self.jump(truebr)

    def op_JUMP_ABSOLUTE(self, inst):
        target = self.blocks[inst.arg]
        self.jump(target)

    def op_JUMP_FORWARD(self, inst):
        target = self.blocks[inst.next + inst.arg]
        self.jump(target)

    def op_RETURN_VALUE(self, inst):
        val = self.pop()
        if isinstance(val, Const) and val.const is None:
            val = None # Generate a bare 'ret' instruction
        self.insert('ret', val)

    def op_CALL_FUNCTION(self, inst, varargs=None):
        argc = inst.arg & 0xff
        kwsc = (inst.arg >> 8) & 0xff
        def pop_kws():
            val = self.pop()
            key = self.pop()
            if key.opcode != 'const':
                raise ValueError('keyword must be a constant')
            return key.value, val

        kws = list(reversed([pop_kws() for i in range(kwsc)]))
        args = list(reversed([self.pop() for i in range(argc)]))
        assert not kws, "Keyword arguments not yet supported"

        func = self.pop()
        return self.call(func, args)

    def op_CALL_FUNCTION_VAR(self, inst):
        it = self.call_pop(tuple, [self.pop()])
        #varargs = self.insert('unpack', it)
        call = self.op_CALL_FUNCTION(inst, varargs=it)

        # Add unpacked iterable to args list
        f, args = call.args
        call.set_args([f, args + [it]])

        # Annotate call as a 'varargs' application
        self.call_annotations[call]['varargs'] = True

    def op_GET_ITER(self, inst):
        self.call(iter, [self.pop()])

    def op_FOR_ITER(self, inst):
        """
        Translate a for loop to:

            it = getiter(iterable)
            try:
                while 1:
                    i = next(t)
                    ...
            except StopIteration:
                pass
        """
        iterobj = self.peek()
        delta = inst.arg
        loopexit = self.blocks[inst.next + delta]

        loop_block = self.loop_stack[-1]
        loop_block.catch_block = loopexit

        # -------------------------------------------------
        # Try

        self.insert('exc_setup', [loopexit])
        self.call(next, [iterobj])

        # We assume a 1-to-1 block mapping, resolve a block split in a
        # later pass
        self.insert('exc_end')

        # -------------------------------------------------
        # Catch

        with self.builder.at_front(loopexit):
            self.insert('exc_catch', [Const(StopIteration, type=types.Exception)])

        # -------------------------------------------------
        # Exit

        # Add the loop exit at a successor to the header
        self.predecessors[loopexit].add(self.curblock)

        # Remove ourselves as a predecessor from the actual exit block, set by
        # SETUP_LOOP
        self.predecessors[loop_block.end].remove(self.prevblock)

    def op_BREAK_LOOP(self, inst):
        loopblock = self.loop_stack[-1]
        self.jump(target=loopblock.catch_block or loopblock.end)

    def op_BUILD_TUPLE(self, inst):
        count = inst.arg
        items = [self.pop() for _ in range(count)]
        ordered = list(reversed(items))
        if all(isinstance(item, Const) for item in ordered):
            # create constant tuple
            self.push(const(tuple(item.const for item in ordered)))
        elif len(ordered) < tupleobject.STATIC_THRESHOLD:
            # Build static tuple
            result = self.call_pop(tupleobject.EmptyTuple)
            for item in items:
                result = self.call_pop(tupleobject.StaticTuple,
                                       args=(item, result))
            self.push(result)
        else:
            raise NotImplementedError("Generic tuples")

    def op_BUILD_LIST(self, inst):
        count = inst.arg
        if not count:
            self.call(listobject.EmptyList, ())
            return

        self.op_BUILD_TUPLE(inst)
        result_tuple = self.pop()
        self.call(list, (result_tuple,))

    def op_LOAD_ATTR(self, inst):
        attr = self.names[inst.arg]
        obj = self.pop()
        if isinstance(obj, Const) and hasattr(obj.const, attr):
            val = getattr(obj.const, attr)
            self.push(const(val))
        else:
            self.push_insert('getfield', obj, attr)

    def op_LOAD_GLOBAL(self, inst):
        name = self.names[inst.arg]
        if name not in self.globals:
            raise NameError("Could not resolve %r at compile time" % name)
        value = self.globals[name]
        self.push(const(value))

    def op_LOAD_DEREF(self, inst):
        i = inst.arg
        cell = self.func.__closure__[i]
        value = cell.cell_contents
        self.push(const(value))

    def op_LOAD_FAST(self, inst):
        name = self.varnames[inst.arg]
        self.push_insert('load', self.allocas[name])

    def op_LOAD_CONST(self, inst):
        val = self.consts[inst.arg]
        self.push(const(val))

    def op_STORE_FAST(self, inst):
        value = self.pop()
        name = self.varnames[inst.arg]
        self.insert('store', value, self.allocas[name])

    def op_STORE_ATTR(self, inst):
        attr = self.names[inst.arg]
        obj = self.pop()
        value = self.pop()
        self.insert('setfield', obj, attr, value)

    def op_STORE_SUBSCR(self, inst):
        tos0 = self.pop()
        tos1 = self.pop()
        tos2 = self.pop()
        self.call(operator.setitem, (tos1, tos0, tos2))
        self.pop()

    def op_UNPACK_SEQUENCE(self, inst):
        value = self.pop()
        itemct = inst.arg
        for i in reversed(range(itemct)):
            self.push_insert('unpack', value, i, itemct)

    def op_COMPARE_OP(self, inst):
        opname = dis.cmp_op[inst.arg]

        if opname == 'not in':
            self.binary_op(COMPARE_OP_FUNC['in'])
            self.unary_op(operator.not_)
        elif opname == 'is not':
            self.binary_op(COMPARE_OP_FUNC['is'])
            self.unary_op(operator.not_)
        else:
            opfunc = COMPARE_OP_FUNC[opname]
            self.binary_op(opfunc)

    def op_UNARY_POSITIVE(self, inst):
        self.unary_op(operator.pos)

    def op_UNARY_NEGATIVE(self, inst):
        self.unary_op(operator.neg)

    def op_UNARY_INVERT(self, inst):
        self.unary_op(operator.invert)

    def op_UNARY_NOT(self, inst):
        self.unary_op(operator.not_)

    def op_BINARY_SUBSCR(self, inst):
        self.binary_op(operator.getitem)

    def op_BINARY_ADD(self, inst):
        self.binary_op(operator.add)

    def op_BINARY_SUBTRACT(self, inst):
        self.binary_op(operator.sub)

    def op_BINARY_MULTIPLY(self, inst):
        self.binary_op(operator.mul)

    def op_BINARY_DIVIDE(self, inst):
        self.binary_op(operator.floordiv)

    def op_BINARY_FLOOR_DIVIDE(self, inst):
        self.binary_op(operator.floordiv)

    def op_BINARY_TRUE_DIVIDE(self, inst):
        self.binary_op(operator.truediv)

    def op_BINARY_MODULO(self, inst):
        self.binary_op(operator.mod)

    def op_BINARY_POWER(self, inst):
        self.binary_op(operator.pow)

    def op_BINARY_RSHIFT(self, inst):
        self.binary_op(operator.rshift)

    def op_BINARY_LSHIFT(self, inst):
        self.binary_op(operator.lshift)

    def op_BINARY_AND(self, inst):
        self.binary_op(operator.and_)

    def op_BINARY_OR(self, inst):
        self.binary_op(operator.or_)

    def op_BINARY_XOR(self, inst):
        self.binary_op(operator.xor)

    def op_INPLACE_ADD(self, inst):
        self.binary_op(operator.add)

    def op_INPLACE_SUBTRACT(self, inst):
        self.binary_op(operator.sub)

    def op_INPLACE_MULTIPLY(self, inst):
        self.binary_op(operator.mul)

    def op_INPLACE_DIVIDE(self, inst):
        self.binary_op(operator.floordiv)

    def op_INPLACE_FLOOR_DIVIDE(self, inst):
        self.binary_op(operator.floordiv)

    def op_INPLACE_TRUE_DIVIDE(self, inst):
        self.binary_op(operator.truediv)

    def op_INPLACE_MODULO(self, inst):
        self.binary_op(operator.mod)

    def op_INPLACE_POWER(self, inst):
        self.binary_op(operator.pow)

    def op_INPLACE_RSHIFT(self, inst):
        self.binary_op(operator.rshift)

    def op_INPLACE_LSHIFT(self, inst):
        self.binary_op(operator.lshift)

    def op_INPLACE_AND(self, inst):
        self.binary_op(operator.and_)

    def op_INPLACE_OR(self, inst):
        self.binary_op(operator.or_)

    def op_INPLACE_XOR(self, inst):
        self.binary_op(operator.xor)

    def slice(self, start=None, stop=None, step=None):
        start, stop, step = map(const, [start, stop, step])
        return self.call_pop(const(sliceobject.Slice), [start, stop, step])

    def op_SLICE_0(self, inst):
        tos = self.pop()
        self.call(operator.getitem, args=(tos, self.slice()))

    def op_SLICE_1(self, inst):
        start = self.pop()
        tos = self.pop()
        self.call(operator.getitem, args=(tos, self.slice(start=start)))

    def op_SLICE_2(self, inst):
        stop = self.pop()
        tos = self.pop()
        self.call(operator.getitem, args=(tos, self.slice(stop=stop)))

    def op_SLICE_3(self, inst):
        stop = self.pop()
        start = self.pop()
        tos = self.pop()
        self.call(operator.getitem, args=(tos, self.slice(start, stop)))

    def op_STORE_SLICE_0(self, inst):
        tos = self.pop()
        val = self.pop()
        self.call_pop(operator.setitem, args=(tos, self.slice(), val))

    def op_STORE_SLICE_1(self, inst):
        start = self.pop()
        tos = self.pop()
        val = self.pop()
        self.call_pop(operator.setitem, args=(tos, self.slice(start=start), val))

    def op_STORE_SLICE_2(self, inst):
        stop = self.pop()
        tos = self.pop()
        val = self.pop()
        self.call_pop(operator.setitem, args=(tos, self.slice(stop=stop), val))

    def op_STORE_SLICE_3(self, inst):
        stop = self.pop()
        start = self.pop()
        tos = self.pop()
        val = self.pop()
        self.call_pop(operator.setitem, args=(tos, self.slice(start, stop), val))

    def op_BUILD_SLICE(self, inst):
        argc = inst.arg
        tos = [self.pop() for _ in range(argc)]

        if argc == 2:
            start, stop, step = [tos[1], tos[0], None]
        elif argc == 3:
            start, stop, step = [tos[2], tos[1], tos[0]]
        else:
            raise Exception('unreachable')

        self.push(self.slice(start, stop, step))

    # ------- Exceptions ------- #

    def op_RAISE_VARARGS(self, inst):
        nargs = inst.arg
        if nargs == 3:
            raise CompileError("Traceback argument to raise not supported")

        args = list(reversed([self.pop() for _ in range(nargs)]))

        if self.except_stack:
            except_block = self.except_stack[-1]
            self.predecessors[except_block].add(self.curblock)

        self.insert('exc_throw', *args)

    # ------- Generators ------- #

    def op_YIELD_VALUE(self, inst):
        val = self.pop()
        self.push_insert('yield', val)
        self.env['flypy.state.generator'] += 1

    # ------- Blocks ------- #

    def op_SETUP_LOOP(self, inst):
        exit_block = self.blocks[inst.next + inst.arg]
        self.predecessors[exit_block].add(self.curblock)

        block = LoopBlock(None, exit_block, self.stack_level)
        self.block_stack.append(block)
        self.loop_stack.append(block)

    def op_SETUP_EXCEPT(self, inst):
        try_block = self.blocks[inst.next]
        except_block = self.blocks[inst.next + inst.arg]
        self.predecessors[except_block].add(self.curblock)
        self.exc_handlers.add(except_block)

        with self.builder.at_front(self.curblock):
            self.builder.exc_setup([except_block])

        block = ExceptionBlock(try_block, except_block, self.stack_level)
        self.block_stack.append(block)
        self.except_stack.append(block)

    def op_SETUP_FINALLY(self, inst):
        try_block = self.blocks[inst.next]
        finally_block = self.blocks[inst.next + inst.arg]
        self.predecessors[finally_block].add(self.curblock)

        block = FinallyBlock(try_block, finally_block, self.stack_level)
        self.block_stack.append(block)
        self.finally_stack.append(block)

    def op_END_FINALLY(self, inst):
        self.pop()
        self.pop()
        self.pop()
        # self.insert('end_finally')

    # ------- print ------- #

    def op_PRINT_ITEM(self, inst):
        self.call_pop(print, [self.pop()])

    def op_PRINT_NEWLINE(self, inst):
        self.call_pop(print, [const('\n')])

    # ------- Misc ------- #

    def op_STOP_CODE(self, inst):
        pass
Esempio n. 21
0
class TestBuilder(unittest.TestCase):

    def setUp(self):
        self.f = Function("testfunc", ['a'],
                          types.Function(types.Float32, [types.Int32], False))
        self.b = Builder(self.f)
        self.b.position_at_end(self.f.new_block('entry'))
        self.a = self.f.get_arg('a')

    def test_basic_builder(self):
        v = self.b.alloca(types.Pointer(types.Float32))
        result = self.b.mul(self.a, self.a, result='r')
        c = self.b.convert(types.Float32, result)
        self.b.store(c, v)
        val = self.b.load(v)
        self.b.ret(val)
        # print(string(self.f))
        assert interp.run(self.f, args=[10]) == 100

    def test_splitblock(self):
        old, new = self.b.splitblock('newblock')
        with self.b.at_front(old):
            self.b.add(self.a, self.a)
        with self.b.at_end(new):
            self.b.div(self.a, self.a)
        self.assertEqual(opcodes(self.f), ['add', 'div'])

    def test_loop_builder(self):
        square = self.b.mul(self.a, self.a)
        c = self.b.convert(types.Float32, square)
        self.b.position_after(square)
        _, block = self.b.splitblock('start', terminate=True)
        self.b.position_at_end(block)

        const = partial(Const, type=types.Int32)
        cond, body, exit = self.b.gen_loop(const(5), const(10), const(2))
        with self.b.at_front(body):
            self.b.print(c)
        with self.b.at_end(exit):
            self.b.ret(c)

        self.assertEqual(interp.run(self.f, args=[10]), 100.0)

    def test_splitblock_preserve_phis(self):
        """
        block1:
            %0 = mul a a
            jump(newblock)

        newblock:
            %1 = phi([block1], [%0])
            ret %1
        """
        square = self.b.mul(self.a, self.a)
        old, new = self.b.splitblock('newblock', terminate=True)
        with self.b.at_front(new):
            phi = self.b.phi(types.Int32, [self.f.startblock], [square])
            self.b.ret(phi)

        # Now split block1
        self.b.position_after(square)
        block1, split = self.b.splitblock(terminate=True)

        phi, ret = new.ops
        blocks, values = phi.args
        self.assertEqual(blocks, [split])