def GenerateMemcpyFun(unit: ir.Unit, addr_type: o.DK) -> ir.Fun: fun = unit.AddFun( ir.Fun("$memcpy", o.FUN_KIND.NORMAL, [], [addr_type, addr_type, o.DK.U32])) dst = fun.AddReg(ir.Reg("dst", addr_type)) src = fun.AddReg(ir.Reg("src", addr_type)) cnt = fun.AddReg(ir.Reg("cnt", o.DK.U32)) data = fun.AddReg(ir.Reg("data", o.DK.U8)) prolog = fun.AddBbl(ir.Bbl("prolog")) loop = fun.AddBbl(ir.Bbl("loop")) epilog = fun.AddBbl(ir.Bbl("epilog")) prolog.AddIns(ir.Ins(o.POPARG, [dst])) prolog.AddIns(ir.Ins(o.POPARG, [src])) prolog.AddIns(ir.Ins(o.POPARG, [cnt])) prolog.AddIns(ir.Ins(o.BRA, [epilog])) loop.AddIns(ir.Ins(o.SUB, [cnt, cnt, ONE])) loop.AddIns(ir.Ins(o.LD, [data, src, cnt])) loop.AddIns(ir.Ins(o.ST, [dst, cnt, data])) epilog.AddIns(ir.Ins(o.BLT, [ZERO, cnt, loop])) epilog.AddIns(ir.Ins(o.RET, [])) return fun
def BuildExample() -> ir.Unit: unit = ir.Unit("fib") fun_fib = unit.AddFun( ir.Fun("fib", o.FUN_KIND.NORMAL, [o.DK.U32], [o.DK.U32])) bbl_start = fun_fib.AddBbl(ir.Bbl("start")) bbl_difficult = fun_fib.AddBbl(ir.Bbl("difficult")) reg_in = fun_fib.AddReg(ir.Reg("in", o.DK.U32)) reg_x = fun_fib.AddReg(ir.Reg("x", o.DK.U32)) reg_out = fun_fib.AddReg(ir.Reg("out", o.DK.U32)) bbl_start.AddIns(ir.Ins(o.POPARG, [reg_in])) bbl_start.AddIns( ir.Ins(o.BLT, [ir.Const(o.DK.U32, 1), reg_in, bbl_difficult])) bbl_start.AddIns(ir.Ins(o.PUSHARG, [reg_in])) bbl_start.AddIns(ir.Ins(o.RET, [])) bbl_difficult.AddIns(ir.Ins(o.MOV, [reg_out, ir.Const(o.DK.U32, 0)])) bbl_difficult.AddIns(ir.Ins(o.SUB, [reg_x, reg_in, ir.Const(o.DK.U32, 1)])) bbl_difficult.AddIns(ir.Ins(o.PUSHARG, [reg_x])) bbl_difficult.AddIns(ir.Ins(o.BSR, [fun_fib])) bbl_difficult.AddIns(ir.Ins(o.POPARG, [reg_x])) bbl_difficult.AddIns(ir.Ins(o.ADD, [reg_out, reg_out, reg_x])) bbl_difficult.AddIns(ir.Ins(o.SUB, [reg_x, reg_in, ir.Const(o.DK.U32, 2)])) bbl_difficult.AddIns(ir.Ins(o.PUSHARG, [reg_x])) bbl_difficult.AddIns(ir.Ins(o.BSR, [fun_fib])) bbl_difficult.AddIns(ir.Ins(o.POPARG, [reg_x])) bbl_difficult.AddIns(ir.Ins(o.ADD, [reg_out, reg_out, reg_x])) bbl_difficult.AddIns(ir.Ins(o.PUSHARG, [reg_out])) bbl_difficult.AddIns(ir.Ins(o.RET, [])) return unit
def GenerateInitGlobalVarsFun(mod: wasm.Module, unit: ir.Unit, addr_type: o.DK) -> ir.Fun: fun = unit.AddFun(ir.Fun("init_global_vars_fun", o.FUN_KIND.NORMAL, [], [])) bbl = fun.AddBbl(ir.Bbl("start")) epilog = fun.AddBbl(ir.Bbl("end")) epilog.AddIns(ir.Ins(o.RET, [])) section = mod.sections.get(wasm.SECTION_ID.GLOBAL) if not section: return fun val32 = fun.AddReg(ir.Reg("val32", o.DK.U32)) val64 = fun.AddReg(ir.Reg("val64", o.DK.U64)) for n, data in enumerate(section.items): kind = o.MEM_KIND.RO if data.global_type.mut is wasm.MUT.CONST else o.MEM_KIND.RW mem = unit.AddMem(ir.Mem(f"global_vars_{n}", 16, kind)) ins = GetInsFromInitializerExpression(data.expr) var_type = data.global_type.value_type if ins.opcode is wasm_opc.GLOBAL_GET: mem.AddData( ir.DataBytes(1, b"\0" * (4 if var_type.is_32bit() else 8))) src_mem = unit.GetMem(f"global_vars_{int(ins.args[0])}") reg = val32 if var_type.is_32bit() else val64 bbl.AddIns(ir.Ins(o.LD_MEM, [reg, src_mem, ZERO])) bbl.AddIns(ir.Ins(o.ST_MEM, [mem, ZERO, reg])) elif ins.opcode.kind is wasm_opc.OPC_KIND.CONST: mem.AddData( ir.DataBytes(1, ExtractBytesFromConstIns(ins, var_type))) else: assert False, f"unsupported init instructions {ins}" return fun
def GenerateStartup(unit: ir.Unit, global_argc, global_argv, main: ir.Fun, init_global: ir.Fun, init_data: ir.Fun, initial_heap_size_pages: int, addr_type: o.DK) -> ir.Fun: bit_width = addr_type.bitwidth() global_mem_base = unit.AddMem(ir.Mem("__memory_base", 0, o.MEM_KIND.EXTERN)) fun = unit.AddFun( ir.Fun("main", o.FUN_KIND.NORMAL, [o.DK.U32], [o.DK.U32, addr_type])) argc = fun.AddReg(ir.Reg("argc", o.DK.U32)) argv = fun.AddReg(ir.Reg("argv", addr_type)) bbl = fun.AddBbl(ir.Bbl("start")) bbl.AddIns(ir.Ins(o.POPARG, [argc])) bbl.AddIns(ir.Ins(o.POPARG, [argv])) bbl.AddIns(ir.Ins(o.ST_MEM, [global_argc, ZERO, argc])) bbl.AddIns(ir.Ins(o.ST_MEM, [global_argv, ZERO, argv])) bbl.AddIns(ir.Ins(o.BSR, [unit.GetFun("__wasi_init")])) if initial_heap_size_pages: bbl.AddIns( ir.Ins(o.PUSHARG, [ir.Const(o.DK.S32, initial_heap_size_pages)])) bbl.AddIns(ir.Ins(o.BSR, [unit.GetFun("__memory_grow")])) bbl.AddIns(ir.Ins(o.POPARG, [fun.AddReg(ir.Reg("dummy", o.DK.S32))])) mem_base = fun.AddReg(ir.Reg("mem_base", addr_type)) bbl.AddIns(ir.Ins(o.LD_MEM, [mem_base, global_mem_base, ZERO])) if init_global: bbl.AddIns(ir.Ins(o.BSR, [init_global])) if init_data: bbl.AddIns(ir.Ins(o.PUSHARG, [mem_base])) bbl.AddIns(ir.Ins(o.BSR, [init_data])) bbl.AddIns(ir.Ins(o.PUSHARG, [mem_base])) bbl.AddIns(ir.Ins(o.BSR, [main])) bbl.AddIns(ir.Ins(o.PUSHARG, [ir.Const(o.DK.U32, 0)])) bbl.AddIns(ir.Ins(o.RET, [])) return fun
def _GetRegOrConstOperand(fun: ir.Fun, last_kind: o.DK, ok: o.OP_KIND, tc: o.TC, token: str, regs_cpu: Dict[str, ir.Reg]) -> Any: if ok == o.OP_KIND.REG_OR_CONST: ok = o.OP_KIND.CONST if parse.IsLikelyConst(token) else o.OP_KIND.REG if ok is o.OP_KIND.REG: cpu_reg = None pos = token.find("@") if pos > 0: cpu_reg_name = token[pos + 1:] token = token[:pos] if cpu_reg_name == "STK": cpu_reg = ir.StackSlot(0) else: cpu_reg = regs_cpu.get(cpu_reg_name) assert cpu_reg is not None, f"unknown cpu_reg {token[pos + 1:]} known regs {regs_cpu.keys()}" pos = token.find(":") if pos < 0: reg = fun.GetReg(token) else: kind = token[pos + 1:] reg_name = token[:pos] reg = ir.Reg(reg_name, o.SHORT_STR_TO_RK.get(kind)) fun.AddReg(reg) assert o.CheckTypeConstraint(last_kind, tc, reg.kind) if cpu_reg: if reg.cpu_reg: assert reg.cpu_reg == cpu_reg else: reg.cpu_reg = cpu_reg return reg else: pos = token.find(":") if pos >= 0: kind = token[pos + 1:] value_str = token[:pos] const = ir.ParseConst(value_str, o.SHORT_STR_TO_RK.get(kind)) return const elif tc == o.TC.SAME_AS_PREV: const = ir.ParseConst(token, last_kind) return const elif tc == o.TC.OFFSET: const = ir.ParseOffsetConst(token) return const elif tc == o.TC.UINT: assert token[0] != "-" const = ir.ParseOffsetConst(token) return const else: assert False, f"cannot deduce type for const {token} [{tc}]"
def GenerateInitDataFun(mod: wasm.Module, unit: ir.Unit, memcpy: ir.Fun, addr_type: o.DK) -> typing.Optional[ir.Fun]: fun = unit.AddFun( ir.Fun("init_data_fun", o.FUN_KIND.NORMAL, [], [addr_type])) bbl = fun.AddBbl(ir.Bbl("start")) epilog = fun.AddBbl(ir.Bbl("end")) epilog.AddIns(ir.Ins(o.RET, [])) section = mod.sections.get(wasm.SECTION_ID.DATA) mem_base = fun.AddReg(ir.Reg("mem_base", addr_type)) bbl.AddIns(ir.Ins(o.POPARG, [mem_base])) if not section: return None offset = fun.AddReg(ir.Reg("offset", o.DK.S32)) src = fun.AddReg(ir.Reg("src", addr_type)) dst = fun.AddReg(ir.Reg("dst", addr_type)) for n, data in enumerate(section.items): assert data.memory_index == 0 assert isinstance(data.offset, wasm.Expression) ins = GetInsFromInitializerExpression(data.offset) init = unit.AddMem(ir.Mem(f"global_init_mem_{n}", 16, o.MEM_KIND.RO)) init.AddData(ir.DataBytes(1, data.init)) if ins.opcode is wasm_opc.GLOBAL_GET: src_mem = unit.GetMem(f"global_vars_{int(ins.args[0])}") bbl.AddIns(ir.Ins(o.LD_MEM, [offset, src_mem, ZERO])) elif ins.opcode is wasm_opc.I32_CONST: bbl.AddIns(ir.Ins(o.MOV, [offset, ir.Const(o.DK.S32, ins.args[0])])) else: assert False, f"unsupported init instructions {ins}" bbl.AddIns(ir.Ins(o.LEA, [dst, mem_base, offset])) bbl.AddIns(ir.Ins(o.LEA_MEM, [src, init, ZERO])) bbl.AddIns(ir.Ins(o.PUSHARG, [ir.Const(o.DK.U32, len(data.init))])) bbl.AddIns(ir.Ins(o.PUSHARG, [src])) bbl.AddIns(ir.Ins(o.PUSHARG, [dst])) bbl.AddIns(ir.Ins(o.BSR, [memcpy])) return fun
def testNoChange(self): x = ir.Reg("x", o.DK.S32) target = ir.Bbl("target") bbl = ir.Bbl("bbl") bbl.live_out.add(x) bbl.AddIns(ir.Ins(O("poparg"), [x])) bbl.AddIns(ir.Ins(O("blt"), [target, ir.OffsetConst(1), x])) DumpBbl(bbl) live_ranges = liveness.BblGetLiveRanges(bbl, None, bbl.live_out, False) live_ranges.sort() lr_cross_bbl = [lr for lr in live_ranges if lr.is_cross_bbl()] lr_lac = [lr for lr in live_ranges if liveness.LiveRangeFlag.LAC in lr.flags] assert len(live_ranges) == 1 assert len(lr_cross_bbl) == 1 assert len(lr_lac) == 0, f"{lr_lac}"
def ParseLiveRanges(fin, cpu_reg_map: Dict[str, ir.CpuReg]) -> List[LiveRange]: out: List[LiveRange] = [] for line in fin: token = line.split() if not token or token[0] == "#": continue assert token.pop(0) == "LR" start = _ParsePos(token.pop(0)) assert token.pop(0) == "-" end = _ParsePos(token.pop(0)) flags = 0 lr = LiveRange(start, end, ir.REG_INVALID, 0) out.append(lr) while token: t = token.pop(0) if t == "PRE_ALLOC": lr.flags |= LiveRangeFlag.PRE_ALLOC elif t == "LAC": lr.flags |= LiveRangeFlag.LAC elif t.startswith("def:"): reg_str = t[4:] cpu_reg_str = "" reg_name, kind_str = reg_str.split(":") if "@" in kind_str: kind_str, cpu_reg_str = kind_str.split("@") lr.reg = ir.Reg(reg_name, o.DK[kind_str]) if cpu_reg_str: lr.reg.cpu_reg = cpu_reg_map[cpu_reg_str] lr.cpu_reg = lr.reg.cpu_reg break elif t.startswith("uses:"): reg_str = t[5:] uses = [u.split(":") for u in reg_str.split(",")] for reg_name, def_pos_str in uses: lr.uses.append( FindDefRange(reg_name, int(def_pos_str), out)) elif t.startswith("#"): break else: assert False, f"parse error [{t}] {line}" return out
def ParseLiveRanges(fin, cpu_reg_map: Dict[str, ir.CpuReg]) -> List[LiveRange]: out: List[LiveRange] = [] for line in fin: token = line.split() if not token or token[0] == "#": continue assert token.pop(0) == "RANGE" start_str = token.pop(0) start = BEFORE_BBL if start_str == "BEFORE_BBL" else int(start_str) assert token.pop(0) == "-" end_str = token.pop(0) end = AFTER_BBL if end_str == "AFTER_BBL" else int(end_str) flags = 0 lr = LiveRange(start, end, ir.REG_INVALID, 0) out.append(lr) while token: t = token.pop(0) if t == "PRE_ALLOC": lr.flags |= LiveRangeFlag.PRE_ALLOC elif t == "LAC": lr.flags |= LiveRangeFlag.LAC elif t == "def:": reg_str = token.pop(0) cpu_reg_str = "" reg_name, kind_str = reg_str.split(":") if "@" in kind_str: kind_str, cpu_reg_str = kind_str.split("@") lr.reg = ir.Reg(reg_name, o.DK[kind_str]) if cpu_reg_str: lr.reg.cpu_reg = cpu_reg_map[cpu_reg_str] lr.cpu_reg = lr.reg.cpu_reg break elif t == "uses:": uses = [u.split(":") for u in token.pop(0).split(",")] for reg_name, def_pos_str in uses: lr.uses.append(FindDefRange(reg_name, int(def_pos_str), out)) else: assert False return out
from typing import List, Dict, Optional, Tuple from Base import canonicalize from Base import reg_alloc from Base import ir from Base import liveness from Base import lowering from Base import opcode_tab as o from Base import reg_stats from Base import sanity from Base import optimize from Base import serialize from CodeGenA64 import isel_tab from CodeGenA64 import regs _DUMMY_A32 = ir.Reg("dummy", o.DK.A32) _ZERO_OFFSET = ir.Const(o.DK.U32, 0) def _InsRewriteOutOfBoundsImmediates( ins: ir.Ins, fun: ir.Fun, unit: ir.Unit) -> Optional[List[ir.Ins]]: if ins.opcode in isel_tab.OPCODES_REQUIRING_SPECIAL_HANDLING: return None inss = [] mismatches = isel_tab.FindtImmediateMismatchesInBestMatchPattern(ins, True) assert mismatches != isel_tab.MATCH_IMPOSSIBLE, f"could not match opcode {ins} {ins.operands}" if mismatches == 0: return None for pos in range(o.MAX_OPERANDS): if mismatches & (1 << pos) != 0:
#!/usr/bin/python3 """ mov.r reg_s32 reg_u32 """ import unittest from Base import ir from Base import sanity from Base import opcode_tab as o reg_s64 = ir.Reg("reg_s64", o.DK.S64) reg_s32 = ir.Reg("reg_s32", o.DK.S32) reg_s18 = ir.Reg("reg_s16", o.DK.S16) reg_s8 = ir.Reg("reg_s8", o.DK.S8) reg_u64 = ir.Reg("reg_u64", o.DK.U64) reg_u32 = ir.Reg("reg_u32", o.DK.U32) reg_u16 = ir.Reg("reg_u16", o.DK.U16) reg_u8 = ir.Reg("reg_u8", o.DK.U8) reg_a64 = ir.Reg("reg_a64", o.DK.A64) reg_a32 = ir.Reg("reg_a32", o.DK.A32) reg_c64 = ir.Reg("reg_c64", o.DK.C64) reg_c32 = ir.Reg("reg_c32", o.DK.C32) class TestRegState(unittest.TestCase): def testMov(self): mov = o.Opcode.Lookup("mov")
def DirReg(unit: ir.Unit, operands: List): fun = unit.funs[-1] reg_list = operands[1] assert isinstance(reg_list, list) for r in reg_list: fun.AddReg(ir.Reg(r, operands[0]))
def GenerateFun(unit: ir.Unit, mod: wasm.Module, wasm_fun: wasm.Function, fun: ir.Fun, global_table, addr_type): # op_stack contains regs produced by GetOpReg (it may only occur a the position encoding in its name op_stack: typing.List[typing.Union[ir.Reg, ir.Const]] = [] block_stack: typing.List[Block] = [] bbls: typing.List[ir.Bbl] = [] bbl_count = 0 jtb_count = 0 bbls.append(fun.AddBbl(ir.Bbl("start"))) loc_index = 0 assert fun.input_types[0] is addr_type mem_base = fun.AddReg(ir.Reg("mem_base", addr_type)) bbls[-1].AddIns(ir.Ins(o.POPARG, [mem_base])) for dk in fun.input_types[1:]: reg = fun.AddReg(ir.Reg(f"$loc_{loc_index}", dk)) loc_index += 1 bbls[-1].AddIns(ir.Ins(o.POPARG, [reg])) for locals in wasm_fun.impl.locals_list: for i in range(locals.count): reg = fun.AddReg( ir.Reg(f"$loc_{loc_index}", WASM_TYPE_TO_CWERG_TYPE[locals.kind])) loc_index += 1 bbls[-1].AddIns(ir.Ins(o.MOV, [reg, ir.Const(reg.kind, 0)])) DEBUG = False if DEBUG: print() print( f"# {fun.name} #ins:{len(wasm_fun.impl.expr.instructions)} in:{fun.input_types} out:{fun.output_types}" ) opc = None last_opc = None for n, wasm_ins in enumerate(wasm_fun.impl.expr.instructions): last_opc = opc opc = wasm_ins.opcode args = wasm_ins.args op_stack_size_before = len(op_stack) if DEBUG: print(f"#@@ {opc.name}", args, len(op_stack)) if opc.kind is wasm_opc.OPC_KIND.CONST: # breaks for floats # breaks for floats kind = OPC_TYPE_TO_CWERG_TYPE[opc.op_type] dst = GetOpReg(fun, kind, len(op_stack)) bbls[-1].AddIns(ir.Ins(o.MOV, [dst, ir.Const(kind, args[0])])) op_stack.append(dst) elif opc is wasm_opc.NOP: pass elif opc is wasm_opc.DROP: op_stack.pop(-1) elif opc is wasm_opc.LOCAL_GET: loc = GetLocalReg(fun, int(args[0])) dst = GetOpReg(fun, loc.kind, len(op_stack)) bbls[-1].AddIns(ir.Ins(o.MOV, [dst, loc])) op_stack.append(dst) elif opc is wasm_opc.LOCAL_SET: op = op_stack.pop(-1) bbls[-1].AddIns(ir.Ins(o.MOV, [GetLocalReg(fun, int(args[0])), op])) elif opc is wasm_opc.LOCAL_TEE: op = op_stack[-1] # no pop! bbls[-1].AddIns(ir.Ins(o.MOV, [GetLocalReg(fun, int(args[0])), op])) elif opc is wasm_opc.GLOBAL_GET: var_index = int(args[0]) var: wasm.Glob = mod.sections.get( wasm.SECTION_ID.GLOBAL).items[var_index] dst = GetOpReg(fun, WASM_TYPE_TO_CWERG_TYPE[var.global_type.value_type], len(op_stack)) var_mem = unit.GetMem(f"global_vars_{var_index}") bbls[-1].AddIns(ir.Ins(o.LD_MEM, [dst, var_mem, ZERO])) op_stack.append(dst) elif opc is wasm_opc.GLOBAL_SET: op = op_stack.pop(-1) var_index = int(args[0]) var_mem = unit.GetMem(f"global_vars_{var_index}") bbls[-1].AddIns(ir.Ins(o.ST_MEM, [var_mem, ZERO, op])) elif opc.kind is wasm_opc.OPC_KIND.ALU: if wasm_opc.FLAGS.UNARY in opc.flags: opcode, arg_factory = WASM_ALU1_TO_CWERG[opc.basename] op = op_stack.pop(-1) dst = GetOpReg(fun, op.kind, len(op_stack)) bbls[-1].AddIns(ir.Ins(opcode, [dst] + arg_factory(op))) op_stack.append(dst) else: op2 = op_stack.pop(-1) op1 = op_stack.pop(-1) dst = GetOpReg(fun, op1.kind, len(op_stack)) alu = WASM_ALU_TO_CWERG[opc.basename] if wasm_opc.FLAGS.UNSIGNED in opc.flags: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) tmp3 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 3) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2])) if isinstance(alu, o.Opcode): bbls[-1].AddIns(ir.Ins(alu, [tmp3, tmp1, tmp2])) else: alu(tmp3, tmp1, tmp2, bbls[-1]) bbls[-1].AddIns(ir.Ins(o.CONV, [dst, tmp3])) else: if isinstance(alu, o.Opcode): bbls[-1].AddIns(ir.Ins(alu, [dst, op1, op2])) else: alu(dst, op1, op2, bbls[-1]) op_stack.append(dst) elif opc.kind is wasm_opc.OPC_KIND.CONV: conv, dst_unsigned, src_unsigned = WASM_CONV_TO_CWERG[opc.name] op = op_stack.pop(-1) dst = GetOpReg(fun, OPC_TYPE_TO_CWERG_TYPE[opc.op_type], len(op_stack)) if src_unsigned: tmp = GetOpReg(fun, ToUnsigned(op.kind), op_stack_size_before + 1) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp, op])) op = tmp if dst_unsigned: tmp = GetOpReg(fun, ToUnsigned(dst.kind), op_stack_size_before + 1) dst, tmp = tmp, dst bbls[-1].AddIns(ir.Ins(conv, [dst, op])) if dst_unsigned: bbls[-1].AddIns(ir.Ins(o.CONV, [tmp, dst])) dst = tmp op_stack.append(dst) elif opc.kind is wasm_opc.OPC_KIND.BITCAST: assert opc.name in SUPPORTED_BITCAST op = op_stack.pop(-1) dst = GetOpReg(fun, OPC_TYPE_TO_CWERG_TYPE[opc.op_type], len(op_stack)) bbls[-1].AddIns(ir.Ins(o.BITCAST, [dst, op])) op_stack.append(dst) elif opc.kind is wasm_opc.OPC_KIND.CMP: # this always works because of the sentinel: "end" succ = wasm_fun.impl.expr.instructions[n + 1] if succ.opcode not in { wasm_opc.IF, wasm_opc.BR_IF, wasm_opc.SELECT }: cmp, res1, res2, op1, op2, unsigned = MakeCompare( opc, op_stack) if unsigned: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2])) op1 = tmp1 op2 = tmp2 dst = GetOpReg(fun, o.DK.S32, len(op_stack)) bbls[-1].AddIns(ir.Ins(cmp, [dst, res1, res2, op1, op2])) op_stack.append(dst) elif opc is wasm_opc.LOOP or opc is wasm_opc.BLOCK: block_stack.append( MakeBlock(bbl_count, opc, args, fun, op_stack, mod)) bbl_count += 1 bbls.append(block_stack[-1].start_bbl) elif opc is wasm_opc.IF: # note we do set the new bbl right away because we add some instructions to the old one # this always works because the stack cannot be empty at this point pred = wasm_fun.impl.expr.instructions[n - 1].opcode br, op1, op2, unsigned = MakeBranch(pred, op_stack, True) if unsigned: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2])) op1 = tmp1 op2 = tmp2 block_stack.append( MakeBlock(bbl_count, opc, args, fun, op_stack, mod)) # assert len(block_stack[-1].param_types) == 0 bbl_count += 1 bbls[-1].AddIns(ir.Ins(br, [op1, op2, block_stack[-1].else_bbl])) bbls.append(block_stack[-1].start_bbl) elif opc is wasm_opc.ELSE: block = block_stack[-1] assert block.opcode is wasm_opc.IF block.FinalizeIf(op_stack, bbls[-1], fun, last_opc in {wasm_opc.UNREACHABLE, wasm_opc.BR}) if DEBUG: print(f"#@@ AFTER STACK RESTORE", len(op_stack)) op_stack = op_stack[0:block.stack_start] bbls[-1].AddIns(ir.Ins(o.BRA, [block.end_bbl])) assert block.else_bbl is not None bbls.append(block.else_bbl) block.else_bbl = None elif opc is wasm_opc.END: if block_stack: block = block_stack.pop(-1) block.FinalizeEnd( op_stack, bbls[-1], fun, last_opc in {wasm_opc.UNREACHABLE, wasm_opc.BR}) if block.else_bbl: bbls.append(block.else_bbl) bbls.append(block.end_bbl) else: # end of function assert n + 1 == len(wasm_fun.impl.expr.instructions) pred = wasm_fun.impl.expr.instructions[n - 1].opcode if pred not in { wasm_opc.RETURN, wasm_opc.UNREACHABLE, wasm_opc.BR }: for x in reversed(fun.output_types): op = op_stack.pop(-1) assert op.kind == x, f"outputs: {fun.output_types} mismatch {op.kind} vs {x}" bbls[-1].AddIns(ir.Ins(o.PUSHARG, [op])) bbls[-1].AddIns(ir.Ins(o.RET, [])) elif opc is wasm_opc.CALL: wasm_callee = mod.functions[int(wasm_ins.args[0])] callee = unit.GetFun(wasm_callee.name) assert callee, f"unknown fun: {wasm_callee.name}" EmitCall(fun, bbls[-1], ir.Ins(o.BSR, [callee]), op_stack, mem_base, callee) elif opc is wasm_opc.CALL_INDIRECT: assert isinstance(args[1], wasm.TableIdx), f"{type(args[1])}" assert int(args[1]) == 0, f"only one table supported" assert isinstance(args[0], wasm.TypeIdx), f"{type(args[0])}" type_sec = mod.sections.get(wasm.SECTION_ID.TYPE) func_type: wasm.FunctionType = type_sec.items[int(args[0])] arguments = [addr_type] + TranslateTypeList(func_type.args) returns = TranslateTypeList(func_type.rets) # print (f"# @@@@ CALL INDIRECT {returns} <- {arguments} [{int(args[0])}] {func_type}") signature = FindFunWithSignature(unit, arguments, returns) table_reg = GetOpReg(fun, addr_type, len(op_stack)) code_type = o.DK.C32 if addr_type is o.DK.A32 else o.DK.C64 fun_reg = GetOpReg(fun, code_type, len(op_stack) + 1) index = op_stack.pop(-1) assert index.kind is o.DK.S32 bbls[-1].AddIns( ir.Ins(o.MUL, [ index, index, ir.Const(o.DK.U32, code_type.bitwidth() // 8) ])) bbls[-1].AddIns(ir.Ins(o.LEA_MEM, [table_reg, global_table, ZERO])) bbls[-1].AddIns(ir.Ins(o.LD, [fun_reg, table_reg, index])) EmitCall(fun, bbls[-1], ir.Ins(o.JSR, [fun_reg, signature]), op_stack, mem_base, signature) elif opc is wasm_opc.RETURN: for x in reversed(fun.output_types): op = op_stack.pop(-1) assert op.kind == x, f"outputs: {fun.output_types} mismatch {op.kind} vs {x}" bbls[-1].AddIns(ir.Ins(o.PUSHARG, [op])) bbls[-1].AddIns(ir.Ins(o.RET, [])) elif opc is wasm_opc.BR: assert isinstance(args[0], wasm.LabelIdx) block = GetTargetBlock(block_stack, args[0]) target = block.start_bbl if block.opcode is not wasm_opc.LOOP: target = block.end_bbl block.FinalizeResultsCopy(op_stack, bbls[-1], fun) bbls[-1].AddIns(ir.Ins(o.BRA, [target])) elif opc is wasm_opc.BR_IF: assert isinstance(args[0], wasm.LabelIdx) pred = wasm_fun.impl.expr.instructions[n - 1].opcode br, op1, op2, unsigned = MakeBranch(pred, op_stack, False) if unsigned: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2])) op1 = tmp1 op2 = tmp2 block = GetTargetBlock(block_stack, args[0]) target = block.start_bbl if block.opcode is not wasm_opc.LOOP: target = block.end_bbl block.FinalizeResultsCopy(op_stack, bbls[-1], fun) bbls[-1].AddIns(ir.Ins(br, [op1, op2, target])) elif opc is wasm_opc.SELECT: pred = wasm_fun.impl.expr.instructions[n - 1].opcode br, op1, op2, unsigned = MakeBranch(pred, op_stack, False) val_f = op_stack.pop(-1) val_t = op_stack.pop(-1) assert val_f.kind == val_t.kind reg = GetOpReg(fun, val_f.kind, len(op_stack)) op_stack.append(reg) bbls[-1].AddIns(ir.Ins(o.MOV, [reg, val_t])) bbls.append(fun.AddBbl(ir.Bbl(f"select_{bbl_count}"))) bbl_count += 1 if unsigned: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) bbls[-2].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-2].AddIns(ir.Ins(o.CONV, [tmp2, op2])) op1 = tmp1 op2 = tmp2 bbls[-2].AddIns(ir.Ins(br, [op1, op2, bbls[-1]])) bbls[-2].AddIns(ir.Ins(o.MOV, [reg, val_f])) elif opc.kind is wasm_opc.OPC_KIND.STORE: val = op_stack.pop(-1) offset = op_stack.pop(-1) if args[1] != 0: tmp = GetOpReg(fun, offset.kind, len(op_stack)) bbls[-1].AddIns( ir.Ins(o.ADD, [tmp, offset, ir.Const(offset.kind, args[1])])) offset = tmp dk_tmp = STORE_TO_CWERG_TYPE.get(opc.name) if dk_tmp is not None: tmp = GetOpReg(fun, dk_tmp, len(op_stack) + 1) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp, val])) val = tmp bbls[-1].AddIns(ir.Ins(o.ST, [mem_base, offset, val])) elif opc.kind is wasm_opc.OPC_KIND.LOAD: offset = op_stack.pop(-1) if args[1] != 0: tmp = GetOpReg(fun, offset.kind, len(op_stack)) bbls[-1].AddIns( ir.Ins(o.ADD, [tmp, offset, ir.Const(offset.kind, args[1])])) offset = tmp dst = GetOpReg(fun, OPC_TYPE_TO_CWERG_TYPE[opc.op_type], len(op_stack)) op_stack.append(dst) dk_tmp = LOAD_TO_CWERG_TYPE.get(opc.basename) if dk_tmp: tmp = GetOpReg(fun, dk_tmp, len(op_stack)) bbls[-1].AddIns(ir.Ins(o.LD, [tmp, mem_base, offset])) bbls[-1].AddIns(ir.Ins(o.CONV, [dst, tmp])) else: bbls[-1].AddIns(ir.Ins(o.LD, [dst, mem_base, offset])) elif opc is wasm_opc.BR_TABLE: bbl_tab = { n: GetTargetBbl(block_stack, x) for n, x in enumerate(args[0]) } bbl_def = GetTargetBbl(block_stack, args[1]) op = op_stack.pop(-1) tab_size = ir.Const(ToUnsigned(op.kind), len(bbl_tab)) jtb_count += 1 jtb = fun.AddJtb( ir.Jtb(f"jtb_{jtb_count}", bbl_def, bbl_tab, tab_size.value)) reg_unsigned = GetOpReg(fun, ToUnsigned(op.kind), len(op_stack)) bbls[-1].AddIns(ir.Ins(o.CONV, [reg_unsigned, op])) bbls[-1].AddIns(ir.Ins(o.BLE, [tab_size, reg_unsigned, bbl_def])) bbls[-1].AddIns(ir.Ins(o.SWITCH, [reg_unsigned, jtb])) elif opc is wasm_opc.UNREACHABLE: bbls[-1].AddIns(ir.Ins(o.TRAP, [])) elif opc is wasm_opc.MEMORY_GROW or opc is wasm_opc.MEMORY_SIZE: op = ZERO_S if opc is wasm_opc.MEMORY_GROW: op = op_stack.pop(-1) bbls[-1].AddIns(ir.Ins(o.PUSHARG, [op])) assert unit.GetFun("__memory_grow") bbls[-1].AddIns(ir.Ins(o.BSR, [unit.GetFun("__memory_grow")])) dst = GetOpReg(fun, o.DK.S32, len(op_stack)) bbls[-1].AddIns(ir.Ins(o.POPARG, [dst])) op_stack.append(dst) else: assert False, f"unsupported opcode [{opc.name}]" assert not op_stack, f"op_stack not empty in {fun.name}: {op_stack}" assert not block_stack, f"block_stack not empty in {fun.name}: {block_stack}" assert len(bbls) == len(fun.bbls) fun.bbls = bbls
def GetOpReg(fun: ir.Fun, dk: o.DK, pos: int) -> ir.Reg: reg_name = f"$op_{pos}_{dk.name}" reg = fun.MaybeGetReg(reg_name) return reg if reg else fun.AddReg(ir.Reg(reg_name, dk))
def testMov(self): for kind in [o.DK.S32, o.DK.U32, o.DK.A32, o.DK.F32]: ins = ir.Ins(o.MOV, [ir.Reg(name="0", kind=kind), ir.Reg(name="1", kind=kind)]) assert isel_tab.FindMatchingPattern(ins) is not None