def _truncate(dk: o.DK, val): if dk.flavor() == o.DK_FLAVOR_F: return val elif dk.flavor() == o.DK_FLAVOR_U: return val & ((1 << dk.bitwidth()) - 1) else: return SignedIntFromBits(val, dk.bitwidth())
def MaybeMakeGlobalTable(mod: wasm.Module, unit: ir.Unit, addr_type: o.DK): bit_width = addr_type.bitwidth() table_sec = mod.sections.get(wasm.SECTION_ID.TABLE) table_elements = mod.sections.get(wasm.SECTION_ID.ELEMENT) if not table_sec: return None global_table = None assert table_elements assert len(table_sec.items) == 1 table_type: wasm.TableType = table_sec.items[0].table_type assert table_type.element_type == wasm.REF_TYPE.FUNCREF table_data = [None] * table_type.limits.max for elem in table_elements.items: ins = GetInsFromInitializerExpression(elem.expr) assert ins.opcode is wasm_opc.I32_CONST start = ins.args[0] for n, fun in enumerate(elem.funcidxs): table_data[start + n] = fun global_table = unit.AddMem( ir.Mem("global_table", bit_width // 8, o.MEM_KIND.RO)) width = addr_type.bitwidth() // 8 for fun in table_data: if fun is None: global_table.AddData(ir.DataBytes(width, b"\0")) else: assert isinstance(fun, wasm.FuncIdx) fun = unit.GetFun(mod.functions[int(fun)].name) global_table.AddData(ir.DataAddrFun(width, fun)) return global_table
def Cntlz(kind: o.DK, val: int): if val == 0: return kind.bitwidth() mask = 1 << (kind.bitwidth() - 1) n = 0 while val & mask: mask >>= 1 n += 1 return n
def ParseConst(value_str: str, kind: o.DK) -> Const: flavor = kind.flavor() if flavor is o.DK_FLAVOR_F: return Const(kind, float(value_str)) bit_width = kind.bitwidth() x = int(value_str, 0) if flavor is o.DK_FLAVOR_U: assert x >= 0 assert x < (1 << bit_width) elif x >= 0: assert x < (1 << (bit_width - 1)) else: assert -x <= (1 << (bit_width - 1)) return Const(kind, x)
def ConvertIntValue(kind_dst: o.DK, val: ir.Const) -> ir.Const: kind_src = val.kind width_dst = kind_dst.bitwidth() width_src = kind_src.bitwidth() masked = val.value & ((1 << width_dst) - 1) if kind_dst.flavor() == o.DK_FLAVOR_U: return ir.Const(kind_dst, val.value & masked) # print ("@@@", kind_dst.name, width_dst, kind_src, width_src, num_kind, x) elif width_dst > width_src: return ir.Const(kind_dst, val.value) else: # dst is ACS and width_dst <= width_src will_be_negative = val.value & (1 << (width_dst - 1)) if will_be_negative: return ir.Const(kind_dst, masked - (1 << width_dst)) return ir.Const(kind_dst, masked)
def Translate(mod: wasm.Module, addr_type: o.DK) -> ir.Unit: table_import = mod.sections.get(wasm.SECTION_ID.IMPORT) for i in table_import.items: assert isinstance(i, wasm.Import) assert isinstance(i.desc, wasm.TypeIdx), f"cannot handle strange imports: {i}" bit_width = addr_type.bitwidth() unit = ir.Unit("unit") global_argv = unit.AddMem( ir.Mem("global_argv", 2 * bit_width // 8, o.MEM_KIND.RW)) global_argv.AddData(ir.DataBytes(bit_width // 8, b"\0")) global_argc = unit.AddMem(ir.Mem("global_argc", 4, o.MEM_KIND.RW)) global_argc.AddData(ir.DataBytes(4, b"\0")) memcpy = GenerateMemcpyFun(unit, addr_type) init_global = GenerateInitGlobalVarsFun(mod, unit, addr_type) init_data = GenerateInitDataFun(mod, unit, memcpy, addr_type) unit.AddFun(ir.Fun("__wasi_init", o.FUN_KIND.EXTERN, [], [])) unit.AddFun( ir.Fun("__memory_grow", o.FUN_KIND.EXTERN, [o.DK.S32], [o.DK.S32])) main = None for wasm_fun in mod.functions: # forward declare everything since we cannot rely on a topological sort of the funs if isinstance(wasm_fun.impl, wasm.Import): assert wasm_fun.name in WASI_FUNCTIONS, f"unimplemented external function: {wasm_fun.name}" arguments = [addr_type] + TranslateTypeList(wasm_fun.func_type.args) returns = TranslateTypeList(wasm_fun.func_type.rets) # assert len(returns) <= 1 unit.AddFun( ir.Fun(wasm_fun.name, o.FUN_KIND.EXTERN, returns, arguments)) global_table = MaybeMakeGlobalTable(mod, unit, addr_type) for wasm_fun in mod.functions: if isinstance(wasm_fun.impl, wasm.Import): continue fun = unit.GetFun(wasm_fun.name) fun.kind = o.FUN_KIND.NORMAL if fun.name == "_start": fun.name = "$main" main = fun GenerateFun(unit, mod, wasm_fun, fun, global_table, addr_type) # print ("\n".join(serialize.FunRenderToAsm(fun))) sanity.FunCheck(fun, unit, check_cfg=False) initial_heap_size_pages = 0 sec_memory = mod.sections.get(wasm.SECTION_ID.MEMORY) if sec_memory: assert len(sec_memory.items) == 1 heap: wasm.Mem = sec_memory.items[0] initial_heap_size_pages = heap.mem_type.limits.min assert main, f"missing main function" GenerateStartup(unit, global_argc, global_argv, main, init_global, init_data, initial_heap_size_pages, addr_type) return unit
def Cnttz(kind: o.DK, val: int): if val == 0: return kind.bitwidth() n = 0 while val & 1 == 0: val >>= 1 n += 1 return n
def ConvertIntValue(kind_dst: o.DK, val: ir.Const) -> ir.Const: kind_src = val.kind width_dst = kind_dst.bitwidth() width_src = kind_src.bitwidth() # print ("@@@", kind_dst.name, width_dst, kind_src, width_src, num_kind, x) masked = val.value & ((1 << width_dst) - 1) if width_dst > width_src: if kind_dst.flavor() == kind_src.flavor() or kind_src.flavor( ) == o.DK_FLAVOR_U: return ir.Const(kind_dst, val.value) # kind_dst == RK_U, kind_src == RK_S return ir.Const(kind_dst, masked) elif kind_dst.flavor() == o.DK_FLAVOR_U: return ir.Const(kind_dst, masked) else: # kind_dst[0] == RK_S sign = val.value & (1 << (width_dst - 1)) if sign == 0: return ir.Const(kind_dst, masked) return ir.Const(kind_dst, masked - (1 << width_dst))
def GenerateStartup(unit: ir.Unit, global_argc, global_argv, main: ir.Fun, init_global: ir.Fun, init_data: ir.Fun, initial_heap_size_pages: int, addr_type: o.DK) -> ir.Fun: bit_width = addr_type.bitwidth() global_mem_base = unit.AddMem(ir.Mem("__memory_base", 0, o.MEM_KIND.EXTERN)) fun = unit.AddFun( ir.Fun("main", o.FUN_KIND.NORMAL, [o.DK.U32], [o.DK.U32, addr_type])) argc = fun.AddReg(ir.Reg("argc", o.DK.U32)) argv = fun.AddReg(ir.Reg("argv", addr_type)) bbl = fun.AddBbl(ir.Bbl("start")) bbl.AddIns(ir.Ins(o.POPARG, [argc])) bbl.AddIns(ir.Ins(o.POPARG, [argv])) bbl.AddIns(ir.Ins(o.ST_MEM, [global_argc, ZERO, argc])) bbl.AddIns(ir.Ins(o.ST_MEM, [global_argv, ZERO, argv])) bbl.AddIns(ir.Ins(o.BSR, [unit.GetFun("__wasi_init")])) if initial_heap_size_pages: bbl.AddIns( ir.Ins(o.PUSHARG, [ir.Const(o.DK.S32, initial_heap_size_pages)])) bbl.AddIns(ir.Ins(o.BSR, [unit.GetFun("__memory_grow")])) bbl.AddIns(ir.Ins(o.POPARG, [fun.AddReg(ir.Reg("dummy", o.DK.S32))])) mem_base = fun.AddReg(ir.Reg("mem_base", addr_type)) bbl.AddIns(ir.Ins(o.LD_MEM, [mem_base, global_mem_base, ZERO])) if init_global: bbl.AddIns(ir.Ins(o.BSR, [init_global])) if init_data: bbl.AddIns(ir.Ins(o.PUSHARG, [mem_base])) bbl.AddIns(ir.Ins(o.BSR, [init_data])) bbl.AddIns(ir.Ins(o.PUSHARG, [mem_base])) bbl.AddIns(ir.Ins(o.BSR, [main])) bbl.AddIns(ir.Ins(o.PUSHARG, [ir.Const(o.DK.U32, 0)])) bbl.AddIns(ir.Ins(o.RET, [])) return fun
def get_cpu_reg_family(self, kind: o.DK) -> int: return FLT_NOT_LAC if kind.flavor() is o.DK_FLAVOR_F else GPR_NOT_LAC
def FunRegWidthWidening(fun: ir.Fun, narrow_kind: o.DK, wide_kind: o.DK): """ Change the type of all register (and constants) of type src_kind into dst_kind. Add compensation code where necessary. dst_kind must be wider than src_kind. This is useful for target architectures that do not support operations for all operand widths. Note, this also widens input and output regs. So this must run for all functions including prototypes TODO: double check if we are doing the right thing with o.CONV TODO: there are more subtle bugs. For example mul x:U8 = 43 * 47 (= 229) div y:u8 = x 13 (= 17) whereas: mul x:U16 = 43 * 47 (= 2021) div y:u16 = x 13 (= 155) Other problematic operations: rem, popcnt, ... """ assert ir.FUN_FLAG.STACK_FINALIZED not in fun.flags fun.input_types = [ wide_kind if x == narrow_kind else x for x in fun.input_types ] fun.output_types = [ wide_kind if x == narrow_kind else x for x in fun.output_types ] assert narrow_kind.flavor() == wide_kind.flavor() assert narrow_kind.bitwidth() < wide_kind.bitwidth() narrow_regs = { reg for reg in fun.reg_syms.values() if reg.kind == narrow_kind } for reg in narrow_regs: reg.kind = wide_kind count = 0 for bbl in fun.bbls: inss = [] for ins in bbl.inss: ops = ins.operands kind = ins.opcode.kind for n, reg in enumerate(ops): if n == 2 and kind is o.OPC_KIND.ST or n == 0 and kind is o.OPC_KIND.LD: continue if isinstance(reg, ir.Const) and reg.kind is narrow_kind: # if ins.opcode.constraints[n] == o.TC.OFFSET: # continue ops[n] = ir.Const(wide_kind, reg.value) kind = ins.opcode.kind if kind is o.OPC_KIND.LD and ops[0] in narrow_regs: inss.append(ins) tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True) inss.append(ir.Ins(o.CONV, [ops[0], tmp_reg])) ops[0] = tmp_reg elif (kind is o.OPC_KIND.ST and isinstance(ops[2], ir.Reg) and ops[2] in narrow_regs): tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True) inss.append(ir.Ins(o.CONV, [tmp_reg, ops[2]])) inss.append(ins) ops[2] = tmp_reg else: inss.append(ins) count += len(inss) - len(bbl.inss) bbl.inss = inss return count
def FunRegWidthWidening(fun: ir.Fun, narrow_kind: o.DK, wide_kind: o.DK): """ Change the type of all register (and constants) of type src_kind into dst_kind. Add compensation code where necessary. dst_kind must be wider than src_kind. This is useful for target architectures that do not support operations for all operand widths. Note, this also widens input and output regs. So this must run for all functions including prototypes TODO: double check if we are doing the right thing with o.CONV TODO: there are more subtle bugs. For example mul x:U8 = 43 * 47 (= 229) div y:u8 = x 13 (= 17) whereas: mul x:U16 = 43 * 47 (= 2021) div y:u16 = x 13 (= 155) Other problematic operations: rem, popcnt, ... The invariant we are maintaining is this one: if reg a gets widened into reg b with bitwidth(a) = w then the lower w bits of reg b will always contain the same data as reg a would have. """ assert ir.FUN_FLAG.STACK_FINALIZED not in fun.flags fun.input_types = [ wide_kind if x == narrow_kind else x for x in fun.input_types ] fun.output_types = [ wide_kind if x == narrow_kind else x for x in fun.output_types ] assert narrow_kind.flavor() == wide_kind.flavor() assert narrow_kind.bitwidth() < wide_kind.bitwidth() narrow_regs = { reg for reg in fun.reg_syms.values() if reg.kind == narrow_kind } for reg in narrow_regs: reg.kind = wide_kind count = 0 for bbl in fun.bbls: inss = [] for ins in bbl.inss: ops = ins.operands kind = ins.opcode.kind changed = False for n, reg in enumerate(ops): if isinstance(reg, ir.Const) and reg.kind is narrow_kind: if kind is o.OPC_KIND.ST and n == 2: continue ops[n] = ir.Const(wide_kind, reg.value) changed = True if isinstance(reg, ir.Reg) and reg in narrow_regs: changed = True if not changed: inss.append(ins) continue kind = ins.opcode.kind if ins.opcode is o.SHL or ins.opcode is o.SHR: # deal with the shift amount which is subject to an implicit modulo "bitwidth -1" # by changing the width of the reg - we lose this information tmp_reg = fun.GetScratchReg(wide_kind, "tricky", False) inss.append( ir.Ins(o.AND, [ tmp_reg, ops[2], ir.Const(wide_kind, narrow_kind.bitwidth() - 1) ])) ops[2] = tmp_reg if ins.opcode is o.SHR and isinstance(ops[1], ir.Reg): # for SHR we also need to make sure the new high order bits are correct tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True) inss.append(ir.Ins(o.CONV, [tmp_reg, ops[1]])) # the implicit understanding is that this will become nop or a move and not modify the # high-order bit we just set in the previous instruction inss.append(ir.Ins(o.CONV, [ops[1], tmp_reg])) inss.append(ins) elif ins.opcode is o.CNTLZ: inss.append(ins) excess = wide_kind.bitwidth() - narrow_kind.bitwidth() inss.append( ir.Ins(o.SUB, [ops[0], ops[0], ir.Const(wide_kind, excess)])) elif ins.opcode is o.CNTTZ: inss.append(ins) inss.append( ir.Ins(o.CMPLT, [ ops[0], ops[0], ir.Const(wide_kind, narrow_kind.bitwidth()), ops[0], ir.Const(wide_kind, narrow_kind.bitwidth()) ])) elif kind is o.OPC_KIND.LD and ops[0] in narrow_regs: inss.append(ins) tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True) inss.append(ir.Ins(o.CONV, [ops[0], tmp_reg])) ops[0] = tmp_reg elif (kind is o.OPC_KIND.ST and isinstance(ops[2], ir.Reg) and ops[2] in narrow_regs): tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True) inss.append(ir.Ins(o.CONV, [tmp_reg, ops[2]])) inss.append(ins) ops[2] = tmp_reg elif ins.opcode is o.CONV: tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True) inss.append(ir.Ins(o.CONV, [tmp_reg, ops[1]])) inss.append(ir.Ins(o.CONV, [ops[0], tmp_reg])) else: inss.append(ins) count += len(inss) - len(bbl.inss) bbl.inss = inss return count