def _GetRegOrConstOperand(fun: ir.Fun, last_kind: o.DK, ok: o.OP_KIND, tc: o.TC, token: str, regs_cpu: Dict[str, ir.Reg]) -> Any: if ok == o.OP_KIND.REG_OR_CONST: ok = o.OP_KIND.CONST if parse.IsLikelyConst(token) else o.OP_KIND.REG if ok is o.OP_KIND.REG: cpu_reg = None pos = token.find("@") if pos > 0: cpu_reg_name = token[pos + 1:] token = token[:pos] if cpu_reg_name == "STK": cpu_reg = ir.StackSlot(0) else: cpu_reg = regs_cpu.get(cpu_reg_name) assert cpu_reg is not None, f"unknown cpu_reg {token[pos + 1:]} known regs {regs_cpu.keys()}" pos = token.find(":") if pos < 0: reg = fun.GetReg(token) else: kind = token[pos + 1:] reg_name = token[:pos] reg = ir.Reg(reg_name, o.SHORT_STR_TO_RK.get(kind)) fun.AddReg(reg) assert o.CheckTypeConstraint(last_kind, tc, reg.kind) if cpu_reg: if reg.cpu_reg: assert reg.cpu_reg == cpu_reg else: reg.cpu_reg = cpu_reg return reg else: pos = token.find(":") if pos >= 0: kind = token[pos + 1:] value_str = token[:pos] const = ir.ParseConst(value_str, o.SHORT_STR_TO_RK.get(kind)) return const elif tc == o.TC.SAME_AS_PREV: const = ir.ParseConst(token, last_kind) return const elif tc == o.TC.OFFSET: const = ir.ParseOffsetConst(token) return const elif tc == o.TC.UINT: assert token[0] != "-" const = ir.ParseOffsetConst(token) return const else: assert False, f"cannot deduce type for const {token} [{tc}]"
def GenerateFun(unit: ir.Unit, mod: wasm.Module, wasm_fun: wasm.Function, fun: ir.Fun, global_table, addr_type): # op_stack contains regs produced by GetOpReg (it may only occur a the position encoding in its name op_stack: typing.List[typing.Union[ir.Reg, ir.Const]] = [] block_stack: typing.List[Block] = [] bbls: typing.List[ir.Bbl] = [] bbl_count = 0 jtb_count = 0 bbls.append(fun.AddBbl(ir.Bbl("start"))) loc_index = 0 assert fun.input_types[0] is addr_type mem_base = fun.AddReg(ir.Reg("mem_base", addr_type)) bbls[-1].AddIns(ir.Ins(o.POPARG, [mem_base])) for dk in fun.input_types[1:]: reg = fun.AddReg(ir.Reg(f"$loc_{loc_index}", dk)) loc_index += 1 bbls[-1].AddIns(ir.Ins(o.POPARG, [reg])) for locals in wasm_fun.impl.locals_list: for i in range(locals.count): reg = fun.AddReg( ir.Reg(f"$loc_{loc_index}", WASM_TYPE_TO_CWERG_TYPE[locals.kind])) loc_index += 1 bbls[-1].AddIns(ir.Ins(o.MOV, [reg, ir.Const(reg.kind, 0)])) DEBUG = False if DEBUG: print() print( f"# {fun.name} #ins:{len(wasm_fun.impl.expr.instructions)} in:{fun.input_types} out:{fun.output_types}" ) opc = None last_opc = None for n, wasm_ins in enumerate(wasm_fun.impl.expr.instructions): last_opc = opc opc = wasm_ins.opcode args = wasm_ins.args op_stack_size_before = len(op_stack) if DEBUG: print(f"#@@ {opc.name}", args, len(op_stack)) if opc.kind is wasm_opc.OPC_KIND.CONST: # breaks for floats # breaks for floats kind = OPC_TYPE_TO_CWERG_TYPE[opc.op_type] dst = GetOpReg(fun, kind, len(op_stack)) bbls[-1].AddIns(ir.Ins(o.MOV, [dst, ir.Const(kind, args[0])])) op_stack.append(dst) elif opc is wasm_opc.NOP: pass elif opc is wasm_opc.DROP: op_stack.pop(-1) elif opc is wasm_opc.LOCAL_GET: loc = GetLocalReg(fun, int(args[0])) dst = GetOpReg(fun, loc.kind, len(op_stack)) bbls[-1].AddIns(ir.Ins(o.MOV, [dst, loc])) op_stack.append(dst) elif opc is wasm_opc.LOCAL_SET: op = op_stack.pop(-1) bbls[-1].AddIns(ir.Ins(o.MOV, [GetLocalReg(fun, int(args[0])), op])) elif opc is wasm_opc.LOCAL_TEE: op = op_stack[-1] # no pop! bbls[-1].AddIns(ir.Ins(o.MOV, [GetLocalReg(fun, int(args[0])), op])) elif opc is wasm_opc.GLOBAL_GET: var_index = int(args[0]) var: wasm.Glob = mod.sections.get( wasm.SECTION_ID.GLOBAL).items[var_index] dst = GetOpReg(fun, WASM_TYPE_TO_CWERG_TYPE[var.global_type.value_type], len(op_stack)) var_mem = unit.GetMem(f"global_vars_{var_index}") bbls[-1].AddIns(ir.Ins(o.LD_MEM, [dst, var_mem, ZERO])) op_stack.append(dst) elif opc is wasm_opc.GLOBAL_SET: op = op_stack.pop(-1) var_index = int(args[0]) var_mem = unit.GetMem(f"global_vars_{var_index}") bbls[-1].AddIns(ir.Ins(o.ST_MEM, [var_mem, ZERO, op])) elif opc.kind is wasm_opc.OPC_KIND.ALU: if wasm_opc.FLAGS.UNARY in opc.flags: opcode, arg_factory = WASM_ALU1_TO_CWERG[opc.basename] op = op_stack.pop(-1) dst = GetOpReg(fun, op.kind, len(op_stack)) bbls[-1].AddIns(ir.Ins(opcode, [dst] + arg_factory(op))) op_stack.append(dst) else: op2 = op_stack.pop(-1) op1 = op_stack.pop(-1) dst = GetOpReg(fun, op1.kind, len(op_stack)) alu = WASM_ALU_TO_CWERG[opc.basename] if wasm_opc.FLAGS.UNSIGNED in opc.flags: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) tmp3 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 3) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2])) if isinstance(alu, o.Opcode): bbls[-1].AddIns(ir.Ins(alu, [tmp3, tmp1, tmp2])) else: alu(tmp3, tmp1, tmp2, bbls[-1]) bbls[-1].AddIns(ir.Ins(o.CONV, [dst, tmp3])) else: if isinstance(alu, o.Opcode): bbls[-1].AddIns(ir.Ins(alu, [dst, op1, op2])) else: alu(dst, op1, op2, bbls[-1]) op_stack.append(dst) elif opc.kind is wasm_opc.OPC_KIND.CONV: conv, dst_unsigned, src_unsigned = WASM_CONV_TO_CWERG[opc.name] op = op_stack.pop(-1) dst = GetOpReg(fun, OPC_TYPE_TO_CWERG_TYPE[opc.op_type], len(op_stack)) if src_unsigned: tmp = GetOpReg(fun, ToUnsigned(op.kind), op_stack_size_before + 1) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp, op])) op = tmp if dst_unsigned: tmp = GetOpReg(fun, ToUnsigned(dst.kind), op_stack_size_before + 1) dst, tmp = tmp, dst bbls[-1].AddIns(ir.Ins(conv, [dst, op])) if dst_unsigned: bbls[-1].AddIns(ir.Ins(o.CONV, [tmp, dst])) dst = tmp op_stack.append(dst) elif opc.kind is wasm_opc.OPC_KIND.BITCAST: assert opc.name in SUPPORTED_BITCAST op = op_stack.pop(-1) dst = GetOpReg(fun, OPC_TYPE_TO_CWERG_TYPE[opc.op_type], len(op_stack)) bbls[-1].AddIns(ir.Ins(o.BITCAST, [dst, op])) op_stack.append(dst) elif opc.kind is wasm_opc.OPC_KIND.CMP: # this always works because of the sentinel: "end" succ = wasm_fun.impl.expr.instructions[n + 1] if succ.opcode not in { wasm_opc.IF, wasm_opc.BR_IF, wasm_opc.SELECT }: cmp, res1, res2, op1, op2, unsigned = MakeCompare( opc, op_stack) if unsigned: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2])) op1 = tmp1 op2 = tmp2 dst = GetOpReg(fun, o.DK.S32, len(op_stack)) bbls[-1].AddIns(ir.Ins(cmp, [dst, res1, res2, op1, op2])) op_stack.append(dst) elif opc is wasm_opc.LOOP or opc is wasm_opc.BLOCK: block_stack.append( MakeBlock(bbl_count, opc, args, fun, op_stack, mod)) bbl_count += 1 bbls.append(block_stack[-1].start_bbl) elif opc is wasm_opc.IF: # note we do set the new bbl right away because we add some instructions to the old one # this always works because the stack cannot be empty at this point pred = wasm_fun.impl.expr.instructions[n - 1].opcode br, op1, op2, unsigned = MakeBranch(pred, op_stack, True) if unsigned: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2])) op1 = tmp1 op2 = tmp2 block_stack.append( MakeBlock(bbl_count, opc, args, fun, op_stack, mod)) # assert len(block_stack[-1].param_types) == 0 bbl_count += 1 bbls[-1].AddIns(ir.Ins(br, [op1, op2, block_stack[-1].else_bbl])) bbls.append(block_stack[-1].start_bbl) elif opc is wasm_opc.ELSE: block = block_stack[-1] assert block.opcode is wasm_opc.IF block.FinalizeIf(op_stack, bbls[-1], fun, last_opc in {wasm_opc.UNREACHABLE, wasm_opc.BR}) if DEBUG: print(f"#@@ AFTER STACK RESTORE", len(op_stack)) op_stack = op_stack[0:block.stack_start] bbls[-1].AddIns(ir.Ins(o.BRA, [block.end_bbl])) assert block.else_bbl is not None bbls.append(block.else_bbl) block.else_bbl = None elif opc is wasm_opc.END: if block_stack: block = block_stack.pop(-1) block.FinalizeEnd( op_stack, bbls[-1], fun, last_opc in {wasm_opc.UNREACHABLE, wasm_opc.BR}) if block.else_bbl: bbls.append(block.else_bbl) bbls.append(block.end_bbl) else: # end of function assert n + 1 == len(wasm_fun.impl.expr.instructions) pred = wasm_fun.impl.expr.instructions[n - 1].opcode if pred not in { wasm_opc.RETURN, wasm_opc.UNREACHABLE, wasm_opc.BR }: for x in reversed(fun.output_types): op = op_stack.pop(-1) assert op.kind == x, f"outputs: {fun.output_types} mismatch {op.kind} vs {x}" bbls[-1].AddIns(ir.Ins(o.PUSHARG, [op])) bbls[-1].AddIns(ir.Ins(o.RET, [])) elif opc is wasm_opc.CALL: wasm_callee = mod.functions[int(wasm_ins.args[0])] callee = unit.GetFun(wasm_callee.name) assert callee, f"unknown fun: {wasm_callee.name}" EmitCall(fun, bbls[-1], ir.Ins(o.BSR, [callee]), op_stack, mem_base, callee) elif opc is wasm_opc.CALL_INDIRECT: assert isinstance(args[1], wasm.TableIdx), f"{type(args[1])}" assert int(args[1]) == 0, f"only one table supported" assert isinstance(args[0], wasm.TypeIdx), f"{type(args[0])}" type_sec = mod.sections.get(wasm.SECTION_ID.TYPE) func_type: wasm.FunctionType = type_sec.items[int(args[0])] arguments = [addr_type] + TranslateTypeList(func_type.args) returns = TranslateTypeList(func_type.rets) # print (f"# @@@@ CALL INDIRECT {returns} <- {arguments} [{int(args[0])}] {func_type}") signature = FindFunWithSignature(unit, arguments, returns) table_reg = GetOpReg(fun, addr_type, len(op_stack)) code_type = o.DK.C32 if addr_type is o.DK.A32 else o.DK.C64 fun_reg = GetOpReg(fun, code_type, len(op_stack) + 1) index = op_stack.pop(-1) assert index.kind is o.DK.S32 bbls[-1].AddIns( ir.Ins(o.MUL, [ index, index, ir.Const(o.DK.U32, code_type.bitwidth() // 8) ])) bbls[-1].AddIns(ir.Ins(o.LEA_MEM, [table_reg, global_table, ZERO])) bbls[-1].AddIns(ir.Ins(o.LD, [fun_reg, table_reg, index])) EmitCall(fun, bbls[-1], ir.Ins(o.JSR, [fun_reg, signature]), op_stack, mem_base, signature) elif opc is wasm_opc.RETURN: for x in reversed(fun.output_types): op = op_stack.pop(-1) assert op.kind == x, f"outputs: {fun.output_types} mismatch {op.kind} vs {x}" bbls[-1].AddIns(ir.Ins(o.PUSHARG, [op])) bbls[-1].AddIns(ir.Ins(o.RET, [])) elif opc is wasm_opc.BR: assert isinstance(args[0], wasm.LabelIdx) block = GetTargetBlock(block_stack, args[0]) target = block.start_bbl if block.opcode is not wasm_opc.LOOP: target = block.end_bbl block.FinalizeResultsCopy(op_stack, bbls[-1], fun) bbls[-1].AddIns(ir.Ins(o.BRA, [target])) elif opc is wasm_opc.BR_IF: assert isinstance(args[0], wasm.LabelIdx) pred = wasm_fun.impl.expr.instructions[n - 1].opcode br, op1, op2, unsigned = MakeBranch(pred, op_stack, False) if unsigned: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2])) op1 = tmp1 op2 = tmp2 block = GetTargetBlock(block_stack, args[0]) target = block.start_bbl if block.opcode is not wasm_opc.LOOP: target = block.end_bbl block.FinalizeResultsCopy(op_stack, bbls[-1], fun) bbls[-1].AddIns(ir.Ins(br, [op1, op2, target])) elif opc is wasm_opc.SELECT: pred = wasm_fun.impl.expr.instructions[n - 1].opcode br, op1, op2, unsigned = MakeBranch(pred, op_stack, False) val_f = op_stack.pop(-1) val_t = op_stack.pop(-1) assert val_f.kind == val_t.kind reg = GetOpReg(fun, val_f.kind, len(op_stack)) op_stack.append(reg) bbls[-1].AddIns(ir.Ins(o.MOV, [reg, val_t])) bbls.append(fun.AddBbl(ir.Bbl(f"select_{bbl_count}"))) bbl_count += 1 if unsigned: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) bbls[-2].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-2].AddIns(ir.Ins(o.CONV, [tmp2, op2])) op1 = tmp1 op2 = tmp2 bbls[-2].AddIns(ir.Ins(br, [op1, op2, bbls[-1]])) bbls[-2].AddIns(ir.Ins(o.MOV, [reg, val_f])) elif opc.kind is wasm_opc.OPC_KIND.STORE: val = op_stack.pop(-1) offset = op_stack.pop(-1) if args[1] != 0: tmp = GetOpReg(fun, offset.kind, len(op_stack)) bbls[-1].AddIns( ir.Ins(o.ADD, [tmp, offset, ir.Const(offset.kind, args[1])])) offset = tmp dk_tmp = STORE_TO_CWERG_TYPE.get(opc.name) if dk_tmp is not None: tmp = GetOpReg(fun, dk_tmp, len(op_stack) + 1) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp, val])) val = tmp bbls[-1].AddIns(ir.Ins(o.ST, [mem_base, offset, val])) elif opc.kind is wasm_opc.OPC_KIND.LOAD: offset = op_stack.pop(-1) if args[1] != 0: tmp = GetOpReg(fun, offset.kind, len(op_stack)) bbls[-1].AddIns( ir.Ins(o.ADD, [tmp, offset, ir.Const(offset.kind, args[1])])) offset = tmp dst = GetOpReg(fun, OPC_TYPE_TO_CWERG_TYPE[opc.op_type], len(op_stack)) op_stack.append(dst) dk_tmp = LOAD_TO_CWERG_TYPE.get(opc.basename) if dk_tmp: tmp = GetOpReg(fun, dk_tmp, len(op_stack)) bbls[-1].AddIns(ir.Ins(o.LD, [tmp, mem_base, offset])) bbls[-1].AddIns(ir.Ins(o.CONV, [dst, tmp])) else: bbls[-1].AddIns(ir.Ins(o.LD, [dst, mem_base, offset])) elif opc is wasm_opc.BR_TABLE: bbl_tab = { n: GetTargetBbl(block_stack, x) for n, x in enumerate(args[0]) } bbl_def = GetTargetBbl(block_stack, args[1]) op = op_stack.pop(-1) tab_size = ir.Const(ToUnsigned(op.kind), len(bbl_tab)) jtb_count += 1 jtb = fun.AddJtb( ir.Jtb(f"jtb_{jtb_count}", bbl_def, bbl_tab, tab_size.value)) reg_unsigned = GetOpReg(fun, ToUnsigned(op.kind), len(op_stack)) bbls[-1].AddIns(ir.Ins(o.CONV, [reg_unsigned, op])) bbls[-1].AddIns(ir.Ins(o.BLE, [tab_size, reg_unsigned, bbl_def])) bbls[-1].AddIns(ir.Ins(o.SWITCH, [reg_unsigned, jtb])) elif opc is wasm_opc.UNREACHABLE: bbls[-1].AddIns(ir.Ins(o.TRAP, [])) elif opc is wasm_opc.MEMORY_GROW or opc is wasm_opc.MEMORY_SIZE: op = ZERO_S if opc is wasm_opc.MEMORY_GROW: op = op_stack.pop(-1) bbls[-1].AddIns(ir.Ins(o.PUSHARG, [op])) assert unit.GetFun("__memory_grow") bbls[-1].AddIns(ir.Ins(o.BSR, [unit.GetFun("__memory_grow")])) dst = GetOpReg(fun, o.DK.S32, len(op_stack)) bbls[-1].AddIns(ir.Ins(o.POPARG, [dst])) op_stack.append(dst) else: assert False, f"unsupported opcode [{opc.name}]" assert not op_stack, f"op_stack not empty in {fun.name}: {op_stack}" assert not block_stack, f"block_stack not empty in {fun.name}: {block_stack}" assert len(bbls) == len(fun.bbls) fun.bbls = bbls
def GetOpReg(fun: ir.Fun, dk: o.DK, pos: int) -> ir.Reg: reg_name = f"$op_{pos}_{dk.name}" reg = fun.MaybeGetReg(reg_name) return reg if reg else fun.AddReg(ir.Reg(reg_name, dk))