def MaybeMakeGlobalTable(mod: wasm.Module, unit: ir.Unit, addr_type: o.DK): bit_width = addr_type.bitwidth() table_sec = mod.sections.get(wasm.SECTION_ID.TABLE) table_elements = mod.sections.get(wasm.SECTION_ID.ELEMENT) if not table_sec: return None global_table = None assert table_elements assert len(table_sec.items) == 1 table_type: wasm.TableType = table_sec.items[0].table_type assert table_type.element_type == wasm.REF_TYPE.FUNCREF table_data = [None] * table_type.limits.max for elem in table_elements.items: ins = GetInsFromInitializerExpression(elem.expr) assert ins.opcode is wasm_opc.I32_CONST start = ins.args[0] for n, fun in enumerate(elem.funcidxs): table_data[start + n] = fun global_table = unit.AddMem( ir.Mem("global_table", bit_width // 8, o.MEM_KIND.RO)) width = addr_type.bitwidth() // 8 for fun in table_data: if fun is None: global_table.AddData(ir.DataBytes(width, b"\0")) else: assert isinstance(fun, wasm.FuncIdx) fun = unit.GetFun(mod.functions[int(fun)].name) global_table.AddData(ir.DataAddrFun(width, fun)) return global_table
def GenerateStartup(unit: ir.Unit, global_argc, global_argv, main: ir.Fun, init_global: ir.Fun, init_data: ir.Fun, initial_heap_size_pages: int, addr_type: o.DK) -> ir.Fun: bit_width = addr_type.bitwidth() global_mem_base = unit.AddMem(ir.Mem("__memory_base", 0, o.MEM_KIND.EXTERN)) fun = unit.AddFun( ir.Fun("main", o.FUN_KIND.NORMAL, [o.DK.U32], [o.DK.U32, addr_type])) argc = fun.AddReg(ir.Reg("argc", o.DK.U32)) argv = fun.AddReg(ir.Reg("argv", addr_type)) bbl = fun.AddBbl(ir.Bbl("start")) bbl.AddIns(ir.Ins(o.POPARG, [argc])) bbl.AddIns(ir.Ins(o.POPARG, [argv])) bbl.AddIns(ir.Ins(o.ST_MEM, [global_argc, ZERO, argc])) bbl.AddIns(ir.Ins(o.ST_MEM, [global_argv, ZERO, argv])) bbl.AddIns(ir.Ins(o.BSR, [unit.GetFun("__wasi_init")])) if initial_heap_size_pages: bbl.AddIns( ir.Ins(o.PUSHARG, [ir.Const(o.DK.S32, initial_heap_size_pages)])) bbl.AddIns(ir.Ins(o.BSR, [unit.GetFun("__memory_grow")])) bbl.AddIns(ir.Ins(o.POPARG, [fun.AddReg(ir.Reg("dummy", o.DK.S32))])) mem_base = fun.AddReg(ir.Reg("mem_base", addr_type)) bbl.AddIns(ir.Ins(o.LD_MEM, [mem_base, global_mem_base, ZERO])) if init_global: bbl.AddIns(ir.Ins(o.BSR, [init_global])) if init_data: bbl.AddIns(ir.Ins(o.PUSHARG, [mem_base])) bbl.AddIns(ir.Ins(o.BSR, [init_data])) bbl.AddIns(ir.Ins(o.PUSHARG, [mem_base])) bbl.AddIns(ir.Ins(o.BSR, [main])) bbl.AddIns(ir.Ins(o.PUSHARG, [ir.Const(o.DK.U32, 0)])) bbl.AddIns(ir.Ins(o.RET, [])) return fun
def DirFun(unit: ir.Unit, operands: List): name, kind, output_types, input_types = operands if len(input_types) > o.MAX_PARAMETERS or len( output_types) > o.MAX_PARAMETERS: raise ParseError(f"parameter list too long {name}") fun = unit.GetFun(name) if fun is None: fun = ir.Fun(name, kind, output_types, input_types) unit.AddFun(fun) elif fun.forward_declared: unit.InitForwardDeclaredFun(fun, kind, output_types, input_types) else: raise ParseError(f"duplicate Fun {name}")
def DirFun(unit: ir.Unit, operands: List): name, kind, output_types, input_types = operands if len(input_types) > o.MAX_PARAMETERS or len( output_types) > o.MAX_PARAMETERS: raise ParseError(f"parameter list too long {name}") fun = unit.GetFun(name) if fun is None: fun = ir.Fun(name, kind, output_types, input_types) unit.AddFun(fun) elif fun.kind is o.FUN_KIND.INVALID: # forward_declared unit.InitForwardDeclaredFun(fun, kind, output_types, input_types) elif fun.kind is o.FUN_KIND.EXTERN or kind is o.FUN_KIND.EXTERN: assert output_types == fun.output_types assert input_types == fun.input_types if kind is o.FUN_KIND.EXTERN: # we already have a proper function return # definition of a formerly extern functions fun.kind = kind # move fun to make it current unit.funs.remove(fun) unit.funs.append(fun) else: raise ParseError(f"duplicate Fun {name}")
def GenerateFun(unit: ir.Unit, mod: wasm.Module, wasm_fun: wasm.Function, fun: ir.Fun, global_table, addr_type): # op_stack contains regs produced by GetOpReg (it may only occur a the position encoding in its name op_stack: typing.List[typing.Union[ir.Reg, ir.Const]] = [] block_stack: typing.List[Block] = [] bbls: typing.List[ir.Bbl] = [] bbl_count = 0 jtb_count = 0 bbls.append(fun.AddBbl(ir.Bbl("start"))) loc_index = 0 assert fun.input_types[0] is addr_type mem_base = fun.AddReg(ir.Reg("mem_base", addr_type)) bbls[-1].AddIns(ir.Ins(o.POPARG, [mem_base])) for dk in fun.input_types[1:]: reg = fun.AddReg(ir.Reg(f"$loc_{loc_index}", dk)) loc_index += 1 bbls[-1].AddIns(ir.Ins(o.POPARG, [reg])) for locals in wasm_fun.impl.locals_list: for i in range(locals.count): reg = fun.AddReg( ir.Reg(f"$loc_{loc_index}", WASM_TYPE_TO_CWERG_TYPE[locals.kind])) loc_index += 1 bbls[-1].AddIns(ir.Ins(o.MOV, [reg, ir.Const(reg.kind, 0)])) DEBUG = False if DEBUG: print() print( f"# {fun.name} #ins:{len(wasm_fun.impl.expr.instructions)} in:{fun.input_types} out:{fun.output_types}" ) opc = None last_opc = None for n, wasm_ins in enumerate(wasm_fun.impl.expr.instructions): last_opc = opc opc = wasm_ins.opcode args = wasm_ins.args op_stack_size_before = len(op_stack) if DEBUG: print(f"#@@ {opc.name}", args, len(op_stack)) if opc.kind is wasm_opc.OPC_KIND.CONST: # breaks for floats # breaks for floats kind = OPC_TYPE_TO_CWERG_TYPE[opc.op_type] dst = GetOpReg(fun, kind, len(op_stack)) bbls[-1].AddIns(ir.Ins(o.MOV, [dst, ir.Const(kind, args[0])])) op_stack.append(dst) elif opc is wasm_opc.NOP: pass elif opc is wasm_opc.DROP: op_stack.pop(-1) elif opc is wasm_opc.LOCAL_GET: loc = GetLocalReg(fun, int(args[0])) dst = GetOpReg(fun, loc.kind, len(op_stack)) bbls[-1].AddIns(ir.Ins(o.MOV, [dst, loc])) op_stack.append(dst) elif opc is wasm_opc.LOCAL_SET: op = op_stack.pop(-1) bbls[-1].AddIns(ir.Ins(o.MOV, [GetLocalReg(fun, int(args[0])), op])) elif opc is wasm_opc.LOCAL_TEE: op = op_stack[-1] # no pop! bbls[-1].AddIns(ir.Ins(o.MOV, [GetLocalReg(fun, int(args[0])), op])) elif opc is wasm_opc.GLOBAL_GET: var_index = int(args[0]) var: wasm.Glob = mod.sections.get( wasm.SECTION_ID.GLOBAL).items[var_index] dst = GetOpReg(fun, WASM_TYPE_TO_CWERG_TYPE[var.global_type.value_type], len(op_stack)) var_mem = unit.GetMem(f"global_vars_{var_index}") bbls[-1].AddIns(ir.Ins(o.LD_MEM, [dst, var_mem, ZERO])) op_stack.append(dst) elif opc is wasm_opc.GLOBAL_SET: op = op_stack.pop(-1) var_index = int(args[0]) var_mem = unit.GetMem(f"global_vars_{var_index}") bbls[-1].AddIns(ir.Ins(o.ST_MEM, [var_mem, ZERO, op])) elif opc.kind is wasm_opc.OPC_KIND.ALU: if wasm_opc.FLAGS.UNARY in opc.flags: opcode, arg_factory = WASM_ALU1_TO_CWERG[opc.basename] op = op_stack.pop(-1) dst = GetOpReg(fun, op.kind, len(op_stack)) bbls[-1].AddIns(ir.Ins(opcode, [dst] + arg_factory(op))) op_stack.append(dst) else: op2 = op_stack.pop(-1) op1 = op_stack.pop(-1) dst = GetOpReg(fun, op1.kind, len(op_stack)) alu = WASM_ALU_TO_CWERG[opc.basename] if wasm_opc.FLAGS.UNSIGNED in opc.flags: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) tmp3 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 3) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2])) if isinstance(alu, o.Opcode): bbls[-1].AddIns(ir.Ins(alu, [tmp3, tmp1, tmp2])) else: alu(tmp3, tmp1, tmp2, bbls[-1]) bbls[-1].AddIns(ir.Ins(o.CONV, [dst, tmp3])) else: if isinstance(alu, o.Opcode): bbls[-1].AddIns(ir.Ins(alu, [dst, op1, op2])) else: alu(dst, op1, op2, bbls[-1]) op_stack.append(dst) elif opc.kind is wasm_opc.OPC_KIND.CONV: conv, dst_unsigned, src_unsigned = WASM_CONV_TO_CWERG[opc.name] op = op_stack.pop(-1) dst = GetOpReg(fun, OPC_TYPE_TO_CWERG_TYPE[opc.op_type], len(op_stack)) if src_unsigned: tmp = GetOpReg(fun, ToUnsigned(op.kind), op_stack_size_before + 1) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp, op])) op = tmp if dst_unsigned: tmp = GetOpReg(fun, ToUnsigned(dst.kind), op_stack_size_before + 1) dst, tmp = tmp, dst bbls[-1].AddIns(ir.Ins(conv, [dst, op])) if dst_unsigned: bbls[-1].AddIns(ir.Ins(o.CONV, [tmp, dst])) dst = tmp op_stack.append(dst) elif opc.kind is wasm_opc.OPC_KIND.BITCAST: assert opc.name in SUPPORTED_BITCAST op = op_stack.pop(-1) dst = GetOpReg(fun, OPC_TYPE_TO_CWERG_TYPE[opc.op_type], len(op_stack)) bbls[-1].AddIns(ir.Ins(o.BITCAST, [dst, op])) op_stack.append(dst) elif opc.kind is wasm_opc.OPC_KIND.CMP: # this always works because of the sentinel: "end" succ = wasm_fun.impl.expr.instructions[n + 1] if succ.opcode not in { wasm_opc.IF, wasm_opc.BR_IF, wasm_opc.SELECT }: cmp, res1, res2, op1, op2, unsigned = MakeCompare( opc, op_stack) if unsigned: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2])) op1 = tmp1 op2 = tmp2 dst = GetOpReg(fun, o.DK.S32, len(op_stack)) bbls[-1].AddIns(ir.Ins(cmp, [dst, res1, res2, op1, op2])) op_stack.append(dst) elif opc is wasm_opc.LOOP or opc is wasm_opc.BLOCK: block_stack.append( MakeBlock(bbl_count, opc, args, fun, op_stack, mod)) bbl_count += 1 bbls.append(block_stack[-1].start_bbl) elif opc is wasm_opc.IF: # note we do set the new bbl right away because we add some instructions to the old one # this always works because the stack cannot be empty at this point pred = wasm_fun.impl.expr.instructions[n - 1].opcode br, op1, op2, unsigned = MakeBranch(pred, op_stack, True) if unsigned: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2])) op1 = tmp1 op2 = tmp2 block_stack.append( MakeBlock(bbl_count, opc, args, fun, op_stack, mod)) # assert len(block_stack[-1].param_types) == 0 bbl_count += 1 bbls[-1].AddIns(ir.Ins(br, [op1, op2, block_stack[-1].else_bbl])) bbls.append(block_stack[-1].start_bbl) elif opc is wasm_opc.ELSE: block = block_stack[-1] assert block.opcode is wasm_opc.IF block.FinalizeIf(op_stack, bbls[-1], fun, last_opc in {wasm_opc.UNREACHABLE, wasm_opc.BR}) if DEBUG: print(f"#@@ AFTER STACK RESTORE", len(op_stack)) op_stack = op_stack[0:block.stack_start] bbls[-1].AddIns(ir.Ins(o.BRA, [block.end_bbl])) assert block.else_bbl is not None bbls.append(block.else_bbl) block.else_bbl = None elif opc is wasm_opc.END: if block_stack: block = block_stack.pop(-1) block.FinalizeEnd( op_stack, bbls[-1], fun, last_opc in {wasm_opc.UNREACHABLE, wasm_opc.BR}) if block.else_bbl: bbls.append(block.else_bbl) bbls.append(block.end_bbl) else: # end of function assert n + 1 == len(wasm_fun.impl.expr.instructions) pred = wasm_fun.impl.expr.instructions[n - 1].opcode if pred not in { wasm_opc.RETURN, wasm_opc.UNREACHABLE, wasm_opc.BR }: for x in reversed(fun.output_types): op = op_stack.pop(-1) assert op.kind == x, f"outputs: {fun.output_types} mismatch {op.kind} vs {x}" bbls[-1].AddIns(ir.Ins(o.PUSHARG, [op])) bbls[-1].AddIns(ir.Ins(o.RET, [])) elif opc is wasm_opc.CALL: wasm_callee = mod.functions[int(wasm_ins.args[0])] callee = unit.GetFun(wasm_callee.name) assert callee, f"unknown fun: {wasm_callee.name}" EmitCall(fun, bbls[-1], ir.Ins(o.BSR, [callee]), op_stack, mem_base, callee) elif opc is wasm_opc.CALL_INDIRECT: assert isinstance(args[1], wasm.TableIdx), f"{type(args[1])}" assert int(args[1]) == 0, f"only one table supported" assert isinstance(args[0], wasm.TypeIdx), f"{type(args[0])}" type_sec = mod.sections.get(wasm.SECTION_ID.TYPE) func_type: wasm.FunctionType = type_sec.items[int(args[0])] arguments = [addr_type] + TranslateTypeList(func_type.args) returns = TranslateTypeList(func_type.rets) # print (f"# @@@@ CALL INDIRECT {returns} <- {arguments} [{int(args[0])}] {func_type}") signature = FindFunWithSignature(unit, arguments, returns) table_reg = GetOpReg(fun, addr_type, len(op_stack)) code_type = o.DK.C32 if addr_type is o.DK.A32 else o.DK.C64 fun_reg = GetOpReg(fun, code_type, len(op_stack) + 1) index = op_stack.pop(-1) assert index.kind is o.DK.S32 bbls[-1].AddIns( ir.Ins(o.MUL, [ index, index, ir.Const(o.DK.U32, code_type.bitwidth() // 8) ])) bbls[-1].AddIns(ir.Ins(o.LEA_MEM, [table_reg, global_table, ZERO])) bbls[-1].AddIns(ir.Ins(o.LD, [fun_reg, table_reg, index])) EmitCall(fun, bbls[-1], ir.Ins(o.JSR, [fun_reg, signature]), op_stack, mem_base, signature) elif opc is wasm_opc.RETURN: for x in reversed(fun.output_types): op = op_stack.pop(-1) assert op.kind == x, f"outputs: {fun.output_types} mismatch {op.kind} vs {x}" bbls[-1].AddIns(ir.Ins(o.PUSHARG, [op])) bbls[-1].AddIns(ir.Ins(o.RET, [])) elif opc is wasm_opc.BR: assert isinstance(args[0], wasm.LabelIdx) block = GetTargetBlock(block_stack, args[0]) target = block.start_bbl if block.opcode is not wasm_opc.LOOP: target = block.end_bbl block.FinalizeResultsCopy(op_stack, bbls[-1], fun) bbls[-1].AddIns(ir.Ins(o.BRA, [target])) elif opc is wasm_opc.BR_IF: assert isinstance(args[0], wasm.LabelIdx) pred = wasm_fun.impl.expr.instructions[n - 1].opcode br, op1, op2, unsigned = MakeBranch(pred, op_stack, False) if unsigned: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2])) op1 = tmp1 op2 = tmp2 block = GetTargetBlock(block_stack, args[0]) target = block.start_bbl if block.opcode is not wasm_opc.LOOP: target = block.end_bbl block.FinalizeResultsCopy(op_stack, bbls[-1], fun) bbls[-1].AddIns(ir.Ins(br, [op1, op2, target])) elif opc is wasm_opc.SELECT: pred = wasm_fun.impl.expr.instructions[n - 1].opcode br, op1, op2, unsigned = MakeBranch(pred, op_stack, False) val_f = op_stack.pop(-1) val_t = op_stack.pop(-1) assert val_f.kind == val_t.kind reg = GetOpReg(fun, val_f.kind, len(op_stack)) op_stack.append(reg) bbls[-1].AddIns(ir.Ins(o.MOV, [reg, val_t])) bbls.append(fun.AddBbl(ir.Bbl(f"select_{bbl_count}"))) bbl_count += 1 if unsigned: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) bbls[-2].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-2].AddIns(ir.Ins(o.CONV, [tmp2, op2])) op1 = tmp1 op2 = tmp2 bbls[-2].AddIns(ir.Ins(br, [op1, op2, bbls[-1]])) bbls[-2].AddIns(ir.Ins(o.MOV, [reg, val_f])) elif opc.kind is wasm_opc.OPC_KIND.STORE: val = op_stack.pop(-1) offset = op_stack.pop(-1) if args[1] != 0: tmp = GetOpReg(fun, offset.kind, len(op_stack)) bbls[-1].AddIns( ir.Ins(o.ADD, [tmp, offset, ir.Const(offset.kind, args[1])])) offset = tmp dk_tmp = STORE_TO_CWERG_TYPE.get(opc.name) if dk_tmp is not None: tmp = GetOpReg(fun, dk_tmp, len(op_stack) + 1) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp, val])) val = tmp bbls[-1].AddIns(ir.Ins(o.ST, [mem_base, offset, val])) elif opc.kind is wasm_opc.OPC_KIND.LOAD: offset = op_stack.pop(-1) if args[1] != 0: tmp = GetOpReg(fun, offset.kind, len(op_stack)) bbls[-1].AddIns( ir.Ins(o.ADD, [tmp, offset, ir.Const(offset.kind, args[1])])) offset = tmp dst = GetOpReg(fun, OPC_TYPE_TO_CWERG_TYPE[opc.op_type], len(op_stack)) op_stack.append(dst) dk_tmp = LOAD_TO_CWERG_TYPE.get(opc.basename) if dk_tmp: tmp = GetOpReg(fun, dk_tmp, len(op_stack)) bbls[-1].AddIns(ir.Ins(o.LD, [tmp, mem_base, offset])) bbls[-1].AddIns(ir.Ins(o.CONV, [dst, tmp])) else: bbls[-1].AddIns(ir.Ins(o.LD, [dst, mem_base, offset])) elif opc is wasm_opc.BR_TABLE: bbl_tab = { n: GetTargetBbl(block_stack, x) for n, x in enumerate(args[0]) } bbl_def = GetTargetBbl(block_stack, args[1]) op = op_stack.pop(-1) tab_size = ir.Const(ToUnsigned(op.kind), len(bbl_tab)) jtb_count += 1 jtb = fun.AddJtb( ir.Jtb(f"jtb_{jtb_count}", bbl_def, bbl_tab, tab_size.value)) reg_unsigned = GetOpReg(fun, ToUnsigned(op.kind), len(op_stack)) bbls[-1].AddIns(ir.Ins(o.CONV, [reg_unsigned, op])) bbls[-1].AddIns(ir.Ins(o.BLE, [tab_size, reg_unsigned, bbl_def])) bbls[-1].AddIns(ir.Ins(o.SWITCH, [reg_unsigned, jtb])) elif opc is wasm_opc.UNREACHABLE: bbls[-1].AddIns(ir.Ins(o.TRAP, [])) elif opc is wasm_opc.MEMORY_GROW or opc is wasm_opc.MEMORY_SIZE: op = ZERO_S if opc is wasm_opc.MEMORY_GROW: op = op_stack.pop(-1) bbls[-1].AddIns(ir.Ins(o.PUSHARG, [op])) assert unit.GetFun("__memory_grow") bbls[-1].AddIns(ir.Ins(o.BSR, [unit.GetFun("__memory_grow")])) dst = GetOpReg(fun, o.DK.S32, len(op_stack)) bbls[-1].AddIns(ir.Ins(o.POPARG, [dst])) op_stack.append(dst) else: assert False, f"unsupported opcode [{opc.name}]" assert not op_stack, f"op_stack not empty in {fun.name}: {op_stack}" assert not block_stack, f"block_stack not empty in {fun.name}: {block_stack}" assert len(bbls) == len(fun.bbls) fun.bbls = bbls