def get_array_pointer(scope_stack, builder: ir.IRBuilder, array, index): if isinstance(array, tuple) and array[0] == 'struct': item = get_struct_pointer(scope_stack, builder, array[1], array[2]) elif isinstance(array, tuple) and array[0] == 'struct_ptr': item = get_struct_ptr_pointer(scope_stack, builder, array[1], array[2]) else: identifier = None if isinstance(array, tuple): identifier = get_identifier(scope_stack, array[1]) item = get_pointer(scope_stack, builder, array) if identifier and identifier['type'] == 'val_ptr': item = builder.load(item) return builder.gep(item, [index], True) return builder.gep(item, [int_type(0), index], True)
def _store_to_alloc(self, index: Union[ir.Constant, ir.Instruction], src_nums: List[ir.LoadInstr], builder: ir.IRBuilder) -> None: if not self.dtype in (DType.Complx, DType.DComplx): dest_ptr = builder.gep(self.alloc, [index]) builder.store(src_nums[0], dest_ptr) else: cmplx_dest_index = builder.mul(index, ir.Constant(int_type, 2)) dest_ptr_r = builder.gep(self.alloc, [cmplx_dest_index]) builder.store(src_nums[0], dest_ptr_r) cmplx_dest_index = builder.add(cmplx_dest_index, ir.Constant(int_type, 1)) dest_ptr_i = builder.gep(self.alloc, [cmplx_dest_index]) builder.store(src_nums[1], dest_ptr_i)
def _load_from_alloc(self, index: Union[ir.Constant, ir.Instruction], builder: ir.IRBuilder) -> List[ir.LoadInstr]: if not self.dtype in (DType.Complx, DType.DComplx): self_ptr = builder.gep(self.alloc, [index]) products = [builder.load(self_ptr)] else: cmplx_index = builder.mul(index, ir.Constant(int_type, 2)) self_ptr_real = builder.gep(self.alloc, [cmplx_index]) product_real = builder.load(self_ptr_real) cmplx_index = builder.add(cmplx_index, ir.Constant(int_type, 1)) self_ptr_imag = builder.gep(self.alloc, [cmplx_index]) product_imag = builder.load(self_ptr_imag) products = [product_real, product_imag] return products
def _gen_setitem(self, builder: ir.IRBuilder, index: Union[int, Node, slice, list, tuple, np.ndarray], val: Node) -> None: """When set index to one node, it must be LValue node, if not, the graph maintainer should modify its vtype to LEFT. Also, only when the array is required, will it be generated. """ if isinstance(index, int): const0 = ir.Constant(int_type, 0) src_nums = val.get_ele(const0, builder) if val.dtype != self.dtype: src_nums = build_type_cast(builder, src_nums, val.dtype, self.dtype) index = ir.Constant(int_type, index) self._store_to_alloc(index, src_nums, builder) elif isinstance(index, slice): size = compute_size(index, self.size) start, _, step = index.indices(val.size) v_start = ir.Constant(int_type, start) v_step = ir.Constant(int_type, step) dest_index_ptr = builder.alloca(int_type, 1) builder.store(v_start, dest_index_ptr) with LoopCtx(self.name + "_set_slice", builder, size) as loop: loop_inc = builder.load(loop.inc) dest_index = builder.load(dest_index_ptr) src_nums = val.get_ele(loop_inc, builder) self._store_to_alloc(dest_index, src_nums, builder) builder.store(builder.add(dest_index, v_step), dest_index_ptr) elif isinstance(index, Node): with LoopCtx(self.name + "_set_slice", builder, index.size) as loop: loop_inc = builder.load(loop.inc) dest_index = index.get_ele(loop_inc)[0] src_nums = val.get_ele(loop_inc, builder) self._store_to_alloc(dest_index, src_nums, builder) else: all_inds = builder.alloca(int_type, len(index)) # TODO: change this to malloc function for i in range(len(index)): ind_ptr = builder.gep(all_inds, [ir.Constant(int_type, i)]) builder.store(ir.Constant(int_type, index[i]), ind_ptr) with LoopCtx(self.name + "_set_slice", builder, len(index)) as loop: loop_inc = builder.load(loop.inc) dest_index_ptr = builder.gep(all_inds, [loop_inc]) dest_index = builder.load(dest_index_ptr) src_nums = val.get_ele(loop_inc) self._store_to_alloc(dest_index, src_nums, builder)
def _code_gen(self, builder: ir.IRBuilder) -> None: """Generating indexing llvm ir Note that indexing using numpy array is not recomended, because it generates static loops and will cause the generated llvm ir too large. """ if isinstance(self.ind, slice): start, _, step = self.ind.indices(self.src.size) step_const = ir.Constant(int_type, step) src_index_ptr = builder.alloca(int_type, 1) builder.store(ir.Constant(int_type, start), src_index_ptr) with LoopCtx(self.name, builder, self.size) as loop: loop_inc = builder.load(loop.inc) if isinstance(self.ind, (ir.Constant, ir.Instruction)): src_index = self.ind elif isinstance(self.ind, slice): src_index = builder.load(src_index_ptr) elif isinstance(self.ind, Node): src_index = self.ind.get_ele(loop_inc, builder)[0] else: src_index_ptr = builder.gep(self.src_inds, [loop_inc]) src_index = builder.load(src_index_ptr) src_nums = self.src.get_ele(src_index, builder) self._store_to_alloc(loop_inc, src_nums, builder) if isinstance(self.ind, slice): builder.store(builder.add(src_index, step_const), src_index_ptr)
def get_struct_pointer(scope_stack, builder: ir.IRBuilder, struct, item, index=int_type(0)): identifier = get_identifier(scope_stack, struct) assert identifier['type'] == 'struct' item = struct_map[identifier['val_type']][item] return builder.gep(identifier['value'], [index, item['index']], True)
def impl_construct_dtype_on_stack(context: BaseContext, builder: ir.IRBuilder, sig, args): ty = sig.args[0].dtype_as_type() containing_size = find_size_for_dtype(sig.args[0].dtype) ptr = builder.alloca(ir.IntType(8), containing_size) for i, (name, mem_ty) in enumerate(ty.members): llvm_mem_ty = context.get_value_type(mem_ty) offset = ty.offset(name) v = builder.extract_value(args[1], i) v = context.cast(builder, v, sig.args[1][i], mem_ty) v_ptr_byte = builder.gep(ptr, (ir.Constant(ir.IntType(32), offset),), True) v_ptr = builder.bitcast(v_ptr_byte, llvm_mem_ty.as_pointer()) builder.store(v, v_ptr) return ptr
def string_to_volpe(b: ir.IRBuilder, string: ir.Value): with options(b, int64) as (ret, phi): ret(int64(0)) character = b.load(b.gep(string, [phi])) with b.if_then(b.icmp_unsigned("!=", character, char(0))): ret(b.add(phi, int64(1))) new_string = string_type.unwrap()(ir.Undefined) new_string = b.insert_value(new_string, string, 0) new_string = b.insert_value(new_string, phi, 1) return new_string
def _code_gen(self, builder: ir.IRBuilder) -> None: mod = builder.block.module # instr = ir.values.Function(mod, self.ftype, self.func_name) input_type = [] for in_node in self.SRC: input_type.append(type_map_llvm[in_node.dtype]) instr = mod.declare_intrinsic(self.func_name, input_type) params = [] with LoopCtx(self.name, builder, self.size) as loop: index = builder.load(loop.inc) data_ptr = builder.gep(self.alloc, [index]) for n in self.SRC: params.append(n.get_ele(index, builder)[0]) res = builder.call(instr, params) builder.store(res, data_ptr)
def _load_from_src(self, ind: Union[ir.Constant, ir.Instruction], builder: ir.IRBuilder) -> List[ir.Instruction]: if isinstance(self.ind, (ir.Constant, ir.Instruction)): src_index = self.ind if isinstance(self.ind, slice): start, _, step = self.ind.indices(self.src.size) muled = builder.mul(ind, ir.Constant(int_type, step)) src_index = builder.add(muled, ir.Constant(int_type, start)) elif isinstance(self.ind, Node): src_index = self.ind.get_ele(ind, builder)[0] else: src_index_ptr = builder.gep(self.src_inds, [ind]) src_index = builder.load(src_index_ptr) return self.src.get_ele(src_index, builder)
def get_value(scope_stack, builder: ir.IRBuilder, item): if isinstance(item, tuple) and item[0] == 'id': identifier = get_identifier(scope_stack, item[1]) item = identifier['value'] if item.opname == 'alloca' and identifier['type'] != 'array': item = builder.load(item) elif item.opname == 'alloca' and identifier['type'] == 'array': item = builder.gep(item, [int_type(0), int_type(0)]) elif isinstance(item, tuple) and item[0] == 'array_index': item = builder.load(item[1]) elif isinstance(item, tuple) and item[0] == 'struct': item = get_struct_pointer(scope_stack, builder, item[1], item[2]) item = builder.load(item) elif isinstance(item, tuple) and item[0] == 'struct_ptr': item = get_struct_ptr_pointer(scope_stack, builder, item[1], item[2]) item = builder.load(item) return item
def code_gen(self, builder: ir.IRBuilder) -> None: """Wrapper of the code_gen function, Rcursively generates it's dependences and call itselves _code_gen core Note that only Lvalue node need generate llvm ir """ if self.gened: return for dep in self.dependence: dep.code_gen(builder) self.gened = True if isinstance(self.ind, (list, tuple, np.ndarray)): for i in range(len(self.ind)): index = ir.Constant(int_type, self.ind[i]) builder.store( index, builder.gep(self.src_inds, [ir.Constant(int_type, i)])) if self.vtype == LRValue.LEFT: self._code_gen(builder)
def insert_string(module: ir.Module, scope_stack: list, builder: ir.IRBuilder, raw_data: str, prt=False): try: global_string = scope_stack[0][raw_data] except KeyError: data = eval('%s' % raw_data) data += '\00' data = data.encode() str_type = ir.ArrayType(char_type, len(data)) const_string = ir.Constant(str_type, bytearray(data)) global_string = ir.GlobalVariable(module, str_type, module.get_unique_name(raw_data)) global_string.initializer = const_string scope_stack[0][raw_data] = global_string if prt: return builder.gep(global_string, [int_type(0), int_type(0)], True) else: return global_string
def compile(self, module: ir.Module, builder: ir.IRBuilder, symbols: SymbolTable) -> ir.Value: if self.parent is None: return symbols.get_symbol(self.ID) else: ret = self.parent.compile(module, builder, symbols) rtype = ret.type if isinstance(rtype, ir.PointerType): rtype = rtype.pointee if isinstance(rtype, ir.IdentifiedStructType): index = symbols.get_elements(rtype.name).index(self.ID) ret = builder.gep( ret, [make_constant(IntType, i) for i in (0, index)]) rtype = ret.type if isinstance(rtype, ir.PointerType) and ( rtype.pointee in [i.ir_type for i in PrimitiveTypes]): ret = builder.load(ret) return ret
class GeneratorVisitor(SmallCVisitor): def __init__(self, output_file=None): super(SmallCVisitor, self).__init__() self.Module = Module(name=__file__) self.Module.triple = "x86_64-pc-linux-gnu" self.Builder = None self.function = None self.NamedValues = dict() self.counter = 0 self.loop_stack = [] self.signal_stack = [] self.cond_stack = [] self.var_stack = [] self.cur_decl_type = None self.indentation = 0 self.function_dict = dict() self.error_queue = list() self.output_file = output_file def print(self): if not self.output_file: print(self.Module) else: f = open(self.output_file, "w+") f.write(self.Module.__str__()) f.close() def error(self, info): print("Error: ", info) return 0 def toBool(self, value): zero = Constant(self.getType('bool'), 0) return self.Builder.icmp_signed('!=', value, zero) def getVal_of_expr(self, expr): temp = self.visit(expr) if isinstance(temp, Constant) or isinstance( temp, CallInstr) or isinstance(temp, LoadInstr) or isinstance( temp, Instruction) or isinstance(temp, GlobalVariable): value = temp else: temp_val = self.getVal_local(temp.IDENTIFIER().getText()) temp_ptr = temp_val['ptr'] if temp.array_indexing(): index = self.getVal_of_expr(temp.array_indexing().expr()) temp_ptr = self.Builder.gep(temp_ptr, [Constant(IntType(32), 0), index], inbounds=True) # if isinstance(temp_val['type'],ArrayType): # if temp.array_indexing(): # index = self.getVal_of_expr(temp.array_indexing().expr()) # temp_ptr = self.Builder.gep(temp_ptr, [Constant(IntType(32), 0), index], inbounds=True) # elif temp.AMPERSAND(): # Constant(PointerType(IntType(8)), temp_ptr.getText()) # elif temp.ASTERIKS(): # pass # else: #返回数组地址 # temp_ptr = self.Builder.gep(temp_ptr, [Constant(IntType(32), 0), Constant(IntType(32), 0)], inbounds=True) # return temp_ptr value = self.Builder.load(temp_ptr) return value def getType(self, type): if type == 'int': return IntType(32) elif type == 'char': return IntType(8) elif type == 'float': return FloatType() elif type == 'bool': return IntType(1) elif type == 'void': return VoidType() else: self.error("type error in <getType>") def getVal_local(self, id): temp_maps = self.var_stack[::-1] for map in temp_maps: if id in map.keys(): return map[id] self.error("value error in <getVal_local>") return None def visitFunction_definition(self, ctx: SmallCParser.Function_definitionContext): retType = self.getType(ctx.type_specifier().getText()) if ctx.identifier().ASTERIKS(): retType = retType.as_pointer() argsType = [] argsName = [] # args if ctx.param_decl_list(): args = ctx.param_decl_list() var_arg = False for t in args.getChildren(): if t.getText() != ',': if t.getText() == '...': var_arg = True break t_type = self.getType(t.type_specifier().getText()) if t.identifier().ASTERIKS(): t_type = t_type.as_pointer() argsType.append(t_type) argsName.append(t.identifier().IDENTIFIER().getText()) funcType = FunctionType(retType, tuple(argsType), var_arg=var_arg) # no args else: funcType = FunctionType(retType, ()) # function if ctx.identifier().IDENTIFIER().getText() in self.function_dict: func = self.function_dict[ctx.identifier().IDENTIFIER().getText()] else: func = Function(self.Module, funcType, name=ctx.identifier().IDENTIFIER().getText()) self.function_dict[ctx.identifier().IDENTIFIER().getText()] = func # blocks or ; if ctx.compound_stmt(): self.function = ctx.identifier().IDENTIFIER().getText() block = func.append_basic_block( ctx.identifier().IDENTIFIER().getText()) varDict = dict() self.Builder = IRBuilder(block) for i, arg in enumerate(func.args): arg.name = argsName[i] alloca = self.Builder.alloca(arg.type, name=arg.name) self.Builder.store(arg, alloca) varDict[arg.name] = { "id": arg.name, "type": arg.type, "value": None, "ptr": alloca } self.var_stack.append(varDict) self.visit(ctx.compound_stmt()) if isinstance(retType, VoidType): self.Builder.ret_void() self.var_stack.pop() self.function = None return def visitFunctioncall(self, ctx: SmallCParser.FunctioncallContext): var_map = self.var_stack[-1] function = self.function_dict[ctx.identifier().getText()] arg_types = function.args index = 0 args = [] if ctx.param_list(): for param in ctx.param_list().getChildren(): if (param.getText() == ','): continue temp = self.getVal_of_expr(param) arg_type = None if index < len(arg_types): arg_type = arg_types[index] ptr_flag = False if arg_type: if isinstance(arg_type.type, PointerType): ptr_flag = True elif self.getVal_local(temp.name): temp_type = self.getVal_local(temp.name)['type'] if isinstance(temp_type, PointerType): ptr_flag = True if not ptr_flag and not isinstance(temp, Constant): temp = self.Builder.load(temp) args.append(temp) index += 1 return self.Builder.call(function, args) # def visitVar_decl(self, ctx: SmallCParser.Var_declContext): # type = self.getType(ctx.type_specifier()) # list = ctx.var_decl_list() # for var in list.getChildren(): # if var.getText() != ',': # if self.builder: # alloca = self.builder.alloca(type, name=var.identifier().getText()) # self.builder.store(Constant(type, None), alloca) # self.var_stack[-1][var.identifier().getText()] = alloca # else: # g_var = GlobalVariable(self.Module, type, var.identifier().getText()) # g_var.initializer = Constant(type, None) # return def visitStmt(self, ctx: SmallCParser.StmtContext): if ctx.RETURN(): value = self.getVal_of_expr(ctx.expr()) if isinstance(value.type, PointerType): value = self.Builder.load(value) return self.Builder.ret(value) elif ctx.CONTINUE(): self.signal_stack[-1] = 1 loop_blocks = self.loop_stack[-1] self.Builder.branch(loop_blocks['continue']) self.Builder.position_at_start(loop_blocks['buf']) return None elif ctx.BREAK(): self.signal_stack[-1] = -1 loop_blocks = self.loop_stack[-1] self.Builder.branch(loop_blocks['break']) self.Builder.position_at_start(loop_blocks['buf']) else: return self.visitChildren(ctx) def visitCompound_stmt(self, ctx: SmallCParser.Compound_stmtContext): # builder = IRBuilder(self.block_stack[-1]) # block = self.Builder.append_basic_block() # self.block_stack.append(block) # with self.Builder.goto_block(block): result = self.visitChildren(ctx) # self.block_stack.pop() return result def visitAssignment(self, ctx: SmallCParser.AssignmentContext): value = self.getVal_of_expr(ctx.expr()) identifier = ctx.identifier() identifier = self.getVal_local(identifier.IDENTIFIER().getText()) if isinstance(identifier['type'], ArrayType): if ctx.identifier().array_indexing(): index = self.getVal_of_expr( ctx.identifier().array_indexing().expr()) if isinstance(index.type, PointerType): index = self.Builder.load(index) else: index = Constant(IntType(32), 0) tempPtr = self.Builder.gep(identifier['ptr'], [Constant(IntType(32), 0), index], inbounds=True) if isinstance(value.type, PointerType): value = self.Builder.load(value) return self.Builder.store(value, tempPtr) if isinstance(value.type, PointerType): value = self.Builder.load(value) return self.Builder.store(value, identifier['ptr']) def visitExpr(self, ctx: SmallCParser.ExprContext): if ctx.condition(): return self.visit(ctx.condition()) elif ctx.assignment(): return self.visit(ctx.assignment()) elif ctx.functioncall(): return self.visit(ctx.functioncall()) def visitCondition(self, ctx: SmallCParser.ConditionContext): if ctx.expr(): disjunction = self.getVal_of_expr(ctx.disjunction()) if isinstance(disjunction.type, PointerType): disjunction = self.Builder.load(disjunction) cond = self.Builder.icmp_signed('!=', disjunction, Constant(disjunction.type, 0)) expr = self.getVal_of_expr(ctx.expr()) if isinstance(expr.type, PointerType): expr = self.Builder.load(expr) condition = self.getVal_of_expr(ctx.condition()) return self.Builder.select(cond, expr, condition) else: return self.getVal_of_expr(ctx.disjunction()) def visitDisjunction(self, ctx: SmallCParser.DisjunctionContext): if ctx.disjunction(): disjunction = self.getVal_of_expr(ctx.disjunction()) if isinstance(disjunction.type, PointerType): disjunction = self.Builder.load(disjunction) conjunction = self.getVal_of_expr(ctx.conjunction()) if isinstance(conjunction.type, PointerType): conjunction = self.Builder.load(conjunction) left = self.Builder.icmp_signed('!=', disjunction, Constant(disjunction.type, 0)) right = self.Builder.icmp_signed('!=', conjunction, Constant(conjunction.type, 0)) return self.Builder.or_(left, right) else: return self.getVal_of_expr(ctx.conjunction()) def visitConjunction(self, ctx: SmallCParser.ConjunctionContext): if ctx.conjunction(): conjunction = self.getVal_of_expr(ctx.conjunction()) if isinstance(conjunction.type, PointerType): conjunction = self.Builder.load(conjunction) comparison = self.getVal_of_expr(ctx.comparison()) if isinstance(comparison.type, PointerType): comparison = self.Builder.load(comparison) left = self.Builder.icmp_signed('!=', conjunction, Constant(conjunction.type, 0)) right = self.Builder.icmp_signed('!=', comparison, Constant(comparison.type, 0)) return self.Builder.and_(left, right) else: return self.getVal_of_expr(ctx.comparison()) def visitComparison(self, ctx: SmallCParser.ComparisonContext): if ctx.EQUALITY(): relation1 = self.getVal_of_expr(ctx.relation(0)) if isinstance(relation1.type, PointerType): relation1 = self.Builder.load(relation1) relation2 = self.getVal_of_expr(ctx.relation(1)) if isinstance(relation2.type, PointerType): relation2 = self.Builder.load(relation2) return self.Builder.icmp_signed('==', relation1, relation2) elif ctx.NEQUALITY(): relation1 = self.getVal_of_expr(ctx.relation(0)) if isinstance(relation1.type, PointerType): relation1 = self.Builder.load(relation1) relation2 = self.getVal_of_expr(ctx.relation(1)) if isinstance(relation2.type, PointerType): relation2 = self.Builder.load(relation2) return self.Builder.icmp_signed('!=', relation1, relation2) else: return self.getVal_of_expr(ctx.relation(0)) def visitRelation(self, ctx: SmallCParser.RelationContext): if len(ctx.equation()) > 1: equation1 = self.getVal_of_expr(ctx.equation(0)) if isinstance(equation1.type, PointerType): equation1 = self.Builder.load(equation1) equation2 = self.getVal_of_expr(ctx.equation(1)) if isinstance(equation2.type, PointerType): equation2 = self.Builder.load(equation2) if ctx.LEFTANGLE(): value = self.Builder.icmp_signed('<', equation1, equation2) elif ctx.RIGHTANGLE(): value = self.Builder.icmp_signed('>', equation1, equation2) elif ctx.LEFTANGLEEQUAL(): value = self.Builder.icmp_signed('<=', equation1, equation2) elif ctx.RIGHTANGLEEQUAL(): value = self.Builder.icmp_signed('>=', equation1, equation2) return value else: return self.getVal_of_expr(ctx.equation(0)) def visitFor_stmt(self, ctx: SmallCParser.For_stmtContext): func = self.function_dict[self.function] end_block = func.append_basic_block() self.var_stack.append({}) decl_block = func.append_basic_block() self.var_stack.append({}) cond_block = func.append_basic_block() self.var_stack.append({}) stmt_block = func.append_basic_block() self.var_stack.append({}) loop_block = func.append_basic_block() # 1 -> continue, -1 -> break self.signal_stack.append(0) self.loop_stack.append({ 'continue': cond_block, 'break': end_block, 'buf': loop_block }) with self.Builder.goto_block(decl_block): # self.Builder.position_at_start(end_block) if ctx.var_decl(): self.visit(ctx.var_decl()) elif ctx.var_decl_list(): self.visit(ctx.var_decl_list()) else: self.error("for error in <visitFor_stmt>") self.Builder.branch(cond_block) self.Builder.branch(decl_block) with self.Builder.goto_block(cond_block): # cond_expr cond_expr = ctx.expr(0) cond_expr = self.visit(cond_expr) cond_expr = self.toBool(cond_expr) self.Builder.cbranch(cond_expr, stmt_block, end_block) with self.Builder.goto_block(stmt_block): # expr self.visit(ctx.stmt()) expr = ctx.expr(1) self.visit(expr) self.Builder.branch(cond_block) if self.signal_stack[-1] == 0: loop_blocks = self.loop_stack[-1] self.Builder.position_at_start(loop_blocks['buf']) self.Builder.branch(end_block) self.Builder.position_at_start(end_block) self.loop_stack.pop() self.signal_stack.pop() self.var_stack.pop() self.var_stack.pop() self.var_stack.pop() self.var_stack.pop() def visitWhile_stmt(self, ctx: SmallCParser.While_stmtContext): func = self.function_dict[self.function] end_block = func.append_basic_block() self.var_stack.append({}) cond_block = func.append_basic_block() self.var_stack.append({}) stmt_block = func.append_basic_block() self.var_stack.append({}) loop_block = func.append_basic_block() # 1 -> continue, -1 -> break self.signal_stack.append(0) self.loop_stack.append({ 'continue': cond_block, 'break': end_block, 'buf': loop_block }) self.Builder.branch(cond_block) with self.Builder.goto_block(cond_block): expr = self.getVal_of_expr(ctx.expr()) cond_expr = self.toBool(expr) self.Builder.cbranch(cond_expr, stmt_block, end_block) with self.Builder.goto_block(stmt_block): self.visit(ctx.stmt()) self.Builder.branch(cond_block) if self.signal_stack[-1] == 0: loop_blocks = self.loop_stack[-1] self.Builder.position_at_start(loop_blocks['buf']) self.Builder.branch(end_block) self.Builder.position_at_start(end_block) self.loop_stack.pop() self.signal_stack.pop() self.var_stack.pop() self.var_stack.pop() self.var_stack.pop() def visitCond_stmt(self, ctx: SmallCParser.Cond_stmtContext): expr = self.getVal_of_expr(ctx.expr()) cond_expr = self.toBool(expr) else_expr = ctx.ELSE() if else_expr: with self.Builder.if_else(cond_expr) as (then, otherwise): with then: self.var_stack.append({}) true_stmt = ctx.stmt(0) self.visit(true_stmt) self.var_stack.pop() with otherwise: self.var_stack.append({}) else_stmt = ctx.stmt(1) self.visit(else_stmt) self.var_stack.pop() else: with self.Builder.if_then(cond_expr): self.var_stack.append({}) true_stmt = ctx.stmt(0) self.visit(true_stmt) self.var_stack.pop() return None def visitVar_decl(self, ctx: SmallCParser.Var_declContext): self.cur_decl_type = self.getType(ctx.type_specifier().getText()) return self.visitChildren(ctx) def visitVar_decl_list(self, ctx: SmallCParser.Var_decl_listContext): ans = [] decls = ctx.variable_id() for decl in decls: ans.append(self.visit(decl)) return ans def visitVariable_id(self, ctx: SmallCParser.Variable_idContext): identifier = ctx.identifier() type = self.cur_decl_type if not self.function: if identifier.array_indexing(): length = self.getVal_of_expr( identifier.array_indexing().expr()) type = ArrayType(type, length.constant) g_var = GlobalVariable(self.Module, type, identifier.IDENTIFIER().getText()) else: g_var = GlobalVariable(self.Module, type, identifier.IDENTIFIER().getText()) g_var.initializer = Constant(type, None) atomic = { "id": identifier.IDENTIFIER().getText(), "type": type, "value": None, "ptr": g_var } if not len(self.var_stack): self.var_stack.append({}) self.var_stack[0][identifier.IDENTIFIER().getText()] = atomic return g_var var_map = self.var_stack[-1] if identifier.array_indexing(): length = self.getVal_of_expr(identifier.array_indexing().expr()) type = ArrayType(type, length.constant) ptr = self.Builder.alloca(typ=type, name=identifier.IDENTIFIER().getText()) else: ptr = self.Builder.alloca(typ=type, name=identifier.IDENTIFIER().getText()) expr = ctx.expr() if expr: value = self.getVal_of_expr(expr) else: value = Constant(type, None) if isinstance(value.type, PointerType): value = self.Builder.load(value) self.Builder.store(value, ptr) var_map[identifier.IDENTIFIER().getText()] = { "id": identifier.IDENTIFIER().getText(), "type": type, "value": value, "ptr": ptr } return ptr def visitPrimary(self, ctx: SmallCParser.PrimaryContext): if ctx.BOOLEAN(): return Constant(IntType(1), bool(ctx.getText())) elif ctx.INTEGER(): return Constant(IntType(32), int(ctx.getText())) elif ctx.REAL(): return Constant(FloatType, float(ctx.getText())) elif ctx.CHARCONST(): tempStr = ctx.getText()[1:-1] tempStr = tempStr.replace('\\n', '\n') tempStr = tempStr.replace('\\0', '\0') if ctx.getText()[0] == '"': tempStr += '\0' temp = GlobalVariable(self.Module, ArrayType(IntType(8), len(tempStr)), name="str_" + tempStr[:-1] + str(self.counter)) self.counter += 1 temp.initializer = Constant( ArrayType(IntType(8), len(tempStr)), bytearray(tempStr, encoding='utf-8')) temp.global_constant = True return self.Builder.gep( temp, [Constant(IntType(32), 0), Constant(IntType(32), 0)], inbounds=True) return Constant(IntType(8), ord(tempStr[0])) elif ctx.identifier(): return self.visit(ctx.identifier()) elif ctx.functioncall(): return self.visit(ctx.functioncall()) elif ctx.expr(): return self.visit(ctx.expr()) else: return self.error("type error in <visitPrimary>") def visitFactor(self, ctx: SmallCParser.FactorContext): if (ctx.MINUS()): factor = self.getVal_of_expr(ctx.factor()) if isinstance(factor.type, PointerType): factor = self.Builder.load(factor) factor = self.Builder.neg(factor) return factor return self.visitChildren(ctx) def visitTerm(self, ctx: SmallCParser.TermContext): if (ctx.ASTERIKS()): term = self.getVal_of_expr(ctx.term()) if isinstance(term.type, PointerType): term = self.Builder.load(term) factor = self.getVal_of_expr(ctx.factor()) if isinstance(factor.type, PointerType): factor = self.Builder.load(factor) return self.Builder.mul(term, factor) if (ctx.SLASH()): term = self.getVal_of_expr(ctx.term()) if isinstance(term.type, PointerType): term = self.Builder.load(term) factor = self.getVal_of_expr(ctx.factor()) if isinstance(factor.type, PointerType): factor = self.Builder.load(factor) return self.Builder.sdiv(term, factor) return self.visitChildren(ctx) def visitEquation(self, ctx: SmallCParser.EquationContext): if (ctx.PLUS()): equation = self.getVal_of_expr(ctx.equation()) if isinstance(equation.type, PointerType): equation = self.Builder.load(equation) if str(equation.type) != 'i32': equation = self.Builder.zext(equation, IntType(32)) term = self.getVal_of_expr(ctx.term()) if isinstance(term.type, PointerType): term = self.Builder.load(term) if str(term.type) != 'i32': term = self.Builder.zext(term, IntType(32)) return self.Builder.add(equation, term) if (ctx.MINUS()): equation = self.getVal_of_expr(ctx.equation()) if isinstance(equation.type, PointerType): equation = self.Builder.load(equation) if str(equation.type) != 'i32': equation = self.Builder.zext(equation, IntType(32)) term = self.getVal_of_expr(ctx.term()) if isinstance(term.type, PointerType): term = self.Builder.load(term) if str(term.type) != 'i32': term = self.Builder.zext(term, IntType(32)) return self.Builder.sub(equation, term) return self.visitChildren(ctx) def visitIdentifier(self, ctx: SmallCParser.IdentifierContext): if (ctx.AMPERSAND() and ctx.array_indexing()): return self.Builder.gep(self.getVal_of_expr(ctx.IDENTIFIER()), self.getVal_of_expr(ctx.array_indexing())) if (ctx.ASTERIKS() and ctx.array_indexing()): return self.Builder.load( self.Builder.gep(self.getVal_of_expr(ctx.IDENTIFIER()), self.getVal_of_expr(ctx.array_indexing()))) if (ctx.AMPERSAND()): return self.getVal_local(str(ctx.IDENTIFIER()))['ptr'] if (ctx.ASTERIKS()): return self.getVal_local(str(ctx.IDENTIFIER()))['ptr'] temp = self.getVal_local(ctx.IDENTIFIER().getText()) temp_ptr = temp['ptr'] # if isinstance(temp_val['type'],ArrayType): # if temp.array_indexing(): # index = self.getVal_of_expr(temp.array_indexing().expr()) # temp_ptr = self.Builder.gep(temp_ptr, [Constant(IntType(32), 0), index], inbounds=True) # elif temp.AMPERSAND(): # Constant(PointerType(IntType(8)), temp_ptr.getText()) # elif temp.ASTERIKS(): # pass # else: #返回数组地址 # temp_ptr = self.Builder.gep(temp_ptr, [Constant(IntType(32), 0), Constant(IntType(32), 0)], inbounds=True) # return temp_ptr if isinstance(temp['type'], ArrayType): if ctx.array_indexing(): index = self.getVal_of_expr(ctx.array_indexing().expr()) if isinstance(index.type, PointerType): index = self.Builder.load(index) temp_ptr = self.Builder.gep(temp_ptr, [Constant(IntType(32), 0), index], inbounds=True) value = self.Builder.load(temp_ptr) self.var_stack[-1][temp_ptr.name] = { 'id': temp_ptr.name, 'type': temp_ptr.type, 'value': value, 'ptr': temp_ptr } return temp_ptr else: temp_ptr = self.Builder.gep( temp_ptr, [Constant(IntType(32), 0), Constant(IntType(32), 0)], inbounds=True) value = self.Builder.load(temp_ptr) self.var_stack[-1][temp_ptr.name] = { 'id': temp_ptr.name, 'type': temp_ptr.type, 'value': value, 'ptr': temp_ptr } return temp_ptr temp_val = temp['ptr'] return temp_val def visitArray_indexing(self, ctx: SmallCParser.Array_indexingContext): return self.visit(ctx.expr())
def _build_wrapper(self, library, name): """ The LLVM IRBuilder code to create the gufunc wrapper. The *library* arg is the CodeLibrary to which the wrapper should be added. The *name* arg is the name of the wrapper function being created. """ intp_t = self.context.get_value_type(types.intp) fnty = self._wrapper_function_type() wrapper_module = library.create_ir_module('_gufunc_wrapper') func_type = self.call_conv.get_function_type(self.fndesc.restype, self.fndesc.argtypes) fname = self.fndesc.llvm_func_name func = ir.Function(wrapper_module, func_type, name=fname) func.attributes.add("alwaysinline") wrapper = ir.Function(wrapper_module, fnty, name) # The use of weak_odr linkage avoids the function being dropped due # to the order in which the wrappers and the user function are linked. wrapper.linkage = 'weak_odr' arg_args, arg_dims, arg_steps, arg_data = wrapper.args arg_args.name = "args" arg_dims.name = "dims" arg_steps.name = "steps" arg_data.name = "data" builder = IRBuilder(wrapper.append_basic_block("entry")) loopcount = builder.load(arg_dims, name="loopcount") pyapi = self.context.get_python_api(builder) # Unpack shapes unique_syms = set() for grp in (self.sin, self.sout): for syms in grp: unique_syms |= set(syms) sym_map = {} for syms in self.sin: for s in syms: if s not in sym_map: sym_map[s] = len(sym_map) sym_dim = {} for s, i in sym_map.items(): sym_dim[s] = builder.load( builder.gep(arg_dims, [self.context.get_constant(types.intp, i + 1)])) # Prepare inputs arrays = [] step_offset = len(self.sin) + len(self.sout) for i, (typ, sym) in enumerate( zip(self.signature.args, self.sin + self.sout)): ary = GUArrayArg(self.context, builder, arg_args, arg_steps, i, step_offset, typ, sym, sym_dim) step_offset += len(sym) arrays.append(ary) bbreturn = builder.append_basic_block('.return') # Prologue self.gen_prologue(builder, pyapi) # Loop with cgutils.for_range(builder, loopcount, intp=intp_t) as loop: args = [a.get_array_at_offset(loop.index) for a in arrays] innercall, error = self.gen_loop_body(builder, pyapi, func, args) # If error, escape cgutils.cbranch_or_continue(builder, error, bbreturn) builder.branch(bbreturn) builder.position_at_end(bbreturn) # Epilogue self.gen_epilogue(builder, pyapi) builder.ret_void() # Link library.add_ir_module(wrapper_module) library.add_linking_library(self.library)