def _InsEliminateMemLoadStore(ins: ir.Ins, fun: ir.Fun, base_kind: o.DK, offset_kind: o.DK) -> Optional[List[ir.Ins]]: """This rewrite is usually applied as prep step by some backends to get rid of Mem operands. It allows the register allocator to see the scratch register but it will obscure the fact that a ld/st is from a static location. Note: this function may add local registers which does not affect liveness or use-deg chains """ opc = ins.opcode ops = ins.operands if opc is o.ST_MEM: st_offset = ops[1] lea_offset = ir.Const(offset_kind, 0) if isinstance(st_offset, ir.Const): st_offset, lea_offset = lea_offset, st_offset scratch_reg = fun.GetScratchReg(base_kind, "base", False) lea = ir.Ins(o.LEA_MEM, [scratch_reg, ops[0], lea_offset]) ins.Init(o.ST, [scratch_reg, st_offset, ops[2]]) return [lea, ins] elif opc is o.LD_MEM: ld_offset = ops[2] lea_offset = ir.Const(offset_kind, 0) if isinstance(ld_offset, ir.Const): ld_offset, lea_offset = lea_offset, ld_offset scratch_reg = fun.GetScratchReg(base_kind, "base", False) # TODO: should the Zero Offset stay with the ld op? lea = ir.Ins(o.LEA_MEM, [scratch_reg, ops[1], lea_offset]) ins.Init(o.LD, [ops[0], scratch_reg, ld_offset]) return [lea, ins] else: return None
def _InsEliminateCopySign(ins: ir.Ins, fun: ir.Fun) -> Optional[List[ir.Ins]]: """Rewrites copysign instructions like so: z = copysign a b aa = int(a) & 0x7f...f bb = int(b) & 0x80...0 z = flt(aa | bb) """ if ins.opcode is not o.COPYSIGN: return None ops = ins.operands out = [] if ops[0].kind == o.DK.F32: kind = o.DK.U32 sign = 1 << 31 mask = sign - 1 else: kind = o.DK.U64 sign = 1 << 63 mask = sign - 1 tmp_src1 = fun.GetScratchReg(kind, "elim_copysign1", False) out.append(ir.Ins(o.BITCAST, [tmp_src1, ops[1]])) out.append(ir.Ins(o.AND, [tmp_src1, tmp_src1, ir.Const(kind, mask)])) # tmp_src2 = fun.GetScratchReg(kind, "elim_copysign2", False) out.append(ir.Ins(o.BITCAST, [tmp_src2, ops[2]])) out.append(ir.Ins(o.AND, [tmp_src2, tmp_src2, ir.Const(kind, sign)])) # out.append(ir.Ins(o.OR, [tmp_src1, tmp_src1, tmp_src2])) out.append(ir.Ins(o.BITCAST, [ops[0], tmp_src1])) return out
def BuildExample() -> ir.Unit: unit = ir.Unit("fib") fun_fib = unit.AddFun( ir.Fun("fib", o.FUN_KIND.NORMAL, [o.DK.U32], [o.DK.U32])) bbl_start = fun_fib.AddBbl(ir.Bbl("start")) bbl_difficult = fun_fib.AddBbl(ir.Bbl("difficult")) reg_in = fun_fib.AddReg(ir.Reg("in", o.DK.U32)) reg_x = fun_fib.AddReg(ir.Reg("x", o.DK.U32)) reg_out = fun_fib.AddReg(ir.Reg("out", o.DK.U32)) bbl_start.AddIns(ir.Ins(o.POPARG, [reg_in])) bbl_start.AddIns( ir.Ins(o.BLT, [ir.Const(o.DK.U32, 1), reg_in, bbl_difficult])) bbl_start.AddIns(ir.Ins(o.PUSHARG, [reg_in])) bbl_start.AddIns(ir.Ins(o.RET, [])) bbl_difficult.AddIns(ir.Ins(o.MOV, [reg_out, ir.Const(o.DK.U32, 0)])) bbl_difficult.AddIns(ir.Ins(o.SUB, [reg_x, reg_in, ir.Const(o.DK.U32, 1)])) bbl_difficult.AddIns(ir.Ins(o.PUSHARG, [reg_x])) bbl_difficult.AddIns(ir.Ins(o.BSR, [fun_fib])) bbl_difficult.AddIns(ir.Ins(o.POPARG, [reg_x])) bbl_difficult.AddIns(ir.Ins(o.ADD, [reg_out, reg_out, reg_x])) bbl_difficult.AddIns(ir.Ins(o.SUB, [reg_x, reg_in, ir.Const(o.DK.U32, 2)])) bbl_difficult.AddIns(ir.Ins(o.PUSHARG, [reg_x])) bbl_difficult.AddIns(ir.Ins(o.BSR, [fun_fib])) bbl_difficult.AddIns(ir.Ins(o.POPARG, [reg_x])) bbl_difficult.AddIns(ir.Ins(o.ADD, [reg_out, reg_out, reg_x])) bbl_difficult.AddIns(ir.Ins(o.PUSHARG, [reg_out])) bbl_difficult.AddIns(ir.Ins(o.RET, [])) return unit
def _InsStrengthReduction(ins: ir.Ins, _fun: ir.Fun) -> Optional[List[ir.Ins]]: """Miscellaneous standard strength reduction rewrites TODO """ opc = ins.opcode ops = ins.operands if _InsIsNop1(ins): ops.pop(2) ins.operand_defs.pop(2) return [ins.Init(o.MOV, ops)] elif _InsIsNop2(ins): ops.pop(1) ins.operand_defs.pop(1) return [ins.Init(o.MOV, ops)] elif _InsIsZero(ins): ins.Init(o.MOV, [ops[0], ir.Const(ops[0].kind, 0)]) return [ins] elif (opc is o.MUL and ops[0].IsIntReg() and isinstance(ops[2], ir.Const) and ops[2].IsIntPowerOfTwo()): shift = ops[2].IntBinaryLog() ins.Init(o.SHL, [ops[0], ops[1], ir.Const(ops[0].kind, shift)]) return [ins] elif (opc is o.MUL and ops[0].IsIntReg() and isinstance(ops[1], ir.Const) and ops[1].IsIntPowerOfTwo()): shift = ops[1].IntBinaryLog() # TODO: orig_operand update ins.Init(o.SHL, [ops[0], ops[2], ir.Const(ops[0].kind, shift)]) return [ins] # TODO: DIV for unsigned int return None
def _InsLimitShiftAmounts(ins: ir.Ins, fun: ir.Fun, width: int) -> Optional[List[ir.Ins]]: """This rewrite is usually applied as prep step by some backends to get rid of Stk operands. It allows the register allocator to see the scratch register but it will obscure the fact that a memory access is a stack access. Note, a stack address already implies a `sp+offset` addressing mode and risk ISAs do no usually support `sp+offset+reg` addressing mode. """ opc = ins.opcode ops = ins.operands if (opc is not o.SHL and opc is not o.SHR) or ops[0].kind.bitwidth() != width: return None amount = ops[2] if isinstance(amount, ir.Const): if 0 <= amount.value < width: return None else: ops[2] = ir.Const(amount.kind, amount.value % width) return ins else: tmp = fun.GetScratchReg(amount.kind, "shift", False) mask = ir.Ins(o.AND, [tmp, amount, ir.Const(amount.kind, width - 1)]) ins.Init(opc, [ops[0], ops[1], tmp]) return [mask, ins]
def ConvertIntValue(kind_dst: o.DK, val: ir.Const) -> ir.Const: kind_src = val.kind width_dst = kind_dst.bitwidth() width_src = kind_src.bitwidth() masked = val.value & ((1 << width_dst) - 1) if kind_dst.flavor() == o.DK_FLAVOR_U: return ir.Const(kind_dst, val.value & masked) # print ("@@@", kind_dst.name, width_dst, kind_src, width_src, num_kind, x) elif width_dst > width_src: return ir.Const(kind_dst, val.value) else: # dst is ACS and width_dst <= width_src will_be_negative = val.value & (1 << (width_dst - 1)) if will_be_negative: return ir.Const(kind_dst, masked - (1 << width_dst)) return ir.Const(kind_dst, masked)
def HandleRotl(dst: ir.Reg, op1: ir.Reg, op2: ir.Reg, bbl: ir.Bbl): assert dst != op1 assert dst.kind is o.DK.U32 or dst.kind is o.DK.U64, f"{dst}" bitwidth = ir.Const(dst.kind, dst.kind.bitwidth()) bbl.AddIns(ir.Ins(o.SHL, [dst, op1, op2])) bbl.AddIns(ir.Ins(o.SUB, [op2, bitwidth, op2])) bbl.AddIns(ir.Ins( o.SHR, [op1, op1, op2])) # here the unsigned requirement kicks in bbl.AddIns(ir.Ins(o.OR, [dst, dst, op1]))
def FunSpillRegs(fun: ir.Fun, offset_kind: o.DK, regs: List[ir.Reg]) -> int: reg_to_stk: Dict[ir.Reg, ir.Stk] = {} for reg in regs: size = ir.OffsetConst(reg.kind.bitwidth() // 8) stk = ir.Stk(f"$spill_{reg.name}", size, size) reg_to_stk[reg] = stk fun.AddStk(stk) return ir.FunGenericRewrite(fun, InsSpillRegs, zero_const=ir.Const(offset_kind, 0), reg_to_stk=reg_to_stk)
def MakeBranch(pred: wasm_opc.Opcode, op_stack, inverse): if pred.kind is wasm_opc.OPC_KIND.CMP: if wasm_opc.FLAGS.UNARY in pred.flags: # eqz op1 = op_stack.pop(-1) op2 = ir.Const(op1.kind, 0) return o.BNE if inverse else o.BEQ, op1, op2, False else: # std two op cmp op2 = op_stack.pop(-1) op1 = op_stack.pop(-1) tab = WASM_CMP_TO_CWERG_CBR_INV if inverse else WASM_CMP_TO_CWERG_CBR br, swap, unsigned = tab[pred.basename] if swap: op1, op2 = op2, op1 return br, op1, op2, unsigned else: op1 = op_stack.pop(-1) op2 = ir.Const(op1.kind, 0) return o.BEQ if inverse else o.BNE, op1, op2, False
def ConvertIntValue(kind_dst: o.DK, val: ir.Const) -> ir.Const: kind_src = val.kind width_dst = kind_dst.bitwidth() width_src = kind_src.bitwidth() # print ("@@@", kind_dst.name, width_dst, kind_src, width_src, num_kind, x) masked = val.value & ((1 << width_dst) - 1) if width_dst > width_src: if kind_dst.flavor() == kind_src.flavor() or kind_src.flavor( ) == o.DK_FLAVOR_U: return ir.Const(kind_dst, val.value) # kind_dst == RK_U, kind_src == RK_S return ir.Const(kind_dst, masked) elif kind_dst.flavor() == o.DK_FLAVOR_U: return ir.Const(kind_dst, masked) else: # kind_dst[0] == RK_S sign = val.value & (1 << (width_dst - 1)) if sign == 0: return ir.Const(kind_dst, masked) return ir.Const(kind_dst, masked - (1 << width_dst))
def InsEliminateImmediateViaMem(ins: ir.Ins, pos: int, fun: ir.Fun, unit: ir.Unit, addr_kind: o.DK, offset_kind: o.DK) -> List[ir.Ins]: """Rewrite instruction with an immediate as load of the immediate This is useful if the target architecture does not support immediate for that instruction, or the immediate is too large. This optimization is run rather late and may already see machine registers. """ # support of PUSHARG would require additional work because they need to stay consecutive assert ins.opcode is not o.PUSHARG const = ins.operands[pos] mem = unit.FindOrAddConstMem(const) tmp_addr = fun.GetScratchReg(addr_kind, "mem_const_addr", True) lea_ins = ir.Ins(o.LEA_MEM, [tmp_addr, mem, ir.Const(offset_kind, 0)]) tmp = fun.GetScratchReg(const.kind, "mem_const", True) ld_ins = ir.Ins(o.LD, [tmp, tmp_addr, ir.Const(offset_kind, 0)]) ins.operands[pos] = tmp return [lea_ins, ld_ins]
def GenerateStartup(unit: ir.Unit, global_argc, global_argv, main: ir.Fun, init_global: ir.Fun, init_data: ir.Fun, initial_heap_size_pages: int, addr_type: o.DK) -> ir.Fun: bit_width = addr_type.bitwidth() global_mem_base = unit.AddMem(ir.Mem("__memory_base", 0, o.MEM_KIND.EXTERN)) fun = unit.AddFun( ir.Fun("main", o.FUN_KIND.NORMAL, [o.DK.U32], [o.DK.U32, addr_type])) argc = fun.AddReg(ir.Reg("argc", o.DK.U32)) argv = fun.AddReg(ir.Reg("argv", addr_type)) bbl = fun.AddBbl(ir.Bbl("start")) bbl.AddIns(ir.Ins(o.POPARG, [argc])) bbl.AddIns(ir.Ins(o.POPARG, [argv])) bbl.AddIns(ir.Ins(o.ST_MEM, [global_argc, ZERO, argc])) bbl.AddIns(ir.Ins(o.ST_MEM, [global_argv, ZERO, argv])) bbl.AddIns(ir.Ins(o.BSR, [unit.GetFun("__wasi_init")])) if initial_heap_size_pages: bbl.AddIns( ir.Ins(o.PUSHARG, [ir.Const(o.DK.S32, initial_heap_size_pages)])) bbl.AddIns(ir.Ins(o.BSR, [unit.GetFun("__memory_grow")])) bbl.AddIns(ir.Ins(o.POPARG, [fun.AddReg(ir.Reg("dummy", o.DK.S32))])) mem_base = fun.AddReg(ir.Reg("mem_base", addr_type)) bbl.AddIns(ir.Ins(o.LD_MEM, [mem_base, global_mem_base, ZERO])) if init_global: bbl.AddIns(ir.Ins(o.BSR, [init_global])) if init_data: bbl.AddIns(ir.Ins(o.PUSHARG, [mem_base])) bbl.AddIns(ir.Ins(o.BSR, [init_data])) bbl.AddIns(ir.Ins(o.PUSHARG, [mem_base])) bbl.AddIns(ir.Ins(o.BSR, [main])) bbl.AddIns(ir.Ins(o.PUSHARG, [ir.Const(o.DK.U32, 0)])) bbl.AddIns(ir.Ins(o.RET, [])) return fun
def GenerateInitDataFun(mod: wasm.Module, unit: ir.Unit, memcpy: ir.Fun, addr_type: o.DK) -> typing.Optional[ir.Fun]: fun = unit.AddFun( ir.Fun("init_data_fun", o.FUN_KIND.NORMAL, [], [addr_type])) bbl = fun.AddBbl(ir.Bbl("start")) epilog = fun.AddBbl(ir.Bbl("end")) epilog.AddIns(ir.Ins(o.RET, [])) section = mod.sections.get(wasm.SECTION_ID.DATA) mem_base = fun.AddReg(ir.Reg("mem_base", addr_type)) bbl.AddIns(ir.Ins(o.POPARG, [mem_base])) if not section: return None offset = fun.AddReg(ir.Reg("offset", o.DK.S32)) src = fun.AddReg(ir.Reg("src", addr_type)) dst = fun.AddReg(ir.Reg("dst", addr_type)) for n, data in enumerate(section.items): assert data.memory_index == 0 assert isinstance(data.offset, wasm.Expression) ins = GetInsFromInitializerExpression(data.offset) init = unit.AddMem(ir.Mem(f"global_init_mem_{n}", 16, o.MEM_KIND.RO)) init.AddData(ir.DataBytes(1, data.init)) if ins.opcode is wasm_opc.GLOBAL_GET: src_mem = unit.GetMem(f"global_vars_{int(ins.args[0])}") bbl.AddIns(ir.Ins(o.LD_MEM, [offset, src_mem, ZERO])) elif ins.opcode is wasm_opc.I32_CONST: bbl.AddIns(ir.Ins(o.MOV, [offset, ir.Const(o.DK.S32, ins.args[0])])) else: assert False, f"unsupported init instructions {ins}" bbl.AddIns(ir.Ins(o.LEA, [dst, mem_base, offset])) bbl.AddIns(ir.Ins(o.LEA_MEM, [src, init, ZERO])) bbl.AddIns(ir.Ins(o.PUSHARG, [ir.Const(o.DK.U32, len(data.init))])) bbl.AddIns(ir.Ins(o.PUSHARG, [src])) bbl.AddIns(ir.Ins(o.PUSHARG, [dst])) bbl.AddIns(ir.Ins(o.BSR, [memcpy])) return fun
def BblSpillRegs(bbl: ir.Bbl, fun: ir.Fun, regs: List[ir.Reg], offset_kind: o.DK, prefix) -> int: reg_to_stk: Dict[ir.Reg, ir.Stk] = {} for reg in regs: size = reg.kind.bitwidth() // 8 stk = ir.Stk(f"{prefix}_{reg.name}", size, size) reg_to_stk[reg] = stk fun.AddStk(stk) ir.BblGenericRewrite(bbl, fun, InsSpillRegs, zero_const=ir.Const(offset_kind, 0), reg_to_stk=reg_to_stk)
def _InsMoveImmediatesToMemory(ins: ir.Ins, fun: ir.Fun, unit: ir.Unit, kind: o.DK) -> Optional[List[ir.Ins]]: inss = [] for n, op in enumerate(ins.operands): if isinstance(op, ir.Const) and op.kind is kind: mem = unit.FindOrAddConstMem(op) tmp = fun.GetScratchReg(kind, "mem_const", True) # TODO: pass the offset kind as a parameter inss.append(ir.Ins(o.LD_MEM, [tmp, mem, ir.Const(o.DK.U32, 0)])) ins.operands[n] = tmp if inss: return inss + [ins] return None
def _InsEliminateStkLoadStoreWithRegOffset( ins: ir.Ins, fun: ir.Fun, base_kind: o.DK, offset_kind: o.DK) -> Optional[List[ir.Ins]]: """This rewrite is usually applied as prep step by some backends to get rid of Stk operands. It allows the register allocator to see the scratch register but it will obscure the fact that a memory access is a stack access. Note, a stack address already implies a `sp+offset` addressing mode and risk ISAs do no usually support `sp+offset+reg` addressing mode. """ opc = ins.opcode ops = ins.operands if opc is o.ST_STK and isinstance(ops[1], ir.Reg): scratch_reg = fun.GetScratchReg(base_kind, "base", False) lea = ir.Ins(o.LEA_STK, [scratch_reg, ops[0], ir.Const(offset_kind, 0)]) ins.Init(o.ST, [scratch_reg, ops[1], ops[2]]) return [lea, ins] elif opc is o.LD_STK and isinstance(ops[2], ir.Reg): scratch_reg = fun.GetScratchReg(base_kind, "base", False) lea = ir.Ins(o.LEA_STK, [scratch_reg, ops[1], ir.Const(offset_kind, 0)]) ins.Init(o.LD, [ops[0], scratch_reg, ops[2]]) return [lea, ins] elif opc is o.LEA_STK and isinstance(ops[2], ir.Reg): scratch_reg = fun.GetScratchReg(base_kind, "base", False) # TODO: maybe reverse the order so that we can tell that ops[0] holds a stack # location lea = ir.Ins(o.LEA_STK, [scratch_reg, ops[1], ir.Const(offset_kind, 0)]) ins.Init(o.LEA, [ops[0], scratch_reg, ops[2]]) return [lea, ins] else: return None
def MakeCompare(opc: wasm_opc.Opcode, op_stack): if wasm_opc.FLAGS.UNARY in opc.flags: # eqz op1 = op_stack.pop(-1) op2 = ir.Const(op1.kind, 0) return o.CMPEQ, ONE_S, ZERO_S, op1, op2, False cmp, swap_op, swap_res, unsigned = WASM_CMP_TO_CWERG_CMP[opc.basename] op2 = op_stack.pop(-1) op1 = op_stack.pop(-1) if swap_op: op1, op2 = op2, op1 res1 = ONE_S res2 = ZERO_S if swap_res: res1, res2 = res2, res1 return cmp, res1, res2, op1, op2, unsigned
else: inss.append( lowering.InsEliminateImmediateViaMov(ins, pos, fun)) inss.append(ins) return inss def _FunRewriteOutOfBoundsImmediates(fun: ir.Fun, unit: ir.Unit) -> int: return ir.FunGenericRewrite(fun, _InsRewriteOutOfBoundsImmediates, unit=unit) _SHIFT_MASK = { o.DK.S8: ir.Const(o.DK.S8, 7), o.DK.U8: ir.Const(o.DK.U8, 7), o.DK.S16: ir.Const(o.DK.S16, 15), o.DK.U16: ir.Const(o.DK.U16, 15), o.DK.S32: ir.Const(o.DK.S32, 31), o.DK.U32: ir.Const(o.DK.U32, 31), } def _InsRewriteDivRemShifts(ins: ir.Ins, fun: ir.Fun) -> Optional[List[ir.Ins]]: opc = ins.opcode ops = ins.operands if opc is o.DIV and ops[0].kind.flavor() != o.DK_FLAVOR_F: # note: we could leave it to the register allocator to pick a CpuReg for ops[2] # but then we would somehow have to ensure that the reg is NOT rdx.
from Base import canonicalize from Base import reg_alloc from Base import ir from Base import liveness from Base import lowering from Base import opcode_tab as o from Base import reg_stats from Base import sanity from Base import optimize from Base import serialize from CodeGenA64 import isel_tab from CodeGenA64 import regs _DUMMY_A32 = ir.Reg("dummy", o.DK.A32) _ZERO_OFFSET = ir.Const(o.DK.U32, 0) def _InsRewriteOutOfBoundsImmediates( ins: ir.Ins, fun: ir.Fun, unit: ir.Unit) -> Optional[List[ir.Ins]]: if ins.opcode in isel_tab.OPCODES_REQUIRING_SPECIAL_HANDLING: return None inss = [] mismatches = isel_tab.FindtImmediateMismatchesInBestMatchPattern(ins, True) assert mismatches != isel_tab.MATCH_IMPOSSIBLE, f"could not match opcode {ins} {ins.operands}" if mismatches == 0: return None for pos in range(o.MAX_OPERANDS): if mismatches & (1 << pos) != 0: const_kind = ins.operands[pos].kind
""" Convert WASM files to Cwerg """ import logging import typing import dataclasses import struct from FrontEndWASM import opcode_tab as wasm_opc import FrontEndWASM.parser as wasm from Base import ir from Base import opcode_tab as o from Base import serialize from Base import sanity ZERO = ir.Const(o.DK.U32, 0) ONE = ir.Const(o.DK.U32, 1) ZERO_S = ir.Const(o.DK.S32, 0) ONE_S = ir.Const(o.DK.S32, 1) WASI_FUNCTIONS = { "$wasi$args_get", "$wasi$args_sizes_get", "$wasi$environ_get", "$wasi$environ_sizes_get", # "$wasi$fd_write", "$wasi$fd_read", "$wasi$fd_seek", "$wasi$fd_close",
def FunRegWidthWidening(fun: ir.Fun, narrow_kind: o.DK, wide_kind: o.DK): """ Change the type of all register (and constants) of type src_kind into dst_kind. Add compensation code where necessary. dst_kind must be wider than src_kind. This is useful for target architectures that do not support operations for all operand widths. Note, this also widens input and output regs. So this must run for all functions including prototypes TODO: double check if we are doing the right thing with o.CONV TODO: there are more subtle bugs. For example mul x:U8 = 43 * 47 (= 229) div y:u8 = x 13 (= 17) whereas: mul x:U16 = 43 * 47 (= 2021) div y:u16 = x 13 (= 155) Other problematic operations: rem, popcnt, ... """ assert ir.FUN_FLAG.STACK_FINALIZED not in fun.flags fun.input_types = [ wide_kind if x == narrow_kind else x for x in fun.input_types ] fun.output_types = [ wide_kind if x == narrow_kind else x for x in fun.output_types ] assert narrow_kind.flavor() == wide_kind.flavor() assert narrow_kind.bitwidth() < wide_kind.bitwidth() narrow_regs = { reg for reg in fun.reg_syms.values() if reg.kind == narrow_kind } for reg in narrow_regs: reg.kind = wide_kind count = 0 for bbl in fun.bbls: inss = [] for ins in bbl.inss: ops = ins.operands kind = ins.opcode.kind for n, reg in enumerate(ops): if n == 2 and kind is o.OPC_KIND.ST or n == 0 and kind is o.OPC_KIND.LD: continue if isinstance(reg, ir.Const) and reg.kind is narrow_kind: # if ins.opcode.constraints[n] == o.TC.OFFSET: # continue ops[n] = ir.Const(wide_kind, reg.value) kind = ins.opcode.kind if kind is o.OPC_KIND.LD and ops[0] in narrow_regs: inss.append(ins) tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True) inss.append(ir.Ins(o.CONV, [ops[0], tmp_reg])) ops[0] = tmp_reg elif (kind is o.OPC_KIND.ST and isinstance(ops[2], ir.Reg) and ops[2] in narrow_regs): tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True) inss.append(ir.Ins(o.CONV, [tmp_reg, ops[2]])) inss.append(ins) ops[2] = tmp_reg else: inss.append(ins) count += len(inss) - len(bbl.inss) bbl.inss = inss return count
def GenerateFun(unit: ir.Unit, mod: wasm.Module, wasm_fun: wasm.Function, fun: ir.Fun, global_table, addr_type): # op_stack contains regs produced by GetOpReg (it may only occur a the position encoding in its name op_stack: typing.List[typing.Union[ir.Reg, ir.Const]] = [] block_stack: typing.List[Block] = [] bbls: typing.List[ir.Bbl] = [] bbl_count = 0 jtb_count = 0 bbls.append(fun.AddBbl(ir.Bbl("start"))) loc_index = 0 assert fun.input_types[0] is addr_type mem_base = fun.AddReg(ir.Reg("mem_base", addr_type)) bbls[-1].AddIns(ir.Ins(o.POPARG, [mem_base])) for dk in fun.input_types[1:]: reg = fun.AddReg(ir.Reg(f"$loc_{loc_index}", dk)) loc_index += 1 bbls[-1].AddIns(ir.Ins(o.POPARG, [reg])) for locals in wasm_fun.impl.locals_list: for i in range(locals.count): reg = fun.AddReg( ir.Reg(f"$loc_{loc_index}", WASM_TYPE_TO_CWERG_TYPE[locals.kind])) loc_index += 1 bbls[-1].AddIns(ir.Ins(o.MOV, [reg, ir.Const(reg.kind, 0)])) DEBUG = False if DEBUG: print() print( f"# {fun.name} #ins:{len(wasm_fun.impl.expr.instructions)} in:{fun.input_types} out:{fun.output_types}" ) opc = None last_opc = None for n, wasm_ins in enumerate(wasm_fun.impl.expr.instructions): last_opc = opc opc = wasm_ins.opcode args = wasm_ins.args op_stack_size_before = len(op_stack) if DEBUG: print(f"#@@ {opc.name}", args, len(op_stack)) if opc.kind is wasm_opc.OPC_KIND.CONST: # breaks for floats # breaks for floats kind = OPC_TYPE_TO_CWERG_TYPE[opc.op_type] dst = GetOpReg(fun, kind, len(op_stack)) bbls[-1].AddIns(ir.Ins(o.MOV, [dst, ir.Const(kind, args[0])])) op_stack.append(dst) elif opc is wasm_opc.NOP: pass elif opc is wasm_opc.DROP: op_stack.pop(-1) elif opc is wasm_opc.LOCAL_GET: loc = GetLocalReg(fun, int(args[0])) dst = GetOpReg(fun, loc.kind, len(op_stack)) bbls[-1].AddIns(ir.Ins(o.MOV, [dst, loc])) op_stack.append(dst) elif opc is wasm_opc.LOCAL_SET: op = op_stack.pop(-1) bbls[-1].AddIns(ir.Ins(o.MOV, [GetLocalReg(fun, int(args[0])), op])) elif opc is wasm_opc.LOCAL_TEE: op = op_stack[-1] # no pop! bbls[-1].AddIns(ir.Ins(o.MOV, [GetLocalReg(fun, int(args[0])), op])) elif opc is wasm_opc.GLOBAL_GET: var_index = int(args[0]) var: wasm.Glob = mod.sections.get( wasm.SECTION_ID.GLOBAL).items[var_index] dst = GetOpReg(fun, WASM_TYPE_TO_CWERG_TYPE[var.global_type.value_type], len(op_stack)) var_mem = unit.GetMem(f"global_vars_{var_index}") bbls[-1].AddIns(ir.Ins(o.LD_MEM, [dst, var_mem, ZERO])) op_stack.append(dst) elif opc is wasm_opc.GLOBAL_SET: op = op_stack.pop(-1) var_index = int(args[0]) var_mem = unit.GetMem(f"global_vars_{var_index}") bbls[-1].AddIns(ir.Ins(o.ST_MEM, [var_mem, ZERO, op])) elif opc.kind is wasm_opc.OPC_KIND.ALU: if wasm_opc.FLAGS.UNARY in opc.flags: opcode, arg_factory = WASM_ALU1_TO_CWERG[opc.basename] op = op_stack.pop(-1) dst = GetOpReg(fun, op.kind, len(op_stack)) bbls[-1].AddIns(ir.Ins(opcode, [dst] + arg_factory(op))) op_stack.append(dst) else: op2 = op_stack.pop(-1) op1 = op_stack.pop(-1) dst = GetOpReg(fun, op1.kind, len(op_stack)) alu = WASM_ALU_TO_CWERG[opc.basename] if wasm_opc.FLAGS.UNSIGNED in opc.flags: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) tmp3 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 3) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2])) if isinstance(alu, o.Opcode): bbls[-1].AddIns(ir.Ins(alu, [tmp3, tmp1, tmp2])) else: alu(tmp3, tmp1, tmp2, bbls[-1]) bbls[-1].AddIns(ir.Ins(o.CONV, [dst, tmp3])) else: if isinstance(alu, o.Opcode): bbls[-1].AddIns(ir.Ins(alu, [dst, op1, op2])) else: alu(dst, op1, op2, bbls[-1]) op_stack.append(dst) elif opc.kind is wasm_opc.OPC_KIND.CONV: conv, dst_unsigned, src_unsigned = WASM_CONV_TO_CWERG[opc.name] op = op_stack.pop(-1) dst = GetOpReg(fun, OPC_TYPE_TO_CWERG_TYPE[opc.op_type], len(op_stack)) if src_unsigned: tmp = GetOpReg(fun, ToUnsigned(op.kind), op_stack_size_before + 1) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp, op])) op = tmp if dst_unsigned: tmp = GetOpReg(fun, ToUnsigned(dst.kind), op_stack_size_before + 1) dst, tmp = tmp, dst bbls[-1].AddIns(ir.Ins(conv, [dst, op])) if dst_unsigned: bbls[-1].AddIns(ir.Ins(o.CONV, [tmp, dst])) dst = tmp op_stack.append(dst) elif opc.kind is wasm_opc.OPC_KIND.BITCAST: assert opc.name in SUPPORTED_BITCAST op = op_stack.pop(-1) dst = GetOpReg(fun, OPC_TYPE_TO_CWERG_TYPE[opc.op_type], len(op_stack)) bbls[-1].AddIns(ir.Ins(o.BITCAST, [dst, op])) op_stack.append(dst) elif opc.kind is wasm_opc.OPC_KIND.CMP: # this always works because of the sentinel: "end" succ = wasm_fun.impl.expr.instructions[n + 1] if succ.opcode not in { wasm_opc.IF, wasm_opc.BR_IF, wasm_opc.SELECT }: cmp, res1, res2, op1, op2, unsigned = MakeCompare( opc, op_stack) if unsigned: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2])) op1 = tmp1 op2 = tmp2 dst = GetOpReg(fun, o.DK.S32, len(op_stack)) bbls[-1].AddIns(ir.Ins(cmp, [dst, res1, res2, op1, op2])) op_stack.append(dst) elif opc is wasm_opc.LOOP or opc is wasm_opc.BLOCK: block_stack.append( MakeBlock(bbl_count, opc, args, fun, op_stack, mod)) bbl_count += 1 bbls.append(block_stack[-1].start_bbl) elif opc is wasm_opc.IF: # note we do set the new bbl right away because we add some instructions to the old one # this always works because the stack cannot be empty at this point pred = wasm_fun.impl.expr.instructions[n - 1].opcode br, op1, op2, unsigned = MakeBranch(pred, op_stack, True) if unsigned: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2])) op1 = tmp1 op2 = tmp2 block_stack.append( MakeBlock(bbl_count, opc, args, fun, op_stack, mod)) # assert len(block_stack[-1].param_types) == 0 bbl_count += 1 bbls[-1].AddIns(ir.Ins(br, [op1, op2, block_stack[-1].else_bbl])) bbls.append(block_stack[-1].start_bbl) elif opc is wasm_opc.ELSE: block = block_stack[-1] assert block.opcode is wasm_opc.IF block.FinalizeIf(op_stack, bbls[-1], fun, last_opc in {wasm_opc.UNREACHABLE, wasm_opc.BR}) if DEBUG: print(f"#@@ AFTER STACK RESTORE", len(op_stack)) op_stack = op_stack[0:block.stack_start] bbls[-1].AddIns(ir.Ins(o.BRA, [block.end_bbl])) assert block.else_bbl is not None bbls.append(block.else_bbl) block.else_bbl = None elif opc is wasm_opc.END: if block_stack: block = block_stack.pop(-1) block.FinalizeEnd( op_stack, bbls[-1], fun, last_opc in {wasm_opc.UNREACHABLE, wasm_opc.BR}) if block.else_bbl: bbls.append(block.else_bbl) bbls.append(block.end_bbl) else: # end of function assert n + 1 == len(wasm_fun.impl.expr.instructions) pred = wasm_fun.impl.expr.instructions[n - 1].opcode if pred not in { wasm_opc.RETURN, wasm_opc.UNREACHABLE, wasm_opc.BR }: for x in reversed(fun.output_types): op = op_stack.pop(-1) assert op.kind == x, f"outputs: {fun.output_types} mismatch {op.kind} vs {x}" bbls[-1].AddIns(ir.Ins(o.PUSHARG, [op])) bbls[-1].AddIns(ir.Ins(o.RET, [])) elif opc is wasm_opc.CALL: wasm_callee = mod.functions[int(wasm_ins.args[0])] callee = unit.GetFun(wasm_callee.name) assert callee, f"unknown fun: {wasm_callee.name}" EmitCall(fun, bbls[-1], ir.Ins(o.BSR, [callee]), op_stack, mem_base, callee) elif opc is wasm_opc.CALL_INDIRECT: assert isinstance(args[1], wasm.TableIdx), f"{type(args[1])}" assert int(args[1]) == 0, f"only one table supported" assert isinstance(args[0], wasm.TypeIdx), f"{type(args[0])}" type_sec = mod.sections.get(wasm.SECTION_ID.TYPE) func_type: wasm.FunctionType = type_sec.items[int(args[0])] arguments = [addr_type] + TranslateTypeList(func_type.args) returns = TranslateTypeList(func_type.rets) # print (f"# @@@@ CALL INDIRECT {returns} <- {arguments} [{int(args[0])}] {func_type}") signature = FindFunWithSignature(unit, arguments, returns) table_reg = GetOpReg(fun, addr_type, len(op_stack)) code_type = o.DK.C32 if addr_type is o.DK.A32 else o.DK.C64 fun_reg = GetOpReg(fun, code_type, len(op_stack) + 1) index = op_stack.pop(-1) assert index.kind is o.DK.S32 bbls[-1].AddIns( ir.Ins(o.MUL, [ index, index, ir.Const(o.DK.U32, code_type.bitwidth() // 8) ])) bbls[-1].AddIns(ir.Ins(o.LEA_MEM, [table_reg, global_table, ZERO])) bbls[-1].AddIns(ir.Ins(o.LD, [fun_reg, table_reg, index])) EmitCall(fun, bbls[-1], ir.Ins(o.JSR, [fun_reg, signature]), op_stack, mem_base, signature) elif opc is wasm_opc.RETURN: for x in reversed(fun.output_types): op = op_stack.pop(-1) assert op.kind == x, f"outputs: {fun.output_types} mismatch {op.kind} vs {x}" bbls[-1].AddIns(ir.Ins(o.PUSHARG, [op])) bbls[-1].AddIns(ir.Ins(o.RET, [])) elif opc is wasm_opc.BR: assert isinstance(args[0], wasm.LabelIdx) block = GetTargetBlock(block_stack, args[0]) target = block.start_bbl if block.opcode is not wasm_opc.LOOP: target = block.end_bbl block.FinalizeResultsCopy(op_stack, bbls[-1], fun) bbls[-1].AddIns(ir.Ins(o.BRA, [target])) elif opc is wasm_opc.BR_IF: assert isinstance(args[0], wasm.LabelIdx) pred = wasm_fun.impl.expr.instructions[n - 1].opcode br, op1, op2, unsigned = MakeBranch(pred, op_stack, False) if unsigned: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2])) op1 = tmp1 op2 = tmp2 block = GetTargetBlock(block_stack, args[0]) target = block.start_bbl if block.opcode is not wasm_opc.LOOP: target = block.end_bbl block.FinalizeResultsCopy(op_stack, bbls[-1], fun) bbls[-1].AddIns(ir.Ins(br, [op1, op2, target])) elif opc is wasm_opc.SELECT: pred = wasm_fun.impl.expr.instructions[n - 1].opcode br, op1, op2, unsigned = MakeBranch(pred, op_stack, False) val_f = op_stack.pop(-1) val_t = op_stack.pop(-1) assert val_f.kind == val_t.kind reg = GetOpReg(fun, val_f.kind, len(op_stack)) op_stack.append(reg) bbls[-1].AddIns(ir.Ins(o.MOV, [reg, val_t])) bbls.append(fun.AddBbl(ir.Bbl(f"select_{bbl_count}"))) bbl_count += 1 if unsigned: tmp1 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 1) tmp2 = GetOpReg(fun, ToUnsigned(op1.kind), op_stack_size_before + 2) bbls[-2].AddIns(ir.Ins(o.CONV, [tmp1, op1])) bbls[-2].AddIns(ir.Ins(o.CONV, [tmp2, op2])) op1 = tmp1 op2 = tmp2 bbls[-2].AddIns(ir.Ins(br, [op1, op2, bbls[-1]])) bbls[-2].AddIns(ir.Ins(o.MOV, [reg, val_f])) elif opc.kind is wasm_opc.OPC_KIND.STORE: val = op_stack.pop(-1) offset = op_stack.pop(-1) if args[1] != 0: tmp = GetOpReg(fun, offset.kind, len(op_stack)) bbls[-1].AddIns( ir.Ins(o.ADD, [tmp, offset, ir.Const(offset.kind, args[1])])) offset = tmp dk_tmp = STORE_TO_CWERG_TYPE.get(opc.name) if dk_tmp is not None: tmp = GetOpReg(fun, dk_tmp, len(op_stack) + 1) bbls[-1].AddIns(ir.Ins(o.CONV, [tmp, val])) val = tmp bbls[-1].AddIns(ir.Ins(o.ST, [mem_base, offset, val])) elif opc.kind is wasm_opc.OPC_KIND.LOAD: offset = op_stack.pop(-1) if args[1] != 0: tmp = GetOpReg(fun, offset.kind, len(op_stack)) bbls[-1].AddIns( ir.Ins(o.ADD, [tmp, offset, ir.Const(offset.kind, args[1])])) offset = tmp dst = GetOpReg(fun, OPC_TYPE_TO_CWERG_TYPE[opc.op_type], len(op_stack)) op_stack.append(dst) dk_tmp = LOAD_TO_CWERG_TYPE.get(opc.basename) if dk_tmp: tmp = GetOpReg(fun, dk_tmp, len(op_stack)) bbls[-1].AddIns(ir.Ins(o.LD, [tmp, mem_base, offset])) bbls[-1].AddIns(ir.Ins(o.CONV, [dst, tmp])) else: bbls[-1].AddIns(ir.Ins(o.LD, [dst, mem_base, offset])) elif opc is wasm_opc.BR_TABLE: bbl_tab = { n: GetTargetBbl(block_stack, x) for n, x in enumerate(args[0]) } bbl_def = GetTargetBbl(block_stack, args[1]) op = op_stack.pop(-1) tab_size = ir.Const(ToUnsigned(op.kind), len(bbl_tab)) jtb_count += 1 jtb = fun.AddJtb( ir.Jtb(f"jtb_{jtb_count}", bbl_def, bbl_tab, tab_size.value)) reg_unsigned = GetOpReg(fun, ToUnsigned(op.kind), len(op_stack)) bbls[-1].AddIns(ir.Ins(o.CONV, [reg_unsigned, op])) bbls[-1].AddIns(ir.Ins(o.BLE, [tab_size, reg_unsigned, bbl_def])) bbls[-1].AddIns(ir.Ins(o.SWITCH, [reg_unsigned, jtb])) elif opc is wasm_opc.UNREACHABLE: bbls[-1].AddIns(ir.Ins(o.TRAP, [])) elif opc is wasm_opc.MEMORY_GROW or opc is wasm_opc.MEMORY_SIZE: op = ZERO_S if opc is wasm_opc.MEMORY_GROW: op = op_stack.pop(-1) bbls[-1].AddIns(ir.Ins(o.PUSHARG, [op])) assert unit.GetFun("__memory_grow") bbls[-1].AddIns(ir.Ins(o.BSR, [unit.GetFun("__memory_grow")])) dst = GetOpReg(fun, o.DK.S32, len(op_stack)) bbls[-1].AddIns(ir.Ins(o.POPARG, [dst])) op_stack.append(dst) else: assert False, f"unsupported opcode [{opc.name}]" assert not op_stack, f"op_stack not empty in {fun.name}: {op_stack}" assert not block_stack, f"block_stack not empty in {fun.name}: {block_stack}" assert len(bbls) == len(fun.bbls) fun.bbls = bbls
def EvaluatateALU1(opcode: o.Opcode, op: ir.Const) -> Optional[ir.Const]: evaluator = _EVALUATORS_ALU1.get(opcode) assert evaluator, f"Evaluator NYI for: {opcode}" return ir.Const(op.kind, _truncate(op.kind, evaluator(op.kind, op.value)))
def EvaluatateALU(opcode: o.Opcode, op1: ir.Const, op2: ir.Const) -> ir.Const: evaluator = _EVALUATORS_ALU.get(opcode) assert evaluator, f"Evaluator NYI for: {opcode}" return ir.Const(op1.kind, _truncate(op1.kind, evaluator(op1.value, op2.value)))
def testA(self): self.assertEqual( 1000, reaching_defs.ConvertIntValue(o.DK.U32, ir.Const(o.DK.U16, 1000)).value) self.assertEqual( 100, reaching_defs.ConvertIntValue(o.DK.U32, ir.Const(o.DK.U8, 100)).value) self.assertEqual( -1000, reaching_defs.ConvertIntValue(o.DK.S32, ir.Const(o.DK.S16, -1000)).value) self.assertEqual( -100, reaching_defs.ConvertIntValue(o.DK.S32, ir.Const(o.DK.S8, -100)).value) self.assertEqual( -127, reaching_defs.ConvertIntValue(o.DK.S32, ir.Const(o.DK.S8, -127)).value) self.assertEqual( 24, reaching_defs.ConvertIntValue(o.DK.S8, ir.Const(o.DK.S32, -1000)).value) self.assertEqual( -20, reaching_defs.ConvertIntValue(o.DK.S8, ir.Const(o.DK.S32, -1300)).value) self.assertEqual( 0xfffffc18, reaching_defs.ConvertIntValue(o.DK.U32, ir.Const(o.DK.S16, -1000)).value) self.assertEqual( 0xfc18, reaching_defs.ConvertIntValue(o.DK.U16, ir.Const(o.DK.S16, -1000)).value) self.assertEqual( 0xfffffc18, reaching_defs.ConvertIntValue(o.DK.U32, ir.Const(o.DK.S32, -1000)).value) self.assertEqual( -1000, reaching_defs.ConvertIntValue(o.DK.S32, ir.Const(o.DK.U32, 0xfffffc18)).value)
def _InsConstantFold(ins: ir.Ins, bbl: ir.Bbl, _fun: ir.Fun, allow_conv_conversion: bool) -> Optional[List[ir.Ins]]: """ Try combining the constant from ins_def with the instruction in ins Return 1 iff a change was made Note: None of the transformations must change the def register - otherwise the reaching_defs will be stale """ ops = ins.operands kind = ins.opcode.kind if kind is o.OPC_KIND.COND_BRA: if not isinstance(ops[0], ir.Const) or not isinstance( ops[1], ir.Const): return None # TODO: implement this, needs access to BBL for CFG changes evaluator = _EVALUATORS_COND_BRA.get(ins.opcode) assert evaluator, f"Evaluator NYI for: {ins} {ins.operands}" branch_taken = evaluator(ops[0].value, ops[1].value) target = ops[2] assert len(bbl.edge_out) == 2 if branch_taken: succ_to_drop = bbl.edge_out[1] if bbl.edge_out[0] == target else \ bbl.edge_out[0] else: succ_to_drop = target bbl.DelEdgeOut(succ_to_drop) return [] elif kind is o.OPC_KIND.ALU1: if not isinstance(ops[1], ir.Const): return None assert False, f"Evaluator NYI for ALU1: {ins} {ins.operands}" elif kind is o.OPC_KIND.ALU: if not isinstance(ops[1], ir.Const) or not isinstance( ops[2], ir.Const): return None evaluator = _EVALUATORS_ALU.get(ins.opcode) assert evaluator, f"Evaluator NYI for: {ins} {ins.operands}" val = ir.Const(ops[1].kind, evaluator(ops[1].value, ops[2].value)) ins.opcode = o.MOV ins.operands.pop(-1) ins.operands[1] = val ins.operand_defs.pop(-1) ins.operand_defs[1] = ir.INS_INVALID return [ins] elif ins.opcode is o.CONV: # TODO: this needs some more thought generally but in # particular when we apply register widening # transformations, conv instructions end up being the only # ones with narrow width regs which simplifies # code generation. By allowing this to be converted into a # mov instruction we may leak the narrow register. if not allow_conv_conversion or not isinstance(ops[1], ir.Const): return None dst: ir.Reg = ops[0] src = ops[1] if not o.RegIsAddrInt(src.kind) or not o.RegIsAddrInt(dst.kind): return None new_val = ConvertIntValue(dst.kind, src) ins.Init(o.MOV, [dst, new_val]) return [ins] else: return None
def FunRegWidthWidening(fun: ir.Fun, narrow_kind: o.DK, wide_kind: o.DK): """ Change the type of all register (and constants) of type src_kind into dst_kind. Add compensation code where necessary. dst_kind must be wider than src_kind. This is useful for target architectures that do not support operations for all operand widths. Note, this also widens input and output regs. So this must run for all functions including prototypes TODO: double check if we are doing the right thing with o.CONV TODO: there are more subtle bugs. For example mul x:U8 = 43 * 47 (= 229) div y:u8 = x 13 (= 17) whereas: mul x:U16 = 43 * 47 (= 2021) div y:u16 = x 13 (= 155) Other problematic operations: rem, popcnt, ... The invariant we are maintaining is this one: if reg a gets widened into reg b with bitwidth(a) = w then the lower w bits of reg b will always contain the same data as reg a would have. """ assert ir.FUN_FLAG.STACK_FINALIZED not in fun.flags fun.input_types = [ wide_kind if x == narrow_kind else x for x in fun.input_types ] fun.output_types = [ wide_kind if x == narrow_kind else x for x in fun.output_types ] assert narrow_kind.flavor() == wide_kind.flavor() assert narrow_kind.bitwidth() < wide_kind.bitwidth() narrow_regs = { reg for reg in fun.reg_syms.values() if reg.kind == narrow_kind } for reg in narrow_regs: reg.kind = wide_kind count = 0 for bbl in fun.bbls: inss = [] for ins in bbl.inss: ops = ins.operands kind = ins.opcode.kind changed = False for n, reg in enumerate(ops): if isinstance(reg, ir.Const) and reg.kind is narrow_kind: if kind is o.OPC_KIND.ST and n == 2: continue ops[n] = ir.Const(wide_kind, reg.value) changed = True if isinstance(reg, ir.Reg) and reg in narrow_regs: changed = True if not changed: inss.append(ins) continue kind = ins.opcode.kind if ins.opcode is o.SHL or ins.opcode is o.SHR: # deal with the shift amount which is subject to an implicit modulo "bitwidth -1" # by changing the width of the reg - we lose this information tmp_reg = fun.GetScratchReg(wide_kind, "tricky", False) inss.append( ir.Ins(o.AND, [ tmp_reg, ops[2], ir.Const(wide_kind, narrow_kind.bitwidth() - 1) ])) ops[2] = tmp_reg if ins.opcode is o.SHR and isinstance(ops[1], ir.Reg): # for SHR we also need to make sure the new high order bits are correct tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True) inss.append(ir.Ins(o.CONV, [tmp_reg, ops[1]])) # the implicit understanding is that this will become nop or a move and not modify the # high-order bit we just set in the previous instruction inss.append(ir.Ins(o.CONV, [ops[1], tmp_reg])) inss.append(ins) elif ins.opcode is o.CNTLZ: inss.append(ins) excess = wide_kind.bitwidth() - narrow_kind.bitwidth() inss.append( ir.Ins(o.SUB, [ops[0], ops[0], ir.Const(wide_kind, excess)])) elif ins.opcode is o.CNTTZ: inss.append(ins) inss.append( ir.Ins(o.CMPLT, [ ops[0], ops[0], ir.Const(wide_kind, narrow_kind.bitwidth()), ops[0], ir.Const(wide_kind, narrow_kind.bitwidth()) ])) elif kind is o.OPC_KIND.LD and ops[0] in narrow_regs: inss.append(ins) tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True) inss.append(ir.Ins(o.CONV, [ops[0], tmp_reg])) ops[0] = tmp_reg elif (kind is o.OPC_KIND.ST and isinstance(ops[2], ir.Reg) and ops[2] in narrow_regs): tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True) inss.append(ir.Ins(o.CONV, [tmp_reg, ops[2]])) inss.append(ins) ops[2] = tmp_reg elif ins.opcode is o.CONV: tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True) inss.append(ir.Ins(o.CONV, [tmp_reg, ops[1]])) inss.append(ir.Ins(o.CONV, [ops[0], tmp_reg])) else: inss.append(ins) count += len(inss) - len(bbl.inss) bbl.inss = inss return count