Exemple #1
0
def get_array_pointer(scope_stack, builder: ir.IRBuilder, array, index):
    if isinstance(array, tuple) and array[0] == 'struct':
        item = get_struct_pointer(scope_stack, builder, array[1], array[2])
    elif isinstance(array, tuple) and array[0] == 'struct_ptr':
        item = get_struct_ptr_pointer(scope_stack, builder, array[1], array[2])
    else:
        identifier = None
        if isinstance(array, tuple):
            identifier = get_identifier(scope_stack, array[1])
        item = get_pointer(scope_stack, builder, array)
        if identifier and identifier['type'] == 'val_ptr':
            item = builder.load(item)
            return builder.gep(item, [index], True)
    return builder.gep(item, [int_type(0), index], True)
Exemple #2
0
 def _store_to_alloc(self, index: Union[ir.Constant, ir.Instruction],
                     src_nums: List[ir.LoadInstr],
                     builder: ir.IRBuilder) -> None:
     if not self.dtype in (DType.Complx, DType.DComplx):
         dest_ptr = builder.gep(self.alloc, [index])
         builder.store(src_nums[0], dest_ptr)
     else:
         cmplx_dest_index = builder.mul(index, ir.Constant(int_type, 2))
         dest_ptr_r = builder.gep(self.alloc, [cmplx_dest_index])
         builder.store(src_nums[0], dest_ptr_r)
         cmplx_dest_index = builder.add(cmplx_dest_index,
                                        ir.Constant(int_type, 1))
         dest_ptr_i = builder.gep(self.alloc, [cmplx_dest_index])
         builder.store(src_nums[1], dest_ptr_i)
Exemple #3
0
    def _load_from_alloc(self, index: Union[ir.Constant, ir.Instruction],
                         builder: ir.IRBuilder) -> List[ir.LoadInstr]:
        if not self.dtype in (DType.Complx, DType.DComplx):
            self_ptr = builder.gep(self.alloc, [index])
            products = [builder.load(self_ptr)]
        else:
            cmplx_index = builder.mul(index, ir.Constant(int_type, 2))
            self_ptr_real = builder.gep(self.alloc, [cmplx_index])
            product_real = builder.load(self_ptr_real)
            cmplx_index = builder.add(cmplx_index, ir.Constant(int_type, 1))
            self_ptr_imag = builder.gep(self.alloc, [cmplx_index])
            product_imag = builder.load(self_ptr_imag)
            products = [product_real, product_imag]

        return products
Exemple #4
0
 def _gen_setitem(self, builder: ir.IRBuilder,
                  index: Union[int, Node, slice, list, tuple,
                               np.ndarray], val: Node) -> None:
     """When set index to one node, it must be LValue node,
     if not, the graph maintainer should modify its vtype to LEFT.
     Also, only when the array is required, will it be generated.
     """
     if isinstance(index, int):
         const0 = ir.Constant(int_type, 0)
         src_nums = val.get_ele(const0, builder)
         if val.dtype != self.dtype:
             src_nums = build_type_cast(builder, src_nums, val.dtype,
                                        self.dtype)
         index = ir.Constant(int_type, index)
         self._store_to_alloc(index, src_nums, builder)
     elif isinstance(index, slice):
         size = compute_size(index, self.size)
         start, _, step = index.indices(val.size)
         v_start = ir.Constant(int_type, start)
         v_step = ir.Constant(int_type, step)
         dest_index_ptr = builder.alloca(int_type, 1)
         builder.store(v_start, dest_index_ptr)
         with LoopCtx(self.name + "_set_slice", builder, size) as loop:
             loop_inc = builder.load(loop.inc)
             dest_index = builder.load(dest_index_ptr)
             src_nums = val.get_ele(loop_inc, builder)
             self._store_to_alloc(dest_index, src_nums, builder)
             builder.store(builder.add(dest_index, v_step), dest_index_ptr)
     elif isinstance(index, Node):
         with LoopCtx(self.name + "_set_slice", builder,
                      index.size) as loop:
             loop_inc = builder.load(loop.inc)
             dest_index = index.get_ele(loop_inc)[0]
             src_nums = val.get_ele(loop_inc, builder)
             self._store_to_alloc(dest_index, src_nums, builder)
     else:
         all_inds = builder.alloca(int_type, len(index))
         # TODO: change this to malloc function
         for i in range(len(index)):
             ind_ptr = builder.gep(all_inds, [ir.Constant(int_type, i)])
             builder.store(ir.Constant(int_type, index[i]), ind_ptr)
         with LoopCtx(self.name + "_set_slice", builder,
                      len(index)) as loop:
             loop_inc = builder.load(loop.inc)
             dest_index_ptr = builder.gep(all_inds, [loop_inc])
             dest_index = builder.load(dest_index_ptr)
             src_nums = val.get_ele(loop_inc)
             self._store_to_alloc(dest_index, src_nums, builder)
Exemple #5
0
    def _code_gen(self, builder: ir.IRBuilder) -> None:
        """Generating indexing llvm ir
        Note that indexing using numpy array is not recomended,
        because it generates static loops and will cause the generated
        llvm ir too large.
        """
        if isinstance(self.ind, slice):
            start, _, step = self.ind.indices(self.src.size)
            step_const = ir.Constant(int_type, step)
            src_index_ptr = builder.alloca(int_type, 1)
            builder.store(ir.Constant(int_type, start), src_index_ptr)

        with LoopCtx(self.name, builder, self.size) as loop:
            loop_inc = builder.load(loop.inc)
            if isinstance(self.ind, (ir.Constant, ir.Instruction)):
                src_index = self.ind
            elif isinstance(self.ind, slice):
                src_index = builder.load(src_index_ptr)
            elif isinstance(self.ind, Node):
                src_index = self.ind.get_ele(loop_inc, builder)[0]
            else:
                src_index_ptr = builder.gep(self.src_inds, [loop_inc])
                src_index = builder.load(src_index_ptr)

            src_nums = self.src.get_ele(src_index, builder)
            self._store_to_alloc(loop_inc, src_nums, builder)

            if isinstance(self.ind, slice):
                builder.store(builder.add(src_index, step_const),
                              src_index_ptr)
Exemple #6
0
def get_struct_pointer(scope_stack,
                       builder: ir.IRBuilder,
                       struct,
                       item,
                       index=int_type(0)):
    identifier = get_identifier(scope_stack, struct)
    assert identifier['type'] == 'struct'
    item = struct_map[identifier['val_type']][item]
    return builder.gep(identifier['value'], [index, item['index']], True)
Exemple #7
0
def impl_construct_dtype_on_stack(context: BaseContext, builder: ir.IRBuilder, sig, args):
    ty = sig.args[0].dtype_as_type()
    containing_size = find_size_for_dtype(sig.args[0].dtype)
    ptr = builder.alloca(ir.IntType(8), containing_size)
    for i, (name, mem_ty) in enumerate(ty.members):
        llvm_mem_ty = context.get_value_type(mem_ty)
        offset = ty.offset(name)
        v = builder.extract_value(args[1], i)
        v = context.cast(builder, v, sig.args[1][i], mem_ty)
        v_ptr_byte = builder.gep(ptr, (ir.Constant(ir.IntType(32), offset),), True)
        v_ptr = builder.bitcast(v_ptr_byte, llvm_mem_ty.as_pointer())
        builder.store(v, v_ptr)
    return ptr
Exemple #8
0
def string_to_volpe(b: ir.IRBuilder, string: ir.Value):
    with options(b, int64) as (ret, phi):
        ret(int64(0))

    character = b.load(b.gep(string, [phi]))
    with b.if_then(b.icmp_unsigned("!=", character, char(0))):
        ret(b.add(phi, int64(1)))

    new_string = string_type.unwrap()(ir.Undefined)
    new_string = b.insert_value(new_string, string, 0)
    new_string = b.insert_value(new_string, phi, 1)

    return new_string
Exemple #9
0
 def _code_gen(self, builder: ir.IRBuilder) -> None:
     mod = builder.block.module
     # instr = ir.values.Function(mod, self.ftype, self.func_name)
     input_type = []
     for in_node in self.SRC:
         input_type.append(type_map_llvm[in_node.dtype])
     instr = mod.declare_intrinsic(self.func_name, input_type)
     params = []
     with LoopCtx(self.name, builder, self.size) as loop:
         index = builder.load(loop.inc)
         data_ptr = builder.gep(self.alloc, [index])
         for n in self.SRC:
             params.append(n.get_ele(index, builder)[0])
         res = builder.call(instr, params)
         builder.store(res, data_ptr)
Exemple #10
0
    def _load_from_src(self, ind: Union[ir.Constant, ir.Instruction],
                       builder: ir.IRBuilder) -> List[ir.Instruction]:
        if isinstance(self.ind, (ir.Constant, ir.Instruction)):
            src_index = self.ind
        if isinstance(self.ind, slice):
            start, _, step = self.ind.indices(self.src.size)
            muled = builder.mul(ind, ir.Constant(int_type, step))
            src_index = builder.add(muled, ir.Constant(int_type, start))
        elif isinstance(self.ind, Node):
            src_index = self.ind.get_ele(ind, builder)[0]
        else:
            src_index_ptr = builder.gep(self.src_inds, [ind])
            src_index = builder.load(src_index_ptr)

        return self.src.get_ele(src_index, builder)
Exemple #11
0
def get_value(scope_stack, builder: ir.IRBuilder, item):
    if isinstance(item, tuple) and item[0] == 'id':
        identifier = get_identifier(scope_stack, item[1])
        item = identifier['value']
        if item.opname == 'alloca' and identifier['type'] != 'array':
            item = builder.load(item)
        elif item.opname == 'alloca' and identifier['type'] == 'array':
            item = builder.gep(item, [int_type(0), int_type(0)])
    elif isinstance(item, tuple) and item[0] == 'array_index':
        item = builder.load(item[1])
    elif isinstance(item, tuple) and item[0] == 'struct':
        item = get_struct_pointer(scope_stack, builder, item[1], item[2])
        item = builder.load(item)
    elif isinstance(item, tuple) and item[0] == 'struct_ptr':
        item = get_struct_ptr_pointer(scope_stack, builder, item[1], item[2])
        item = builder.load(item)
    return item
Exemple #12
0
    def code_gen(self, builder: ir.IRBuilder) -> None:
        """Wrapper of the code_gen function,
        Rcursively generates it's dependences and call itselves _code_gen core
        Note that only Lvalue node need generate llvm ir
        """
        if self.gened:
            return
        for dep in self.dependence:
            dep.code_gen(builder)
        self.gened = True

        if isinstance(self.ind, (list, tuple, np.ndarray)):
            for i in range(len(self.ind)):
                index = ir.Constant(int_type, self.ind[i])
                builder.store(
                    index,
                    builder.gep(self.src_inds, [ir.Constant(int_type, i)]))
        if self.vtype == LRValue.LEFT:
            self._code_gen(builder)
Exemple #13
0
def insert_string(module: ir.Module,
                  scope_stack: list,
                  builder: ir.IRBuilder,
                  raw_data: str,
                  prt=False):
    try:
        global_string = scope_stack[0][raw_data]
    except KeyError:
        data = eval('%s' % raw_data)
        data += '\00'
        data = data.encode()
        str_type = ir.ArrayType(char_type, len(data))
        const_string = ir.Constant(str_type, bytearray(data))
        global_string = ir.GlobalVariable(module, str_type,
                                          module.get_unique_name(raw_data))
        global_string.initializer = const_string
        scope_stack[0][raw_data] = global_string
    if prt:
        return builder.gep(global_string, [int_type(0), int_type(0)], True)
    else:
        return global_string
Exemple #14
0
    def compile(self, module: ir.Module, builder: ir.IRBuilder,
                symbols: SymbolTable) -> ir.Value:
        if self.parent is None:
            return symbols.get_symbol(self.ID)

        else:
            ret = self.parent.compile(module, builder, symbols)
            rtype = ret.type

            if isinstance(rtype, ir.PointerType):
                rtype = rtype.pointee

            if isinstance(rtype, ir.IdentifiedStructType):
                index = symbols.get_elements(rtype.name).index(self.ID)
                ret = builder.gep(
                    ret, [make_constant(IntType, i) for i in (0, index)])

            rtype = ret.type

            if isinstance(rtype, ir.PointerType) and (
                    rtype.pointee in [i.ir_type for i in PrimitiveTypes]):
                ret = builder.load(ret)

            return ret
Exemple #15
0
class GeneratorVisitor(SmallCVisitor):
    def __init__(self, output_file=None):
        super(SmallCVisitor, self).__init__()
        self.Module = Module(name=__file__)
        self.Module.triple = "x86_64-pc-linux-gnu"
        self.Builder = None
        self.function = None
        self.NamedValues = dict()
        self.counter = 0
        self.loop_stack = []
        self.signal_stack = []
        self.cond_stack = []
        self.var_stack = []
        self.cur_decl_type = None

        self.indentation = 0
        self.function_dict = dict()
        self.error_queue = list()
        self.output_file = output_file

    def print(self):
        if not self.output_file:
            print(self.Module)
        else:
            f = open(self.output_file, "w+")
            f.write(self.Module.__str__())
            f.close()

    def error(self, info):
        print("Error: ", info)
        return 0

    def toBool(self, value):
        zero = Constant(self.getType('bool'), 0)
        return self.Builder.icmp_signed('!=', value, zero)

    def getVal_of_expr(self, expr):
        temp = self.visit(expr)
        if isinstance(temp, Constant) or isinstance(
                temp, CallInstr) or isinstance(temp, LoadInstr) or isinstance(
                    temp, Instruction) or isinstance(temp, GlobalVariable):
            value = temp
        else:
            temp_val = self.getVal_local(temp.IDENTIFIER().getText())
            temp_ptr = temp_val['ptr']
            if temp.array_indexing():
                index = self.getVal_of_expr(temp.array_indexing().expr())
                temp_ptr = self.Builder.gep(temp_ptr,
                                            [Constant(IntType(32), 0), index],
                                            inbounds=True)
            # if isinstance(temp_val['type'],ArrayType):
            #     if temp.array_indexing():
            #         index = self.getVal_of_expr(temp.array_indexing().expr())
            #         temp_ptr = self.Builder.gep(temp_ptr, [Constant(IntType(32), 0), index], inbounds=True)
            #     elif temp.AMPERSAND():
            #         Constant(PointerType(IntType(8)), temp_ptr.getText())
            #     elif temp.ASTERIKS():
            #         pass
            #     else: #返回数组地址
            #         temp_ptr = self.Builder.gep(temp_ptr, [Constant(IntType(32), 0), Constant(IntType(32), 0)], inbounds=True)
            #         return temp_ptr
            value = self.Builder.load(temp_ptr)
        return value

    def getType(self, type):
        if type == 'int':
            return IntType(32)
        elif type == 'char':
            return IntType(8)
        elif type == 'float':
            return FloatType()
        elif type == 'bool':
            return IntType(1)
        elif type == 'void':
            return VoidType()
        else:
            self.error("type error in <getType>")

    def getVal_local(self, id):
        temp_maps = self.var_stack[::-1]
        for map in temp_maps:
            if id in map.keys():
                return map[id]
        self.error("value error in <getVal_local>")
        return None

    def visitFunction_definition(self,
                                 ctx: SmallCParser.Function_definitionContext):
        retType = self.getType(ctx.type_specifier().getText())
        if ctx.identifier().ASTERIKS():
            retType = retType.as_pointer()
        argsType = []
        argsName = []
        # args
        if ctx.param_decl_list():
            args = ctx.param_decl_list()
            var_arg = False
            for t in args.getChildren():
                if t.getText() != ',':
                    if t.getText() == '...':
                        var_arg = True
                        break
                    t_type = self.getType(t.type_specifier().getText())
                    if t.identifier().ASTERIKS():
                        t_type = t_type.as_pointer()
                    argsType.append(t_type)
                    argsName.append(t.identifier().IDENTIFIER().getText())
            funcType = FunctionType(retType, tuple(argsType), var_arg=var_arg)
        # no args
        else:
            funcType = FunctionType(retType, ())

        # function
        if ctx.identifier().IDENTIFIER().getText() in self.function_dict:
            func = self.function_dict[ctx.identifier().IDENTIFIER().getText()]
        else:
            func = Function(self.Module,
                            funcType,
                            name=ctx.identifier().IDENTIFIER().getText())
            self.function_dict[ctx.identifier().IDENTIFIER().getText()] = func
        # blocks or ;
        if ctx.compound_stmt():
            self.function = ctx.identifier().IDENTIFIER().getText()
            block = func.append_basic_block(
                ctx.identifier().IDENTIFIER().getText())
            varDict = dict()
            self.Builder = IRBuilder(block)
            for i, arg in enumerate(func.args):
                arg.name = argsName[i]
                alloca = self.Builder.alloca(arg.type, name=arg.name)
                self.Builder.store(arg, alloca)
                varDict[arg.name] = {
                    "id": arg.name,
                    "type": arg.type,
                    "value": None,
                    "ptr": alloca
                }
            self.var_stack.append(varDict)
            self.visit(ctx.compound_stmt())
            if isinstance(retType, VoidType):
                self.Builder.ret_void()
            self.var_stack.pop()
            self.function = None
        return

    def visitFunctioncall(self, ctx: SmallCParser.FunctioncallContext):
        var_map = self.var_stack[-1]
        function = self.function_dict[ctx.identifier().getText()]
        arg_types = function.args
        index = 0

        args = []
        if ctx.param_list():
            for param in ctx.param_list().getChildren():
                if (param.getText() == ','):
                    continue

                temp = self.getVal_of_expr(param)

                arg_type = None
                if index < len(arg_types):
                    arg_type = arg_types[index]
                ptr_flag = False
                if arg_type:
                    if isinstance(arg_type.type, PointerType):
                        ptr_flag = True
                elif self.getVal_local(temp.name):
                    temp_type = self.getVal_local(temp.name)['type']
                    if isinstance(temp_type, PointerType):
                        ptr_flag = True
                if not ptr_flag and not isinstance(temp, Constant):
                    temp = self.Builder.load(temp)
                args.append(temp)

                index += 1

        return self.Builder.call(function, args)

    # def visitVar_decl(self, ctx: SmallCParser.Var_declContext):
    #     type = self.getType(ctx.type_specifier())
    #     list = ctx.var_decl_list()
    #     for var in list.getChildren():
    #         if var.getText() != ',':
    #             if self.builder:
    #                 alloca = self.builder.alloca(type, name=var.identifier().getText())
    #                 self.builder.store(Constant(type, None), alloca)
    #                 self.var_stack[-1][var.identifier().getText()] = alloca
    #             else:
    #                 g_var = GlobalVariable(self.Module, type, var.identifier().getText())
    #                 g_var.initializer = Constant(type, None)
    #     return

    def visitStmt(self, ctx: SmallCParser.StmtContext):
        if ctx.RETURN():
            value = self.getVal_of_expr(ctx.expr())
            if isinstance(value.type, PointerType):
                value = self.Builder.load(value)
            return self.Builder.ret(value)
        elif ctx.CONTINUE():
            self.signal_stack[-1] = 1
            loop_blocks = self.loop_stack[-1]
            self.Builder.branch(loop_blocks['continue'])
            self.Builder.position_at_start(loop_blocks['buf'])
            return None
        elif ctx.BREAK():
            self.signal_stack[-1] = -1
            loop_blocks = self.loop_stack[-1]
            self.Builder.branch(loop_blocks['break'])
            self.Builder.position_at_start(loop_blocks['buf'])
        else:
            return self.visitChildren(ctx)

    def visitCompound_stmt(self, ctx: SmallCParser.Compound_stmtContext):
        # builder = IRBuilder(self.block_stack[-1])
        # block = self.Builder.append_basic_block()
        # self.block_stack.append(block)
        # with self.Builder.goto_block(block):
        result = self.visitChildren(ctx)
        # self.block_stack.pop()
        return result

    def visitAssignment(self, ctx: SmallCParser.AssignmentContext):
        value = self.getVal_of_expr(ctx.expr())
        identifier = ctx.identifier()
        identifier = self.getVal_local(identifier.IDENTIFIER().getText())
        if isinstance(identifier['type'], ArrayType):
            if ctx.identifier().array_indexing():
                index = self.getVal_of_expr(
                    ctx.identifier().array_indexing().expr())
                if isinstance(index.type, PointerType):
                    index = self.Builder.load(index)
            else:
                index = Constant(IntType(32), 0)
            tempPtr = self.Builder.gep(identifier['ptr'],
                                       [Constant(IntType(32), 0), index],
                                       inbounds=True)
            if isinstance(value.type, PointerType):
                value = self.Builder.load(value)
            return self.Builder.store(value, tempPtr)
        if isinstance(value.type, PointerType):
            value = self.Builder.load(value)
        return self.Builder.store(value, identifier['ptr'])

    def visitExpr(self, ctx: SmallCParser.ExprContext):
        if ctx.condition():
            return self.visit(ctx.condition())
        elif ctx.assignment():
            return self.visit(ctx.assignment())
        elif ctx.functioncall():
            return self.visit(ctx.functioncall())

    def visitCondition(self, ctx: SmallCParser.ConditionContext):
        if ctx.expr():
            disjunction = self.getVal_of_expr(ctx.disjunction())
            if isinstance(disjunction.type, PointerType):
                disjunction = self.Builder.load(disjunction)
            cond = self.Builder.icmp_signed('!=', disjunction,
                                            Constant(disjunction.type, 0))
            expr = self.getVal_of_expr(ctx.expr())
            if isinstance(expr.type, PointerType):
                expr = self.Builder.load(expr)
            condition = self.getVal_of_expr(ctx.condition())
            return self.Builder.select(cond, expr, condition)
        else:
            return self.getVal_of_expr(ctx.disjunction())

    def visitDisjunction(self, ctx: SmallCParser.DisjunctionContext):
        if ctx.disjunction():
            disjunction = self.getVal_of_expr(ctx.disjunction())
            if isinstance(disjunction.type, PointerType):
                disjunction = self.Builder.load(disjunction)
            conjunction = self.getVal_of_expr(ctx.conjunction())
            if isinstance(conjunction.type, PointerType):
                conjunction = self.Builder.load(conjunction)
            left = self.Builder.icmp_signed('!=', disjunction,
                                            Constant(disjunction.type, 0))
            right = self.Builder.icmp_signed('!=', conjunction,
                                             Constant(conjunction.type, 0))
            return self.Builder.or_(left, right)
        else:
            return self.getVal_of_expr(ctx.conjunction())

    def visitConjunction(self, ctx: SmallCParser.ConjunctionContext):
        if ctx.conjunction():
            conjunction = self.getVal_of_expr(ctx.conjunction())
            if isinstance(conjunction.type, PointerType):
                conjunction = self.Builder.load(conjunction)
            comparison = self.getVal_of_expr(ctx.comparison())
            if isinstance(comparison.type, PointerType):
                comparison = self.Builder.load(comparison)
            left = self.Builder.icmp_signed('!=', conjunction,
                                            Constant(conjunction.type, 0))
            right = self.Builder.icmp_signed('!=', comparison,
                                             Constant(comparison.type, 0))
            return self.Builder.and_(left, right)
        else:
            return self.getVal_of_expr(ctx.comparison())

    def visitComparison(self, ctx: SmallCParser.ComparisonContext):
        if ctx.EQUALITY():
            relation1 = self.getVal_of_expr(ctx.relation(0))
            if isinstance(relation1.type, PointerType):
                relation1 = self.Builder.load(relation1)
            relation2 = self.getVal_of_expr(ctx.relation(1))
            if isinstance(relation2.type, PointerType):
                relation2 = self.Builder.load(relation2)
            return self.Builder.icmp_signed('==', relation1, relation2)
        elif ctx.NEQUALITY():
            relation1 = self.getVal_of_expr(ctx.relation(0))
            if isinstance(relation1.type, PointerType):
                relation1 = self.Builder.load(relation1)
            relation2 = self.getVal_of_expr(ctx.relation(1))
            if isinstance(relation2.type, PointerType):
                relation2 = self.Builder.load(relation2)
            return self.Builder.icmp_signed('!=', relation1, relation2)
        else:
            return self.getVal_of_expr(ctx.relation(0))

    def visitRelation(self, ctx: SmallCParser.RelationContext):
        if len(ctx.equation()) > 1:
            equation1 = self.getVal_of_expr(ctx.equation(0))
            if isinstance(equation1.type, PointerType):
                equation1 = self.Builder.load(equation1)
            equation2 = self.getVal_of_expr(ctx.equation(1))
            if isinstance(equation2.type, PointerType):
                equation2 = self.Builder.load(equation2)
            if ctx.LEFTANGLE():
                value = self.Builder.icmp_signed('<', equation1, equation2)
            elif ctx.RIGHTANGLE():
                value = self.Builder.icmp_signed('>', equation1, equation2)
            elif ctx.LEFTANGLEEQUAL():
                value = self.Builder.icmp_signed('<=', equation1, equation2)
            elif ctx.RIGHTANGLEEQUAL():
                value = self.Builder.icmp_signed('>=', equation1, equation2)
            return value
        else:
            return self.getVal_of_expr(ctx.equation(0))

    def visitFor_stmt(self, ctx: SmallCParser.For_stmtContext):
        func = self.function_dict[self.function]

        end_block = func.append_basic_block()
        self.var_stack.append({})
        decl_block = func.append_basic_block()
        self.var_stack.append({})
        cond_block = func.append_basic_block()
        self.var_stack.append({})
        stmt_block = func.append_basic_block()
        self.var_stack.append({})
        loop_block = func.append_basic_block()
        # 1 -> continue, -1 -> break
        self.signal_stack.append(0)

        self.loop_stack.append({
            'continue': cond_block,
            'break': end_block,
            'buf': loop_block
        })

        with self.Builder.goto_block(decl_block):
            # self.Builder.position_at_start(end_block)
            if ctx.var_decl():
                self.visit(ctx.var_decl())
            elif ctx.var_decl_list():
                self.visit(ctx.var_decl_list())
            else:
                self.error("for error in <visitFor_stmt>")
            self.Builder.branch(cond_block)

        self.Builder.branch(decl_block)
        with self.Builder.goto_block(cond_block):
            # cond_expr
            cond_expr = ctx.expr(0)
            cond_expr = self.visit(cond_expr)
            cond_expr = self.toBool(cond_expr)
            self.Builder.cbranch(cond_expr, stmt_block, end_block)

        with self.Builder.goto_block(stmt_block):
            # expr
            self.visit(ctx.stmt())
            expr = ctx.expr(1)
            self.visit(expr)
            self.Builder.branch(cond_block)
            if self.signal_stack[-1] == 0:
                loop_blocks = self.loop_stack[-1]
                self.Builder.position_at_start(loop_blocks['buf'])
                self.Builder.branch(end_block)

        self.Builder.position_at_start(end_block)

        self.loop_stack.pop()
        self.signal_stack.pop()

        self.var_stack.pop()
        self.var_stack.pop()
        self.var_stack.pop()
        self.var_stack.pop()

    def visitWhile_stmt(self, ctx: SmallCParser.While_stmtContext):
        func = self.function_dict[self.function]

        end_block = func.append_basic_block()
        self.var_stack.append({})
        cond_block = func.append_basic_block()
        self.var_stack.append({})
        stmt_block = func.append_basic_block()
        self.var_stack.append({})
        loop_block = func.append_basic_block()
        # 1 -> continue, -1 -> break
        self.signal_stack.append(0)

        self.loop_stack.append({
            'continue': cond_block,
            'break': end_block,
            'buf': loop_block
        })

        self.Builder.branch(cond_block)

        with self.Builder.goto_block(cond_block):
            expr = self.getVal_of_expr(ctx.expr())
            cond_expr = self.toBool(expr)
            self.Builder.cbranch(cond_expr, stmt_block, end_block)

        with self.Builder.goto_block(stmt_block):
            self.visit(ctx.stmt())
            self.Builder.branch(cond_block)
            if self.signal_stack[-1] == 0:
                loop_blocks = self.loop_stack[-1]
                self.Builder.position_at_start(loop_blocks['buf'])
                self.Builder.branch(end_block)

        self.Builder.position_at_start(end_block)

        self.loop_stack.pop()
        self.signal_stack.pop()

        self.var_stack.pop()
        self.var_stack.pop()
        self.var_stack.pop()

    def visitCond_stmt(self, ctx: SmallCParser.Cond_stmtContext):

        expr = self.getVal_of_expr(ctx.expr())

        cond_expr = self.toBool(expr)
        else_expr = ctx.ELSE()

        if else_expr:
            with self.Builder.if_else(cond_expr) as (then, otherwise):
                with then:
                    self.var_stack.append({})
                    true_stmt = ctx.stmt(0)
                    self.visit(true_stmt)
                    self.var_stack.pop()
                with otherwise:
                    self.var_stack.append({})
                    else_stmt = ctx.stmt(1)
                    self.visit(else_stmt)
                    self.var_stack.pop()
        else:
            with self.Builder.if_then(cond_expr):
                self.var_stack.append({})
                true_stmt = ctx.stmt(0)
                self.visit(true_stmt)
                self.var_stack.pop()
        return None

    def visitVar_decl(self, ctx: SmallCParser.Var_declContext):
        self.cur_decl_type = self.getType(ctx.type_specifier().getText())
        return self.visitChildren(ctx)

    def visitVar_decl_list(self, ctx: SmallCParser.Var_decl_listContext):
        ans = []
        decls = ctx.variable_id()
        for decl in decls:
            ans.append(self.visit(decl))
        return ans

    def visitVariable_id(self, ctx: SmallCParser.Variable_idContext):
        identifier = ctx.identifier()
        type = self.cur_decl_type

        if not self.function:
            if identifier.array_indexing():
                length = self.getVal_of_expr(
                    identifier.array_indexing().expr())
                type = ArrayType(type, length.constant)
                g_var = GlobalVariable(self.Module, type,
                                       identifier.IDENTIFIER().getText())
            else:
                g_var = GlobalVariable(self.Module, type,
                                       identifier.IDENTIFIER().getText())

            g_var.initializer = Constant(type, None)
            atomic = {
                "id": identifier.IDENTIFIER().getText(),
                "type": type,
                "value": None,
                "ptr": g_var
            }
            if not len(self.var_stack):
                self.var_stack.append({})
            self.var_stack[0][identifier.IDENTIFIER().getText()] = atomic
            return g_var

        var_map = self.var_stack[-1]

        if identifier.array_indexing():
            length = self.getVal_of_expr(identifier.array_indexing().expr())
            type = ArrayType(type, length.constant)
            ptr = self.Builder.alloca(typ=type,
                                      name=identifier.IDENTIFIER().getText())
        else:
            ptr = self.Builder.alloca(typ=type,
                                      name=identifier.IDENTIFIER().getText())

        expr = ctx.expr()
        if expr:
            value = self.getVal_of_expr(expr)
        else:
            value = Constant(type, None)
        if isinstance(value.type, PointerType):
            value = self.Builder.load(value)
        self.Builder.store(value, ptr)
        var_map[identifier.IDENTIFIER().getText()] = {
            "id": identifier.IDENTIFIER().getText(),
            "type": type,
            "value": value,
            "ptr": ptr
        }
        return ptr

    def visitPrimary(self, ctx: SmallCParser.PrimaryContext):
        if ctx.BOOLEAN():
            return Constant(IntType(1), bool(ctx.getText()))
        elif ctx.INTEGER():
            return Constant(IntType(32), int(ctx.getText()))
        elif ctx.REAL():
            return Constant(FloatType, float(ctx.getText()))
        elif ctx.CHARCONST():
            tempStr = ctx.getText()[1:-1]
            tempStr = tempStr.replace('\\n', '\n')
            tempStr = tempStr.replace('\\0', '\0')
            if ctx.getText()[0] == '"':
                tempStr += '\0'
                temp = GlobalVariable(self.Module,
                                      ArrayType(IntType(8), len(tempStr)),
                                      name="str_" + tempStr[:-1] +
                                      str(self.counter))
                self.counter += 1
                temp.initializer = Constant(
                    ArrayType(IntType(8), len(tempStr)),
                    bytearray(tempStr, encoding='utf-8'))
                temp.global_constant = True
                return self.Builder.gep(
                    temp, [Constant(IntType(32), 0),
                           Constant(IntType(32), 0)],
                    inbounds=True)
            return Constant(IntType(8), ord(tempStr[0]))
        elif ctx.identifier():
            return self.visit(ctx.identifier())
        elif ctx.functioncall():
            return self.visit(ctx.functioncall())
        elif ctx.expr():
            return self.visit(ctx.expr())
        else:
            return self.error("type error in <visitPrimary>")

    def visitFactor(self, ctx: SmallCParser.FactorContext):
        if (ctx.MINUS()):
            factor = self.getVal_of_expr(ctx.factor())
            if isinstance(factor.type, PointerType):
                factor = self.Builder.load(factor)
            factor = self.Builder.neg(factor)
            return factor
        return self.visitChildren(ctx)

    def visitTerm(self, ctx: SmallCParser.TermContext):
        if (ctx.ASTERIKS()):
            term = self.getVal_of_expr(ctx.term())
            if isinstance(term.type, PointerType):
                term = self.Builder.load(term)
            factor = self.getVal_of_expr(ctx.factor())
            if isinstance(factor.type, PointerType):
                factor = self.Builder.load(factor)
            return self.Builder.mul(term, factor)
        if (ctx.SLASH()):
            term = self.getVal_of_expr(ctx.term())
            if isinstance(term.type, PointerType):
                term = self.Builder.load(term)
            factor = self.getVal_of_expr(ctx.factor())
            if isinstance(factor.type, PointerType):
                factor = self.Builder.load(factor)
            return self.Builder.sdiv(term, factor)
        return self.visitChildren(ctx)

    def visitEquation(self, ctx: SmallCParser.EquationContext):
        if (ctx.PLUS()):
            equation = self.getVal_of_expr(ctx.equation())
            if isinstance(equation.type, PointerType):
                equation = self.Builder.load(equation)
            if str(equation.type) != 'i32':
                equation = self.Builder.zext(equation, IntType(32))
            term = self.getVal_of_expr(ctx.term())
            if isinstance(term.type, PointerType):
                term = self.Builder.load(term)
            if str(term.type) != 'i32':
                term = self.Builder.zext(term, IntType(32))
            return self.Builder.add(equation, term)
        if (ctx.MINUS()):
            equation = self.getVal_of_expr(ctx.equation())
            if isinstance(equation.type, PointerType):
                equation = self.Builder.load(equation)
            if str(equation.type) != 'i32':
                equation = self.Builder.zext(equation, IntType(32))
            term = self.getVal_of_expr(ctx.term())
            if isinstance(term.type, PointerType):
                term = self.Builder.load(term)
            if str(term.type) != 'i32':
                term = self.Builder.zext(term, IntType(32))
            return self.Builder.sub(equation, term)
        return self.visitChildren(ctx)

    def visitIdentifier(self, ctx: SmallCParser.IdentifierContext):
        if (ctx.AMPERSAND() and ctx.array_indexing()):
            return self.Builder.gep(self.getVal_of_expr(ctx.IDENTIFIER()),
                                    self.getVal_of_expr(ctx.array_indexing()))
        if (ctx.ASTERIKS() and ctx.array_indexing()):
            return self.Builder.load(
                self.Builder.gep(self.getVal_of_expr(ctx.IDENTIFIER()),
                                 self.getVal_of_expr(ctx.array_indexing())))
        if (ctx.AMPERSAND()):
            return self.getVal_local(str(ctx.IDENTIFIER()))['ptr']
        if (ctx.ASTERIKS()):
            return self.getVal_local(str(ctx.IDENTIFIER()))['ptr']
        temp = self.getVal_local(ctx.IDENTIFIER().getText())
        temp_ptr = temp['ptr']
        # if isinstance(temp_val['type'],ArrayType):
        #     if temp.array_indexing():
        #         index = self.getVal_of_expr(temp.array_indexing().expr())
        #         temp_ptr = self.Builder.gep(temp_ptr, [Constant(IntType(32), 0), index], inbounds=True)
        #     elif temp.AMPERSAND():
        #         Constant(PointerType(IntType(8)), temp_ptr.getText())
        #     elif temp.ASTERIKS():
        #         pass
        #     else: #返回数组地址
        #         temp_ptr = self.Builder.gep(temp_ptr, [Constant(IntType(32), 0), Constant(IntType(32), 0)], inbounds=True)
        #         return temp_ptr
        if isinstance(temp['type'], ArrayType):
            if ctx.array_indexing():
                index = self.getVal_of_expr(ctx.array_indexing().expr())
                if isinstance(index.type, PointerType):
                    index = self.Builder.load(index)
                temp_ptr = self.Builder.gep(temp_ptr,
                                            [Constant(IntType(32), 0), index],
                                            inbounds=True)
                value = self.Builder.load(temp_ptr)
                self.var_stack[-1][temp_ptr.name] = {
                    'id': temp_ptr.name,
                    'type': temp_ptr.type,
                    'value': value,
                    'ptr': temp_ptr
                }
                return temp_ptr
            else:
                temp_ptr = self.Builder.gep(
                    temp_ptr,
                    [Constant(IntType(32), 0),
                     Constant(IntType(32), 0)],
                    inbounds=True)
                value = self.Builder.load(temp_ptr)
                self.var_stack[-1][temp_ptr.name] = {
                    'id': temp_ptr.name,
                    'type': temp_ptr.type,
                    'value': value,
                    'ptr': temp_ptr
                }
                return temp_ptr
        temp_val = temp['ptr']
        return temp_val

    def visitArray_indexing(self, ctx: SmallCParser.Array_indexingContext):
        return self.visit(ctx.expr())
Exemple #16
0
    def _build_wrapper(self, library, name):
        """
        The LLVM IRBuilder code to create the gufunc wrapper.
        The *library* arg is the CodeLibrary to which the wrapper should
        be added.  The *name* arg is the name of the wrapper function being
        created.
        """
        intp_t = self.context.get_value_type(types.intp)
        fnty = self._wrapper_function_type()

        wrapper_module = library.create_ir_module('_gufunc_wrapper')
        func_type = self.call_conv.get_function_type(self.fndesc.restype,
                                                     self.fndesc.argtypes)
        fname = self.fndesc.llvm_func_name
        func = ir.Function(wrapper_module, func_type, name=fname)

        func.attributes.add("alwaysinline")
        wrapper = ir.Function(wrapper_module, fnty, name)
        # The use of weak_odr linkage avoids the function being dropped due
        # to the order in which the wrappers and the user function are linked.
        wrapper.linkage = 'weak_odr'
        arg_args, arg_dims, arg_steps, arg_data = wrapper.args
        arg_args.name = "args"
        arg_dims.name = "dims"
        arg_steps.name = "steps"
        arg_data.name = "data"

        builder = IRBuilder(wrapper.append_basic_block("entry"))
        loopcount = builder.load(arg_dims, name="loopcount")
        pyapi = self.context.get_python_api(builder)

        # Unpack shapes
        unique_syms = set()
        for grp in (self.sin, self.sout):
            for syms in grp:
                unique_syms |= set(syms)

        sym_map = {}
        for syms in self.sin:
            for s in syms:
                if s not in sym_map:
                    sym_map[s] = len(sym_map)

        sym_dim = {}
        for s, i in sym_map.items():
            sym_dim[s] = builder.load(
                builder.gep(arg_dims,
                            [self.context.get_constant(types.intp, i + 1)]))

        # Prepare inputs
        arrays = []
        step_offset = len(self.sin) + len(self.sout)
        for i, (typ, sym) in enumerate(
                zip(self.signature.args, self.sin + self.sout)):
            ary = GUArrayArg(self.context, builder, arg_args, arg_steps, i,
                             step_offset, typ, sym, sym_dim)
            step_offset += len(sym)
            arrays.append(ary)

        bbreturn = builder.append_basic_block('.return')

        # Prologue
        self.gen_prologue(builder, pyapi)

        # Loop
        with cgutils.for_range(builder, loopcount, intp=intp_t) as loop:
            args = [a.get_array_at_offset(loop.index) for a in arrays]
            innercall, error = self.gen_loop_body(builder, pyapi, func, args)
            # If error, escape
            cgutils.cbranch_or_continue(builder, error, bbreturn)

        builder.branch(bbreturn)
        builder.position_at_end(bbreturn)

        # Epilogue
        self.gen_epilogue(builder, pyapi)

        builder.ret_void()

        # Link
        library.add_ir_module(wrapper_module)
        library.add_linking_library(self.library)