def GenerateInitGlobalVarsFun(mod: wasm.Module, unit: ir.Unit, addr_type: o.DK) -> ir.Fun: fun = unit.AddFun(ir.Fun("init_global_vars_fun", o.FUN_KIND.NORMAL, [], [])) bbl = fun.AddBbl(ir.Bbl("start")) epilog = fun.AddBbl(ir.Bbl("end")) epilog.AddIns(ir.Ins(o.RET, [])) section = mod.sections.get(wasm.SECTION_ID.GLOBAL) if not section: return fun val32 = fun.AddReg(ir.Reg("val32", o.DK.U32)) val64 = fun.AddReg(ir.Reg("val64", o.DK.U64)) for n, data in enumerate(section.items): kind = o.MEM_KIND.RO if data.global_type.mut is wasm.MUT.CONST else o.MEM_KIND.RW mem = unit.AddMem(ir.Mem(f"global_vars_{n}", 16, kind)) ins = GetInsFromInitializerExpression(data.expr) var_type = data.global_type.value_type if ins.opcode is wasm_opc.GLOBAL_GET: mem.AddData( ir.DataBytes(1, b"\0" * (4 if var_type.is_32bit() else 8))) src_mem = unit.GetMem(f"global_vars_{int(ins.args[0])}") reg = val32 if var_type.is_32bit() else val64 bbl.AddIns(ir.Ins(o.LD_MEM, [reg, src_mem, ZERO])) bbl.AddIns(ir.Ins(o.ST_MEM, [mem, ZERO, reg])) elif ins.opcode.kind is wasm_opc.OPC_KIND.CONST: mem.AddData( ir.DataBytes(1, ExtractBytesFromConstIns(ins, var_type))) else: assert False, f"unsupported init instructions {ins}" return fun
def MaybeMakeGlobalTable(mod: wasm.Module, unit: ir.Unit, addr_type: o.DK): bit_width = addr_type.bitwidth() table_sec = mod.sections.get(wasm.SECTION_ID.TABLE) table_elements = mod.sections.get(wasm.SECTION_ID.ELEMENT) if not table_sec: return None global_table = None assert table_elements assert len(table_sec.items) == 1 table_type: wasm.TableType = table_sec.items[0].table_type assert table_type.element_type == wasm.REF_TYPE.FUNCREF table_data = [None] * table_type.limits.max for elem in table_elements.items: ins = GetInsFromInitializerExpression(elem.expr) assert ins.opcode is wasm_opc.I32_CONST start = ins.args[0] for n, fun in enumerate(elem.funcidxs): table_data[start + n] = fun global_table = unit.AddMem( ir.Mem("global_table", bit_width // 8, o.MEM_KIND.RO)) width = addr_type.bitwidth() // 8 for fun in table_data: if fun is None: global_table.AddData(ir.DataBytes(width, b"\0")) else: assert isinstance(fun, wasm.FuncIdx) fun = unit.GetFun(mod.functions[int(fun)].name) global_table.AddData(ir.DataAddrFun(width, fun)) return global_table
def DirFun(unit: ir.Unit, operands: List): name, kind, output_types, input_types = operands if len(input_types) > o.MAX_PARAMETERS or len( output_types) > o.MAX_PARAMETERS: raise ParseError(f"parameter list too long {name}") fun = unit.GetFun(name) if fun is None: fun = ir.Fun(name, kind, output_types, input_types) unit.AddFun(fun) elif fun.forward_declared: unit.InitForwardDeclaredFun(fun, kind, output_types, input_types) else: raise ParseError(f"duplicate Fun {name}")
def DirMem(unit: ir.Unit, operands: List): name, alignment, kind = operands mem = unit.GetMem(name) if mem is None: unit.AddMem(ir.Mem(name, alignment, kind)) elif kind is o.MEM_KIND.EXTERN: return elif mem.kind is o.MEM_KIND.EXTERN: mem.kind = kind mem.alignment = alignment # move fun to make it current unit.mems.remove(mem) unit.mems.append(mem) else: raise ParseError(f"Duplicate Mem symbol: {name}")
def GenerateMemcpyFun(unit: ir.Unit, addr_type: o.DK) -> ir.Fun: fun = unit.AddFun( ir.Fun("$memcpy", o.FUN_KIND.NORMAL, [], [addr_type, addr_type, o.DK.U32])) dst = fun.AddReg(ir.Reg("dst", addr_type)) src = fun.AddReg(ir.Reg("src", addr_type)) cnt = fun.AddReg(ir.Reg("cnt", o.DK.U32)) data = fun.AddReg(ir.Reg("data", o.DK.U8)) prolog = fun.AddBbl(ir.Bbl("prolog")) loop = fun.AddBbl(ir.Bbl("loop")) epilog = fun.AddBbl(ir.Bbl("epilog")) prolog.AddIns(ir.Ins(o.POPARG, [dst])) prolog.AddIns(ir.Ins(o.POPARG, [src])) prolog.AddIns(ir.Ins(o.POPARG, [cnt])) prolog.AddIns(ir.Ins(o.BRA, [epilog])) loop.AddIns(ir.Ins(o.SUB, [cnt, cnt, ONE])) loop.AddIns(ir.Ins(o.LD, [data, src, cnt])) loop.AddIns(ir.Ins(o.ST, [dst, cnt, data])) epilog.AddIns(ir.Ins(o.BLT, [ZERO, cnt, loop])) epilog.AddIns(ir.Ins(o.RET, [])) return fun
def _GetOperand(unit: ir.Unit, fun: ir.Fun, ok: o.OP_KIND, v: Any) -> Any: if ok in o.OKS_LIST: assert isinstance(v, list) or v[0] == v[-1] == '"', f"operand {ok}: [{v}]" else: assert isinstance(v, str), f"bad operand {v} of type [{ok}]" if ok is o.OP_KIND.TYPE_LIST: out = [] for kind_name in v: kind = o.SHORT_STR_TO_RK.get(kind_name) assert kind is not None, f"bad kind name [{kind_name}]" out.append(kind) return out elif ok is o.OP_KIND.FUN: return unit.GetFunOrAddForwardDeclaration(v) elif ok is o.OP_KIND.BBL: return fun.GetBblOrAddForwardDeclaration(v) elif ok is o.OP_KIND.BBL_TAB: return ExtractBblTable(fun, v) elif ok is o.OP_KIND.MEM: return unit.GetMem(v) elif ok is o.OP_KIND.STK: return fun.GetStk(v) elif ok is o.OP_KIND.FUN_KIND: return o.SHORT_STR_TO_FK[v] elif ok is o.OP_KIND.DATA_KIND: rk = o.SHORT_STR_TO_RK.get(v) assert rk is not None, f"bad kind name [{v}]" return rk elif ok is o.OP_KIND.NAME: assert parse.RE_IDENTIFIER.match(v), f"bad identifier [{v}]" return v elif ok is o.OP_KIND.NAME_LIST: for x in v: assert parse.RE_IDENTIFIER.match(x), f"bad identifier [{x}]" return v elif ok is o.OP_KIND.MEM_KIND: return o.SHORT_STR_TO_MK[v] elif ok is o.OP_KIND.VALUE: return v elif ok is o.OP_KIND.BYTES: return ExtractBytes(v) elif ok is o.OP_KIND.JTB: return fun.GetJbl(v) else: raise ir.ParseError(f"cannot read op type: {ok}")
def _InsRewriteFltImmediates(ins: ir.Ins, fun: ir.Fun, unit: ir.Unit) -> Optional[List[ir.Ins]]: inss = [] for n, op in enumerate(ins.operands): if isinstance(op, ir.Const) and op.kind.flavor() is o.DK_FLAVOR_F: mem = unit.FindOrAddConstMem(op) tmp = fun.GetScratchReg(op.kind, "flt_const", True) inss.append(ir.Ins(o.LD_MEM, [tmp, mem, _ZERO_OFFSET])) ins.operands[n] = tmp if inss: return inss + [ins] return None
def GenerateStartup(unit: ir.Unit, global_argc, global_argv, main: ir.Fun, init_global: ir.Fun, init_data: ir.Fun, initial_heap_size_pages: int, addr_type: o.DK) -> ir.Fun: bit_width = addr_type.bitwidth() global_mem_base = unit.AddMem(ir.Mem("__memory_base", 0, o.MEM_KIND.EXTERN)) fun = unit.AddFun( ir.Fun("main", o.FUN_KIND.NORMAL, [o.DK.U32], [o.DK.U32, addr_type])) argc = fun.AddReg(ir.Reg("argc", o.DK.U32)) argv = fun.AddReg(ir.Reg("argv", addr_type)) bbl = fun.AddBbl(ir.Bbl("start")) bbl.AddIns(ir.Ins(o.POPARG, [argc])) bbl.AddIns(ir.Ins(o.POPARG, [argv])) bbl.AddIns(ir.Ins(o.ST_MEM, [global_argc, ZERO, argc])) bbl.AddIns(ir.Ins(o.ST_MEM, [global_argv, ZERO, argv])) bbl.AddIns(ir.Ins(o.BSR, [unit.GetFun("__wasi_init")])) if initial_heap_size_pages: bbl.AddIns( ir.Ins(o.PUSHARG, [ir.Const(o.DK.S32, initial_heap_size_pages)])) bbl.AddIns(ir.Ins(o.BSR, [unit.GetFun("__memory_grow")])) bbl.AddIns(ir.Ins(o.POPARG, [fun.AddReg(ir.Reg("dummy", o.DK.S32))])) mem_base = fun.AddReg(ir.Reg("mem_base", addr_type)) bbl.AddIns(ir.Ins(o.LD_MEM, [mem_base, global_mem_base, ZERO])) if init_global: bbl.AddIns(ir.Ins(o.BSR, [init_global])) if init_data: bbl.AddIns(ir.Ins(o.PUSHARG, [mem_base])) bbl.AddIns(ir.Ins(o.BSR, [init_data])) bbl.AddIns(ir.Ins(o.PUSHARG, [mem_base])) bbl.AddIns(ir.Ins(o.BSR, [main])) bbl.AddIns(ir.Ins(o.PUSHARG, [ir.Const(o.DK.U32, 0)])) bbl.AddIns(ir.Ins(o.RET, [])) return fun
def GenerateInitDataFun(mod: wasm.Module, unit: ir.Unit, memcpy: ir.Fun, addr_type: o.DK) -> typing.Optional[ir.Fun]: fun = unit.AddFun( ir.Fun("init_data_fun", o.FUN_KIND.NORMAL, [], [addr_type])) bbl = fun.AddBbl(ir.Bbl("start")) epilog = fun.AddBbl(ir.Bbl("end")) epilog.AddIns(ir.Ins(o.RET, [])) section = mod.sections.get(wasm.SECTION_ID.DATA) mem_base = fun.AddReg(ir.Reg("mem_base", addr_type)) bbl.AddIns(ir.Ins(o.POPARG, [mem_base])) if not section: return None offset = fun.AddReg(ir.Reg("offset", o.DK.S32)) src = fun.AddReg(ir.Reg("src", addr_type)) dst = fun.AddReg(ir.Reg("dst", addr_type)) for n, data in enumerate(section.items): assert data.memory_index == 0 assert isinstance(data.offset, wasm.Expression) ins = GetInsFromInitializerExpression(data.offset) init = unit.AddMem(ir.Mem(f"global_init_mem_{n}", 16, o.MEM_KIND.RO)) init.AddData(ir.DataBytes(1, data.init)) if ins.opcode is wasm_opc.GLOBAL_GET: src_mem = unit.GetMem(f"global_vars_{int(ins.args[0])}") bbl.AddIns(ir.Ins(o.LD_MEM, [offset, src_mem, ZERO])) elif ins.opcode is wasm_opc.I32_CONST: bbl.AddIns(ir.Ins(o.MOV, [offset, ir.Const(o.DK.S32, ins.args[0])])) else: assert False, f"unsupported init instructions {ins}" bbl.AddIns(ir.Ins(o.LEA, [dst, mem_base, offset])) bbl.AddIns(ir.Ins(o.LEA_MEM, [src, init, ZERO])) bbl.AddIns(ir.Ins(o.PUSHARG, [ir.Const(o.DK.U32, len(data.init))])) bbl.AddIns(ir.Ins(o.PUSHARG, [src])) bbl.AddIns(ir.Ins(o.PUSHARG, [dst])) bbl.AddIns(ir.Ins(o.BSR, [memcpy])) return fun
def _InsMoveImmediatesToMemory(ins: ir.Ins, fun: ir.Fun, unit: ir.Unit, kind: o.DK) -> Optional[List[ir.Ins]]: inss = [] for n, op in enumerate(ins.operands): if isinstance(op, ir.Const) and op.kind is kind: mem = unit.FindOrAddConstMem(op) tmp = fun.GetScratchReg(kind, "mem_const", True) # TODO: pass the offset kind as a parameter inss.append(ir.Ins(o.LD_MEM, [tmp, mem, ir.Const(o.DK.U32, 0)])) ins.operands[n] = tmp if inss: return inss + [ins] return None
def DirFun(unit: ir.Unit, operands: List): name, kind, output_types, input_types = operands if len(input_types) > o.MAX_PARAMETERS or len( output_types) > o.MAX_PARAMETERS: raise ParseError(f"parameter list too long {name}") fun = unit.GetFun(name) if fun is None: fun = ir.Fun(name, kind, output_types, input_types) unit.AddFun(fun) elif fun.kind is o.FUN_KIND.INVALID: # forward_declared unit.InitForwardDeclaredFun(fun, kind, output_types, input_types) elif fun.kind is o.FUN_KIND.EXTERN or kind is o.FUN_KIND.EXTERN: assert output_types == fun.output_types assert input_types == fun.input_types if kind is o.FUN_KIND.EXTERN: # we already have a proper function return # definition of a formerly extern functions fun.kind = kind # move fun to make it current unit.funs.remove(fun) unit.funs.append(fun) else: raise ParseError(f"duplicate Fun {name}")
def InsEliminateImmediateViaMem(ins: ir.Ins, pos: int, fun: ir.Fun, unit: ir.Unit, addr_kind: o.DK, offset_kind: o.DK) -> List[ir.Ins]: """Rewrite instruction with an immediate as load of the immediate This is useful if the target architecture does not support immediate for that instruction, or the immediate is too large. This optimization is run rather late and may already see machine registers. """ # support of PUSHARG would require additional work because they need to stay consecutive assert ins.opcode is not o.PUSHARG const = ins.operands[pos] mem = unit.FindOrAddConstMem(const) tmp_addr = fun.GetScratchReg(addr_kind, "mem_const_addr", True) lea_ins = ir.Ins(o.LEA_MEM, [tmp_addr, mem, ir.Const(offset_kind, 0)]) tmp = fun.GetScratchReg(const.kind, "mem_const", True) ld_ins = ir.Ins(o.LD, [tmp, tmp_addr, ir.Const(offset_kind, 0)]) ins.operands[pos] = tmp return [lea_ins, ld_ins]
def DirData(unit: ir.Unit, operands: List): count, data = operands unit.AddData(ir.DataBytes(count, data))
def GenerateFun(unit: ir.Unit, mod: wasm.Module, wasm_fun: wasm.Function, fun: ir.Fun, global_table, addr_type): # op_stack contains regs produced by GetOpReg (it may only occur a the position encoding in its name op_stack: typing.List[typing.Union[ir.Reg, ir.Const]] = [] block_stack: typing.List[Block] = [] bbls: typing.List[ir.Bbl] = [] bbl_count = 0 jtb_count = 0 bbls.append(fun.AddBbl(ir.Bbl("start"))) loc_index = 0 assert fun.input_types[0] is addr_type mem_base = fun.AddReg(ir.Reg("mem_base", addr_type)) bbls[-1].AddIns(ir.Ins(o.POPARG, [mem_base])) for dk in fun.input_types[1:]: reg = fun.AddReg(ir.Reg(f"$loc_{loc_index}", dk)) loc_index += 1 bbls[-1].AddIns(ir.Ins(o.POPARG, [reg])) for locals in wasm_fun.impl.locals_list: for i in range(locals.count): reg = fun.AddReg( ir.Reg(f"$loc_{loc_index}", WASM_TYPE_TO_CWERG_TYPE[locals.kind])) loc_index += 1 bbls[-1].AddIns(ir.Ins(o.MOV, [reg, ir.Const(reg.kind, 0)])) DEBUG = False if DEBUG: print() print( f"# {fun.name} #ins:{len(wasm_fun.impl.expr.instructions)} in:{fun.input_types} out:{fun.output_types}" ) opc = None last_opc = None for n, wasm_ins in enumerate(wasm_fun.impl.expr.instructions): last_opc = opc opc = wasm_ins.opcode args = wasm_ins.args op_stack_size_before = len(op_stack) if DEBUG: print(f"#@@ {opc.name}", args, len(op_stack)) if opc.kind is wasm_opc.OPC_KIND.CONST: # breaks for floats # breaks for floats kind = OPC_TYPE_TO_CWERG_TYPE[opc.op_type] dst = GetOpReg(fun, kind, len(op_stack)) bbls[-1].AddIns(ir.Ins(o.MOV, [dst, ir.Const(kind, args[0])])) op_stack.append(dst) elif opc is wasm_opc.NOP: pass elif opc is wasm_opc.DROP: op_stack.pop(-1) elif opc is wasm_opc.LOCAL_GET: loc = GetLocalReg(fun, int(args[0])) dst = GetOpReg(fun, loc.kind, len(op_stack)) bbls[-1].AddIns(ir.Ins(o.MOV, [dst, loc])) op_stack.append(dst) elif opc is wasm_opc.LOCAL_SET: op = op_stack.pop(-1) bbls[-1].AddIns(ir.Ins(o.MOV, [GetLocalReg(fun, int(args[0])), op])) elif opc is wasm_opc.LOCAL_TEE: op = op_stack[-1] # no pop! bbls[-1].AddIns(ir.Ins(o.MOV, [GetLocalReg(fun, int(args[0])), op])) elif opc is wasm_opc.GLOBAL_GET: var_index = int(args[0]) var: wasm.Glob = mod.sections.get( wasm.SECTION_ID.GLOBAL).items[var_index] dst = GetOpReg(fun, WASM_TYPE_TO_CWERG_TYPE[var.global_type.value_type], len(op_stack)) var_mem = unit.GetMem(f"global_vars_{var_index}") bbls[-1].AddIns(ir.Ins(o.LD_MEM, [dst, var_mem, ZERO])) op_stack.append(dst) elif opc is wasm_opc.GLOBAL_SET: op = op_stack.pop(-1) var_index = int(args[0]) var_mem = unit.GetMem(f"global_vars_{var_index}") bbls[-1].AddIns(ir.Ins(o.ST_MEM, [var_mem, ZERO, op])) elif opc.kind is wasm_opc.OPC_KIND.ALU: if wasm_opc.FLAGS.UNARY in opc.flags: opcode, arg_factory = WASM_ALU1_TO_CWERG[opc.basename] op = op_stack.pop(-1) dst = GetOpReg(fun, op.kind, len(op_stack)) bbls[-1].AddIns(ir.Ins(opcode, [dst] + arg_factory(op))) op_stack.append(dst) else: op2 = op_stack.pop(-1) op1 = op_stack.pop(-1) dst = GetOpReg(fun, op1.kind, len(op_stack)) alu = WASM_ALU_TO_CWERG[opc.basename] if wasm_opc.FLAGS.UNSIGNED in opc.flags: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) tmp3 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 3) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2])) if isinstance(alu, o.Opcode): bbls[-1].AddIns(ir.Ins(alu, [tmp3, tmp1, tmp2])) else: alu(tmp3, tmp1, tmp2, bbls[-1]) bbls[-1].AddIns(ir.Ins(o.CONV, [dst, tmp3])) else: if isinstance(alu, o.Opcode): bbls[-1].AddIns(ir.Ins(alu, [dst, op1, op2])) else: alu(dst, op1, op2, bbls[-1]) op_stack.append(dst) elif opc.kind is wasm_opc.OPC_KIND.CONV: conv, dst_unsigned, src_unsigned = WASM_CONV_TO_CWERG[opc.name] op = op_stack.pop(-1) dst = GetOpReg(fun, OPC_TYPE_TO_CWERG_TYPE[opc.op_type], len(op_stack)) if src_unsigned: tmp = GetOpReg(fun, ToUnsigned(op.kind), op_stack_size_before + 1) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp, op])) op = tmp if dst_unsigned: tmp = GetOpReg(fun, ToUnsigned(dst.kind), op_stack_size_before + 1) dst, tmp = tmp, dst bbls[-1].AddIns(ir.Ins(conv, [dst, op])) if dst_unsigned: bbls[-1].AddIns(ir.Ins(o.CONV, [tmp, dst])) dst = tmp op_stack.append(dst) elif opc.kind is wasm_opc.OPC_KIND.BITCAST: assert opc.name in SUPPORTED_BITCAST op = op_stack.pop(-1) dst = GetOpReg(fun, OPC_TYPE_TO_CWERG_TYPE[opc.op_type], len(op_stack)) bbls[-1].AddIns(ir.Ins(o.BITCAST, [dst, op])) op_stack.append(dst) elif opc.kind is wasm_opc.OPC_KIND.CMP: # this always works because of the sentinel: "end" succ = wasm_fun.impl.expr.instructions[n + 1] if succ.opcode not in { wasm_opc.IF, wasm_opc.BR_IF, wasm_opc.SELECT }: cmp, res1, res2, op1, op2, unsigned = MakeCompare( opc, op_stack) if unsigned: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2])) op1 = tmp1 op2 = tmp2 dst = GetOpReg(fun, o.DK.S32, len(op_stack)) bbls[-1].AddIns(ir.Ins(cmp, [dst, res1, res2, op1, op2])) op_stack.append(dst) elif opc is wasm_opc.LOOP or opc is wasm_opc.BLOCK: block_stack.append( MakeBlock(bbl_count, opc, args, fun, op_stack, mod)) bbl_count += 1 bbls.append(block_stack[-1].start_bbl) elif opc is wasm_opc.IF: # note we do set the new bbl right away because we add some instructions to the old one # this always works because the stack cannot be empty at this point pred = wasm_fun.impl.expr.instructions[n - 1].opcode br, op1, op2, unsigned = MakeBranch(pred, op_stack, True) if unsigned: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2])) op1 = tmp1 op2 = tmp2 block_stack.append( MakeBlock(bbl_count, opc, args, fun, op_stack, mod)) # assert len(block_stack[-1].param_types) == 0 bbl_count += 1 bbls[-1].AddIns(ir.Ins(br, [op1, op2, block_stack[-1].else_bbl])) bbls.append(block_stack[-1].start_bbl) elif opc is wasm_opc.ELSE: block = block_stack[-1] assert block.opcode is wasm_opc.IF block.FinalizeIf(op_stack, bbls[-1], fun, last_opc in {wasm_opc.UNREACHABLE, wasm_opc.BR}) if DEBUG: print(f"#@@ AFTER STACK RESTORE", len(op_stack)) op_stack = op_stack[0:block.stack_start] bbls[-1].AddIns(ir.Ins(o.BRA, [block.end_bbl])) assert block.else_bbl is not None bbls.append(block.else_bbl) block.else_bbl = None elif opc is wasm_opc.END: if block_stack: block = block_stack.pop(-1) block.FinalizeEnd( op_stack, bbls[-1], fun, last_opc in {wasm_opc.UNREACHABLE, wasm_opc.BR}) if block.else_bbl: bbls.append(block.else_bbl) bbls.append(block.end_bbl) else: # end of function assert n + 1 == len(wasm_fun.impl.expr.instructions) pred = wasm_fun.impl.expr.instructions[n - 1].opcode if pred not in { wasm_opc.RETURN, wasm_opc.UNREACHABLE, wasm_opc.BR }: for x in reversed(fun.output_types): op = op_stack.pop(-1) assert op.kind == x, f"outputs: {fun.output_types} mismatch {op.kind} vs {x}" bbls[-1].AddIns(ir.Ins(o.PUSHARG, [op])) bbls[-1].AddIns(ir.Ins(o.RET, [])) elif opc is wasm_opc.CALL: wasm_callee = mod.functions[int(wasm_ins.args[0])] callee = unit.GetFun(wasm_callee.name) assert callee, f"unknown fun: {wasm_callee.name}" EmitCall(fun, bbls[-1], ir.Ins(o.BSR, [callee]), op_stack, mem_base, callee) elif opc is wasm_opc.CALL_INDIRECT: assert isinstance(args[1], wasm.TableIdx), f"{type(args[1])}" assert int(args[1]) == 0, f"only one table supported" assert isinstance(args[0], wasm.TypeIdx), f"{type(args[0])}" type_sec = mod.sections.get(wasm.SECTION_ID.TYPE) func_type: wasm.FunctionType = type_sec.items[int(args[0])] arguments = [addr_type] + TranslateTypeList(func_type.args) returns = TranslateTypeList(func_type.rets) # print (f"# @@@@ CALL INDIRECT {returns} <- {arguments} [{int(args[0])}] {func_type}") signature = FindFunWithSignature(unit, arguments, returns) table_reg = GetOpReg(fun, addr_type, len(op_stack)) code_type = o.DK.C32 if addr_type is o.DK.A32 else o.DK.C64 fun_reg = GetOpReg(fun, code_type, len(op_stack) + 1) index = op_stack.pop(-1) assert index.kind is o.DK.S32 bbls[-1].AddIns( ir.Ins(o.MUL, [ index, index, ir.Const(o.DK.U32, code_type.bitwidth() // 8) ])) bbls[-1].AddIns(ir.Ins(o.LEA_MEM, [table_reg, global_table, ZERO])) bbls[-1].AddIns(ir.Ins(o.LD, [fun_reg, table_reg, index])) EmitCall(fun, bbls[-1], ir.Ins(o.JSR, [fun_reg, signature]), op_stack, mem_base, signature) elif opc is wasm_opc.RETURN: for x in reversed(fun.output_types): op = op_stack.pop(-1) assert op.kind == x, f"outputs: {fun.output_types} mismatch {op.kind} vs {x}" bbls[-1].AddIns(ir.Ins(o.PUSHARG, [op])) bbls[-1].AddIns(ir.Ins(o.RET, [])) elif opc is wasm_opc.BR: assert isinstance(args[0], wasm.LabelIdx) block = GetTargetBlock(block_stack, args[0]) target = block.start_bbl if block.opcode is not wasm_opc.LOOP: target = block.end_bbl block.FinalizeResultsCopy(op_stack, bbls[-1], fun) bbls[-1].AddIns(ir.Ins(o.BRA, [target])) elif opc is wasm_opc.BR_IF: assert isinstance(args[0], wasm.LabelIdx) pred = wasm_fun.impl.expr.instructions[n - 1].opcode br, op1, op2, unsigned = MakeBranch(pred, op_stack, False) if unsigned: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2])) op1 = tmp1 op2 = tmp2 block = GetTargetBlock(block_stack, args[0]) target = block.start_bbl if block.opcode is not wasm_opc.LOOP: target = block.end_bbl block.FinalizeResultsCopy(op_stack, bbls[-1], fun) bbls[-1].AddIns(ir.Ins(br, [op1, op2, target])) elif opc is wasm_opc.SELECT: pred = wasm_fun.impl.expr.instructions[n - 1].opcode br, op1, op2, unsigned = MakeBranch(pred, op_stack, False) val_f = op_stack.pop(-1) val_t = op_stack.pop(-1) assert val_f.kind == val_t.kind reg = GetOpReg(fun, val_f.kind, len(op_stack)) op_stack.append(reg) bbls[-1].AddIns(ir.Ins(o.MOV, [reg, val_t])) bbls.append(fun.AddBbl(ir.Bbl(f"select_{bbl_count}"))) bbl_count += 1 if unsigned: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) bbls[-2].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-2].AddIns(ir.Ins(o.CONV, [tmp2, op2])) op1 = tmp1 op2 = tmp2 bbls[-2].AddIns(ir.Ins(br, [op1, op2, bbls[-1]])) bbls[-2].AddIns(ir.Ins(o.MOV, [reg, val_f])) elif opc.kind is wasm_opc.OPC_KIND.STORE: val = op_stack.pop(-1) offset = op_stack.pop(-1) if args[1] != 0: tmp = GetOpReg(fun, offset.kind, len(op_stack)) bbls[-1].AddIns( ir.Ins(o.ADD, [tmp, offset, ir.Const(offset.kind, args[1])])) offset = tmp dk_tmp = STORE_TO_CWERG_TYPE.get(opc.name) if dk_tmp is not None: tmp = GetOpReg(fun, dk_tmp, len(op_stack) + 1) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp, val])) val = tmp bbls[-1].AddIns(ir.Ins(o.ST, [mem_base, offset, val])) elif opc.kind is wasm_opc.OPC_KIND.LOAD: offset = op_stack.pop(-1) if args[1] != 0: tmp = GetOpReg(fun, offset.kind, len(op_stack)) bbls[-1].AddIns( ir.Ins(o.ADD, [tmp, offset, ir.Const(offset.kind, args[1])])) offset = tmp dst = GetOpReg(fun, OPC_TYPE_TO_CWERG_TYPE[opc.op_type], len(op_stack)) op_stack.append(dst) dk_tmp = LOAD_TO_CWERG_TYPE.get(opc.basename) if dk_tmp: tmp = GetOpReg(fun, dk_tmp, len(op_stack)) bbls[-1].AddIns(ir.Ins(o.LD, [tmp, mem_base, offset])) bbls[-1].AddIns(ir.Ins(o.CONV, [dst, tmp])) else: bbls[-1].AddIns(ir.Ins(o.LD, [dst, mem_base, offset])) elif opc is wasm_opc.BR_TABLE: bbl_tab = { n: GetTargetBbl(block_stack, x) for n, x in enumerate(args[0]) } bbl_def = GetTargetBbl(block_stack, args[1]) op = op_stack.pop(-1) tab_size = ir.Const(ToUnsigned(op.kind), len(bbl_tab)) jtb_count += 1 jtb = fun.AddJtb( ir.Jtb(f"jtb_{jtb_count}", bbl_def, bbl_tab, tab_size.value)) reg_unsigned = GetOpReg(fun, ToUnsigned(op.kind), len(op_stack)) bbls[-1].AddIns(ir.Ins(o.CONV, [reg_unsigned, op])) bbls[-1].AddIns(ir.Ins(o.BLE, [tab_size, reg_unsigned, bbl_def])) bbls[-1].AddIns(ir.Ins(o.SWITCH, [reg_unsigned, jtb])) elif opc is wasm_opc.UNREACHABLE: bbls[-1].AddIns(ir.Ins(o.TRAP, [])) elif opc is wasm_opc.MEMORY_GROW or opc is wasm_opc.MEMORY_SIZE: op = ZERO_S if opc is wasm_opc.MEMORY_GROW: op = op_stack.pop(-1) bbls[-1].AddIns(ir.Ins(o.PUSHARG, [op])) assert unit.GetFun("__memory_grow") bbls[-1].AddIns(ir.Ins(o.BSR, [unit.GetFun("__memory_grow")])) dst = GetOpReg(fun, o.DK.S32, len(op_stack)) bbls[-1].AddIns(ir.Ins(o.POPARG, [dst])) op_stack.append(dst) else: assert False, f"unsupported opcode [{opc.name}]" assert not op_stack, f"op_stack not empty in {fun.name}: {op_stack}" assert not block_stack, f"block_stack not empty in {fun.name}: {block_stack}" assert len(bbls) == len(fun.bbls) fun.bbls = bbls
def DirAddrMem(unit: ir.Unit, operands: List): unit.AddData(ir.DataAddrMem(*operands))
def DirData(unit: ir.Unit, operands: List): unit.AddData(ir.DataBytes(*operands))
def DirMem(unit: ir.Unit, operands: List): unit.AddMem(ir.Mem(*operands))
def DirAddrMem(unit: ir.Unit, operands: List): size, mem, offset = operands unit.AddData(ir.DataAddrMem(size, mem, offset))
def DirAddrFun(unit: ir.Unit, operands: List): size, fun = operands unit.AddData(ir.DataAddrFun(size, fun))