예제 #1
0
    def visit_Module(self, node):
        ctx = Context()
        self.module = Module("root_module", ctx)
        self.irbuilder = IRBuilder(self.module, ctx)

        for expr in node.body:
            ast.NodeVisitor.visit(self, expr)
예제 #2
0
class IRGenerator(ast.NodeVisitor):
    """
    Constant Propagation Visitor: This class does constant propagation for the node visitor so that types can be inferred
    """
    def __init__(self):
        self.module = None
        self.irbuilder = None
        self.current_scope_stack = []
        self.symbol_tables = [{}]
        self.global_scope = {}
        self.current_func = None
        self.cfg_stack = []
        self.current_exit_block = None

    def get_current_symbol_table(self):
        return self.symbol_tables[-1]

    def set_symbol(self, sym, val):
        sym_table = self.get_current_symbol_table()
        sym_table[sym] = val

    def get_symbol(self, sym):
        sym_table = self.get_current_symbol_table()
        if sym in sym_table:
            return sym_table[sym]

        return None

    def in_symbol_table(self, sym):
        sym_table = self.get_current_symbol_table()
        if sym in sym_table:
            return True

        return False

    def current_scope(self):
        return self.current_scope_stack[-1]

    def visit_Module(self, node):
        ctx = Context()
        self.module = Module("root_module", ctx)
        self.irbuilder = IRBuilder(self.module, ctx)

        for expr in node.body:
            ast.NodeVisitor.visit(self, expr)

    def visit_Num(self, node):
        number = self.irbuilder.create_number(node.n)
        return number

    def visit_Str(self, node):
        string = self.irbuilder.create_string(node.s)
        return string

    def visit_Name(self, node):
        name = node.id
        # Find the name in the symbol table
        if self.in_symbol_table(name):
            return self.get_symbol(name)

        raise Exception("Variable %s not found" % name)

    def visit_BinOp(self, node):
        irbuilder = self.irbuilder
        left = node.left
        right = node.right
        op = node.op
        irleft = ast.NodeVisitor.visit(self, left)
        irright = ast.NodeVisitor.visit(self, right)

        inst = None

        if isinstance(op, ast.Add):
            inst = irbuilder.create_add(irleft, irright, "add")
        elif isinstance(op, ast.Sub):
            inst = irbuilder.create_sub(irleft, irright, "sub")
        elif isinstance(op, ast.Mult):
            inst = irbuilder.create_mul(irleft, irright, "mult")
        elif isinstance(op, ast.Div):
            inst = irbuilder.create_div(irleft, irright, "div")
        elif isinstance(op, ast.LShift):
            inst = irbuilder.create_shl(irleft, irright, "shl")
        elif isinstance(op, ast.RShift):
            inst = irbuilder.create_lshr(irleft, irright, "lshr")

        if inst is None:
            raise Exception("Inst cannot be None")

        return inst

    def visit_Expr(self, node):
        ast.NodeVisitor.visit(self, node.value)


    def visit_Assign(self, node):
        irbuilder = self.irbuilder
        targets = node.targets

        if len(targets) > 1:
            raise Exception("Do not handle more than one return value")

        value = node.value
        rhs = ast.NodeVisitor.visit(self, value)

        target = targets[0]
        target_sym = self.get_symbol(target.id)

        if target_sym is None:
            # Create an alloca for this target symbol
            target_sym = irbuilder.create_alloca(name=target.id)
            self.set_symbol(target.id, target_sym)

        assert isinstance(target_sym, AllocaInstruction)

        if isinstance(rhs, AllocaInstruction):
            rhs = irbuilder.create_load(rhs)

        irbuilder.create_store(target_sym, rhs)

    def visit_Return(self, node):
        irbuilder = self.irbuilder
        retv_sym = self.get_symbol("retv")

        if node is None:
            none = ConstantNone()
            irbuilder.create_store(retv_sym, none)
            irbuilder.create_branch(self.current_exit_block)
            self.symbol_tables.pop()
            return

        value = node.value
        if isinstance(value, ast.Name):
            inst = self.get_symbol(value.id)

            if isinstance(inst, AllocaInstruction):
                inst = irbuilder.create_load(inst)
            irbuilder.create_store(retv_sym, inst)
        else:
            inst = ast.NodeVisitor.visit(self, value)

            if isinstance(inst, AllocaInstruction):
                inst = irbuilder.create_load(inst)

            irbuilder.create_store(retv_sym, inst)

        irbuilder.create_branch(self.current_exit_block)
        self.symbol_tables.pop()

    def visit_GT(self):
        return CompareTypes.SGT

    def visit_GE(self):
        return CompareTypes.SGE

    def visit_EQ(self):
        return CompareTypes.EQ

    def visit_LE(self):
        return CompareTypes.SLE

    def visit_LT(self):
        return CompareTypes.SLT

    def visit_Compare(self, node):
        induction_var = node.left

        assert len(node.comparators) == 1
        induction_var_value_check = node.comparators[0]

        assert len(node.ops) == 1

        # So create a compare instruction
        irbuilder = self.irbuilder
        cmp = irbuilder.create_icmp(self.visit(induction_var), self.visit(induction_var_value_check), self.visit(node.ops[0]))
        return cmp

    def visit_If(self, node):
        cmp = self.visit(node.test)

        irbuilder = self.irbuilder
        if_block = irbuilder.create_basic_block("then", self.current_func)
        else_block = irbuilder.create_basic_block("else", self.current_func)
        exit_if_block = irbuilder.create_basic_block("endif", self.current_func)

        irbuilder.create_cond_branch(cmp, 1, if_block, else_block)

        # Add the code for the if-block
        self.current_func.basic_blocks.append(if_block)
        irbuilder.insert_after(if_block)
        for inst in node.body:
            self.visit(inst)

        current_blk = irbuilder.get_current_bb()
        if not current_blk.has_terminator():
            irbuilder.create_branch(exit_if_block)

        # Add the code for the else-block
        self.current_func.basic_blocks.append(else_block)
        irbuilder.insert_after(else_block)
        for inst in node.orelse:
            self.visit(inst)

        current_blk = irbuilder.get_current_bb()
        if not current_blk.has_terminator():
            irbuilder.create_branch(exit_if_block)

        self.current_func.basic_blocks.append(exit_if_block)
        irbuilder.insert_after(exit_if_block)

    def visit_While(self, node):
        irbuilder = self.irbuilder

        # Create a comparision object for the while
        cmp = self.visit(node.test)

        while_entry_block = irbuilder.create_basic_block("while_entry", self.current_func)
        irbuilder.create_branch(while_entry_block)

        while_block = irbuilder.create_basic_block("while", self.current_func)
        exit_while_block = irbuilder.create_basic_block("exit_while", self.current_func)

        irbuilder.insert_after(while_entry_block)
        irbuilder.create_cond_branch(cmp, 1, while_block, exit_while_block)

        irbuilder.insert_after(while_block)
        for inst in node.body:
            self.visit(inst)

        irbuilder.create_branch(while_entry_block)

        self.current_func.basic_blocks.append(while_entry_block)
        self.current_func.basic_blocks.append(while_block)
        self.current_func.basic_blocks.append(exit_while_block)

        irbuilder.insert_after(exit_while_block)

    def visit_For(self, node):
        raise Exception("For: Not been implemented since we haven't implemented range iterations")

    def visit_Call(self, node):
        irbuilder = self.irbuilder
        irfunc = self.global_scope[node.func.id]

        args = node.args

        irargs = []
        for arg in args:
            if isinstance(arg, ast.Name):
                irarg = self.get_symbol(arg.name)
            else:
                irarg = ast.NodeVisitor.visit(self, arg)

            irargs.append(irarg)

        inst = irbuilder.create_call(irfunc, irargs)
        return inst

    def visit_FunctionDef(self, func):
        irbuilder = self.irbuilder
        self.symbol_tables.append({})
        irargs = [Argument(arg.arg) for arg in func.args.args]

        irfunc = irbuilder.create_function(func.name, irargs)
        self.current_func = irfunc

        entry_bb = irbuilder.create_basic_block("entry", irfunc)
        irfunc.basic_blocks.append(entry_bb)
        exit_bb = irbuilder.create_basic_block("exit", irfunc)
        self.current_exit_block = exit_bb

        irbuilder.insert_after(entry_bb)
        # Create an alloca for the return value
        retv = irbuilder.create_alloca(name="retv")
        self.set_symbol("retv", retv)

        # Now create allocas for every argument
        for irarg in irargs:
            alloca = irbuilder.create_alloca()
            irbuilder.create_store(alloca, irarg)
            self.set_symbol(irarg.name, alloca)

        for i, inst in enumerate(func.body):
            ast.NodeVisitor.visit(self, inst)

            if isinstance(inst, ast.While) or isinstance(inst, ast.If) or isinstance(inst, ast.For):
                if len(func.body) == i + 1:
                    # Create a default return
                    self.visit_Return(None)

        # Now we should have completed generating all the code. So create
        # the return value
        irfunc.basic_blocks.append(exit_bb)

        irbuilder.insert_after(exit_bb)
        retv_loaded = irbuilder.create_load(retv)
        irbuilder.create_return(retv_loaded)

        self.global_scope[func.name] = irfunc

        if func.name == "main":
            self.module.entry_point = irfunc