def FunRegWidthWidening(fun: ir.Fun, narrow_kind: o.DK, wide_kind: o.DK): """ Change the type of all register (and constants) of type src_kind into dst_kind. Add compensation code where necessary. dst_kind must be wider than src_kind. This is useful for target architectures that do not support operations for all operand widths. Note, this also widens input and output regs. So this must run for all functions including prototypes TODO: double check if we are doing the right thing with o.CONV TODO: there are more subtle bugs. For example mul x:U8 = 43 * 47 (= 229) div y:u8 = x 13 (= 17) whereas: mul x:U16 = 43 * 47 (= 2021) div y:u16 = x 13 (= 155) Other problematic operations: rem, popcnt, ... """ assert ir.FUN_FLAG.STACK_FINALIZED not in fun.flags fun.input_types = [ wide_kind if x == narrow_kind else x for x in fun.input_types ] fun.output_types = [ wide_kind if x == narrow_kind else x for x in fun.output_types ] assert narrow_kind.flavor() == wide_kind.flavor() assert narrow_kind.bitwidth() < wide_kind.bitwidth() narrow_regs = { reg for reg in fun.reg_syms.values() if reg.kind == narrow_kind } for reg in narrow_regs: reg.kind = wide_kind count = 0 for bbl in fun.bbls: inss = [] for ins in bbl.inss: ops = ins.operands kind = ins.opcode.kind for n, reg in enumerate(ops): if n == 2 and kind is o.OPC_KIND.ST or n == 0 and kind is o.OPC_KIND.LD: continue if isinstance(reg, ir.Const) and reg.kind is narrow_kind: # if ins.opcode.constraints[n] == o.TC.OFFSET: # continue ops[n] = ir.Const(wide_kind, reg.value) kind = ins.opcode.kind if kind is o.OPC_KIND.LD and ops[0] in narrow_regs: inss.append(ins) tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True) inss.append(ir.Ins(o.CONV, [ops[0], tmp_reg])) ops[0] = tmp_reg elif (kind is o.OPC_KIND.ST and isinstance(ops[2], ir.Reg) and ops[2] in narrow_regs): tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True) inss.append(ir.Ins(o.CONV, [tmp_reg, ops[2]])) inss.append(ins) ops[2] = tmp_reg else: inss.append(ins) count += len(inss) - len(bbl.inss) bbl.inss = inss return count
def FunRegWidthWidening(fun: ir.Fun, narrow_kind: o.DK, wide_kind: o.DK): """ Change the type of all register (and constants) of type src_kind into dst_kind. Add compensation code where necessary. dst_kind must be wider than src_kind. This is useful for target architectures that do not support operations for all operand widths. Note, this also widens input and output regs. So this must run for all functions including prototypes TODO: double check if we are doing the right thing with o.CONV TODO: there are more subtle bugs. For example mul x:U8 = 43 * 47 (= 229) div y:u8 = x 13 (= 17) whereas: mul x:U16 = 43 * 47 (= 2021) div y:u16 = x 13 (= 155) Other problematic operations: rem, popcnt, ... The invariant we are maintaining is this one: if reg a gets widened into reg b with bitwidth(a) = w then the lower w bits of reg b will always contain the same data as reg a would have. """ assert ir.FUN_FLAG.STACK_FINALIZED not in fun.flags fun.input_types = [ wide_kind if x == narrow_kind else x for x in fun.input_types ] fun.output_types = [ wide_kind if x == narrow_kind else x for x in fun.output_types ] assert narrow_kind.flavor() == wide_kind.flavor() assert narrow_kind.bitwidth() < wide_kind.bitwidth() narrow_regs = { reg for reg in fun.reg_syms.values() if reg.kind == narrow_kind } for reg in narrow_regs: reg.kind = wide_kind count = 0 for bbl in fun.bbls: inss = [] for ins in bbl.inss: ops = ins.operands kind = ins.opcode.kind changed = False for n, reg in enumerate(ops): if isinstance(reg, ir.Const) and reg.kind is narrow_kind: if kind is o.OPC_KIND.ST and n == 2: continue ops[n] = ir.Const(wide_kind, reg.value) changed = True if isinstance(reg, ir.Reg) and reg in narrow_regs: changed = True if not changed: inss.append(ins) continue kind = ins.opcode.kind if ins.opcode is o.SHL or ins.opcode is o.SHR: # deal with the shift amount which is subject to an implicit modulo "bitwidth -1" # by changing the width of the reg - we lose this information tmp_reg = fun.GetScratchReg(wide_kind, "tricky", False) inss.append( ir.Ins(o.AND, [ tmp_reg, ops[2], ir.Const(wide_kind, narrow_kind.bitwidth() - 1) ])) ops[2] = tmp_reg if ins.opcode is o.SHR and isinstance(ops[1], ir.Reg): # for SHR we also need to make sure the new high order bits are correct tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True) inss.append(ir.Ins(o.CONV, [tmp_reg, ops[1]])) # the implicit understanding is that this will become nop or a move and not modify the # high-order bit we just set in the previous instruction inss.append(ir.Ins(o.CONV, [ops[1], tmp_reg])) inss.append(ins) elif ins.opcode is o.CNTLZ: inss.append(ins) excess = wide_kind.bitwidth() - narrow_kind.bitwidth() inss.append( ir.Ins(o.SUB, [ops[0], ops[0], ir.Const(wide_kind, excess)])) elif ins.opcode is o.CNTTZ: inss.append(ins) inss.append( ir.Ins(o.CMPLT, [ ops[0], ops[0], ir.Const(wide_kind, narrow_kind.bitwidth()), ops[0], ir.Const(wide_kind, narrow_kind.bitwidth()) ])) elif kind is o.OPC_KIND.LD and ops[0] in narrow_regs: inss.append(ins) tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True) inss.append(ir.Ins(o.CONV, [ops[0], tmp_reg])) ops[0] = tmp_reg elif (kind is o.OPC_KIND.ST and isinstance(ops[2], ir.Reg) and ops[2] in narrow_regs): tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True) inss.append(ir.Ins(o.CONV, [tmp_reg, ops[2]])) inss.append(ins) ops[2] = tmp_reg elif ins.opcode is o.CONV: tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True) inss.append(ir.Ins(o.CONV, [tmp_reg, ops[1]])) inss.append(ir.Ins(o.CONV, [ops[0], tmp_reg])) else: inss.append(ins) count += len(inss) - len(bbl.inss) bbl.inss = inss return count