Esempio n. 1
0
def _InsEliminateRem(ins: ir.Ins, fun: ir.Fun) -> Optional[List[ir.Ins]]:
    """Rewrites modulo instructions like so:
    z = a % b
    becomes
    z = a // b
    z = z * b
    z = a - z
    TODO: double check that this works out for corner-cases
    """

    if ins.opcode is not o.REM:
        return None
    ops = ins.operands
    out = []
    tmp_reg1 = fun.GetScratchReg(ops[0].kind, "elim_rem1", True)
    out.append(ir.Ins(o.DIV, [tmp_reg1, ops[1], ops[2]]))
    # NOTE: this implementation for floating mod may have precision issues.
    if ops[0].kind.flavor() is o.DK_FLAVOR_F:
        tmp_reg3 = fun.GetScratchReg(ops[0].kind, "elim_rem3", True)
        out.append(ir.Ins(o.TRUNC, [tmp_reg3, tmp_reg1]))
        tmp_reg1 = tmp_reg3
    tmp_reg2 = fun.GetScratchReg(ops[0].kind, "elim_rem2", True)
    out.append(ir.Ins(o.MUL, [tmp_reg2, tmp_reg1, ops[2]]))
    out.append(ir.Ins(o.SUB, [ops[0], ops[1], tmp_reg2]))
    return out
Esempio n. 2
0
def _InsEliminateCopySign(ins: ir.Ins, fun: ir.Fun) -> Optional[List[ir.Ins]]:
    """Rewrites copysign instructions like so:
    z = copysign a  b
    aa = int(a) & 0x7f...f
    bb = int(b) & 0x80...0
    z = flt(aa | bb)
    """

    if ins.opcode is not o.COPYSIGN:
        return None
    ops = ins.operands
    out = []
    if ops[0].kind == o.DK.F32:
        kind = o.DK.U32
        sign = 1 << 31
        mask = sign - 1
    else:
        kind = o.DK.U64
        sign = 1 << 63
        mask = sign - 1

    tmp_src1 = fun.GetScratchReg(kind, "elim_copysign1", False)
    out.append(ir.Ins(o.BITCAST, [tmp_src1, ops[1]]))
    out.append(ir.Ins(o.AND, [tmp_src1, tmp_src1, ir.Const(kind, mask)]))
    #
    tmp_src2 = fun.GetScratchReg(kind, "elim_copysign2", False)
    out.append(ir.Ins(o.BITCAST, [tmp_src2, ops[2]]))
    out.append(ir.Ins(o.AND, [tmp_src2, tmp_src2, ir.Const(kind, sign)]))
    #
    out.append(ir.Ins(o.OR, [tmp_src1, tmp_src1, tmp_src2]))
    out.append(ir.Ins(o.BITCAST, [ops[0], tmp_src1]))
    return out
Esempio n. 3
0
def _InsEliminateMemLoadStore(ins: ir.Ins, fun: ir.Fun, base_kind: o.DK,
                              offset_kind: o.DK) -> Optional[List[ir.Ins]]:
    """This rewrite is usually applied as prep step by some backends
     to get rid of Mem operands.
     It allows the register allocator to see the scratch register but
     it will obscure the fact that a ld/st is from a static location.

     Note: this function may add local registers which does not affect liveness or use-deg chains
    """
    opc = ins.opcode
    ops = ins.operands
    if opc is o.ST_MEM:
        st_offset = ops[1]
        lea_offset = ir.Const(offset_kind, 0)
        if isinstance(st_offset, ir.Const):
            st_offset, lea_offset = lea_offset, st_offset
        scratch_reg = fun.GetScratchReg(base_kind, "base", False)
        lea = ir.Ins(o.LEA_MEM, [scratch_reg, ops[0], lea_offset])
        ins.Init(o.ST, [scratch_reg, st_offset, ops[2]])
        return [lea, ins]
    elif opc is o.LD_MEM:
        ld_offset = ops[2]
        lea_offset = ir.Const(offset_kind, 0)
        if isinstance(ld_offset, ir.Const):
            ld_offset, lea_offset = lea_offset, ld_offset
        scratch_reg = fun.GetScratchReg(base_kind, "base", False)
        # TODO: should the Zero Offset stay with the ld op?
        lea = ir.Ins(o.LEA_MEM, [scratch_reg, ops[1], lea_offset])
        ins.Init(o.LD, [ops[0], scratch_reg, ld_offset])
        return [lea, ins]
    else:
        return None
Esempio n. 4
0
def FunSeparateLocalRegUsage(fun: ir.Fun) -> int:
    """ Split life ranges for (BBL) local regs

    This is works in coordination with the liverange computation AND
    the local register allocator which assigns one cpu register to each
    liverange.
    """
    count = 0
    for bbl in fun.bbls:
        for pos, ins in enumerate(bbl.inss):
            num_defs = ins.opcode.def_ops_count()
            for n, reg in enumerate(ins.operands[:num_defs]):
                assert isinstance(reg, ir.Reg)
                # do not separate if:
                # * this is the first definition of this reg
                # * the reg is global
                # * the reg is part of a two address "situation" (for x64)
                # * the reg is has been assigned a cpu_reg
                if (reg.def_ins is ins or ir.REG_FLAG.GLOBAL in reg.flags
                        or (ir.REG_FLAG.TWO_ADDRESS in reg.flags
                            and len(ins.operands) >= 2
                            and ins.operands[0] == ins.operands[1])
                        or reg.cpu_reg is not None):
                    continue
                purpose = reg.name
                if purpose.startswith("$"):
                    underscore_pos = purpose.find("_")
                    purpose = purpose[underscore_pos + 1:]
                new_reg = fun.GetScratchReg(reg.kind, purpose, False)
                if ir.REG_FLAG.TWO_ADDRESS in reg.flags:
                    new_reg.flags |= ir.REG_FLAG.TWO_ADDRESS
                ins.operands[n] = new_reg
                _BblRenameReg(bbl, pos + 1, reg, new_reg)
                count += 1
    return count
Esempio n. 5
0
def InsEliminateCmp(ins: ir.Ins, bbl: ir.Bbl, fun: ir.Fun):
    """Rewrites cmpXX a, b, c, x, y instructions like so:
    canonicalization ensures that a != c
    mov z b
    bXX skip, x, y
      mov z c
    .bbl skip
      mov a z

    TODO: This is very coarse
    """
    assert ins.opcode.kind is o.OPC_KIND.CMP
    bbl_skip = cfg.BblSplit(ins, bbl, fun, bbl.name + "_spilt")
    bbl_prev = cfg.BblSplit(ins, bbl_skip, fun, bbl.name + "_spilt")
    assert not bbl_skip.inss
    assert bbl_prev.inss[-1] is ins
    assert bbl_prev.edge_out == [bbl_skip]
    assert bbl_skip.edge_in == [bbl_prev]
    assert bbl_skip.edge_out == [bbl]
    assert bbl.edge_in == [bbl_skip]

    reg = fun.GetScratchReg(ins.operands[0].kind, "cmp", False)

    del bbl_prev.inss[-1]
    ops = ins.operands
    bbl_prev.inss.append(ir.Ins(o.MOV, [reg, ops[1]]))
    bbl_prev.inss.append(
        ir.Ins(o.BEQ if ins.opcode == o.CMPEQ else o.BLT,
               [ops[3], ops[4], bbl]))
    bbl_skip.inss.append(ir.Ins(o.MOV, [reg, ops[2]]))
    bbl.inss.insert(0, ir.Ins(o.MOV, [ops[0], reg]))
    bbl_prev.edge_out.append(bbl)
    bbl.edge_in.append(bbl_prev)
Esempio n. 6
0
def _InsLimitShiftAmounts(ins: ir.Ins, fun: ir.Fun,
                          width: int) -> Optional[List[ir.Ins]]:
    """This rewrite is usually applied as prep step by some backends
     to get rid of Stk operands.
     It allows the register allocator to see the scratch register but
     it will obscure the fact that a memory access is a stack access.

     Note, a stack address already implies a `sp+offset` addressing mode and risk
     ISAs do no usually support  `sp+offset+reg` addressing mode.
    """
    opc = ins.opcode
    ops = ins.operands
    if (opc is not o.SHL
            and opc is not o.SHR) or ops[0].kind.bitwidth() != width:
        return None
    amount = ops[2]
    if isinstance(amount, ir.Const):
        if 0 <= amount.value < width:
            return None
        else:
            ops[2] = ir.Const(amount.kind, amount.value % width)
            return ins
    else:
        tmp = fun.GetScratchReg(amount.kind, "shift", False)
        mask = ir.Ins(o.AND, [tmp, amount, ir.Const(amount.kind, width - 1)])
        ins.Init(opc, [ops[0], ops[1], tmp])
        return [mask, ins]
Esempio n. 7
0
def _InsAddNop1ForCodeSel(ins: ir.Ins, fun: ir.Fun) -> Optional[List[ir.Ins]]:
    opc = ins.opcode
    if opc is o.SWITCH:
        # needs scratch to compute the jmp address into
        scratch = fun.GetScratchReg(o.DK.C32, "switch", False)
        return [ir.Ins(o.NOP1, [scratch]), ins]
    elif (opc is o.CONV and o.RegIsInt(ins.operands[0].kind) and
          ins.operands[1].kind.flavor() == o.DK_FLAVOR_F):
        # need scratch for intermediate ftl result
        # we know the result cannot be wider than 32bit for this CPU
        scratch = fun.GetScratchReg(o.DK.F32, "ftoi", False)
        return [ir.Ins(o.NOP1, [scratch]), ins]
    elif (opc is o.CONV and o.RegIsInt(ins.operands[1].kind) and
          ins.operands[0].kind is o.DK.F64):
        # need scratch for intermediate ftl result
        # we know the result cannot be wider than 32bit for this CPU
        scratch = fun.GetScratchReg(o.DK.F32, "itof", False)
        return [ir.Ins(o.NOP1, [scratch]), ins]
    return [ins]
Esempio n. 8
0
def _InsRewriteFltImmediates(ins: ir.Ins, fun: ir.Fun,
                             unit: ir.Unit) -> Optional[List[ir.Ins]]:
    inss = []
    for n, op in enumerate(ins.operands):
        if isinstance(op, ir.Const) and op.kind.flavor() is o.DK_FLAVOR_F:
            mem = unit.FindOrAddConstMem(op)
            tmp = fun.GetScratchReg(op.kind, "flt_const", True)
            inss.append(ir.Ins(o.LD_MEM, [tmp, mem, _ZERO_OFFSET]))
            ins.operands[n] = tmp
    if inss:
        return inss + [ins]
    return None
Esempio n. 9
0
def InsEliminateImmediateViaMem(ins: ir.Ins, pos: int, fun: ir.Fun,
                                unit: ir.Unit, addr_kind: o.DK,
                                offset_kind: o.DK) -> List[ir.Ins]:
    """Rewrite instruction with an immediate as load of the immediate


    This is useful if the target architecture does not support immediate
    for that instruction, or the immediate is too large.

    This optimization is run rather late and may already see machine registers.
    """
    # support of PUSHARG would require additional work because they need to stay consecutive
    assert ins.opcode is not o.PUSHARG
    const = ins.operands[pos]
    mem = unit.FindOrAddConstMem(const)
    tmp_addr = fun.GetScratchReg(addr_kind, "mem_const_addr", True)
    lea_ins = ir.Ins(o.LEA_MEM, [tmp_addr, mem, ir.Const(offset_kind, 0)])
    tmp = fun.GetScratchReg(const.kind, "mem_const", True)
    ld_ins = ir.Ins(o.LD, [tmp, tmp_addr, ir.Const(offset_kind, 0)])
    ins.operands[pos] = tmp
    return [lea_ins, ld_ins]
Esempio n. 10
0
def InsSpillRegs(ins: ir.Ins, fun: ir.Fun, zero_const, reg_to_stk) -> Optional[List[ir.Ins]]:
    before: List[ir.Ins] = []
    after: List[ir.Ins] = []
    num_defs = ins.opcode.def_ops_count()
    for n, reg in reversed(list(enumerate(ins.operands))):
        if not isinstance(reg, ir.Reg):
            continue
        stk = reg_to_stk.get(reg)
        if stk is None:
            continue
        if n < num_defs:
            scratch = fun.GetScratchReg(reg.kind, "stspill", False)
            ins.operands[n] = scratch
            after.append(ir.Ins(o.ST_STK, [stk, zero_const, scratch]))
        else:
            scratch = fun.GetScratchReg(reg.kind, "ldspill", False)
            ins.operands[n] = scratch
            before.append(ir.Ins(o.LD_STK, [scratch, stk, zero_const]))
    if before or after:
        return before + [ins] + after
    else:
        return None
Esempio n. 11
0
def _InsMoveImmediatesToMemory(ins: ir.Ins, fun: ir.Fun, unit: ir.Unit,
                               kind: o.DK) -> Optional[List[ir.Ins]]:
    inss = []
    for n, op in enumerate(ins.operands):
        if isinstance(op, ir.Const) and op.kind is kind:
            mem = unit.FindOrAddConstMem(op)
            tmp = fun.GetScratchReg(kind, "mem_const", True)
            # TODO: pass the offset kind as a parameter
            inss.append(ir.Ins(o.LD_MEM, [tmp, mem, ir.Const(o.DK.U32, 0)]))
            ins.operands[n] = tmp
    if inss:
        return inss + [ins]
    return None
Esempio n. 12
0
def _InsEliminateStkLoadStoreWithRegOffset(
        ins: ir.Ins, fun: ir.Fun, base_kind: o.DK,
        offset_kind: o.DK) -> Optional[List[ir.Ins]]:
    """This rewrite is usually applied as prep step by some backends
     to get rid of Stk operands.
     It allows the register allocator to see the scratch register but
     it will obscure the fact that a memory access is a stack access.

     Note, a stack address already implies a `sp+offset` addressing mode and risk
     ISAs do no usually support  `sp+offset+reg` addressing mode.
    """
    opc = ins.opcode
    ops = ins.operands
    if opc is o.ST_STK and isinstance(ops[1], ir.Reg):
        scratch_reg = fun.GetScratchReg(base_kind, "base", False)
        lea = ir.Ins(o.LEA_STK,
                     [scratch_reg, ops[0],
                      ir.Const(offset_kind, 0)])
        ins.Init(o.ST, [scratch_reg, ops[1], ops[2]])
        return [lea, ins]
    elif opc is o.LD_STK and isinstance(ops[2], ir.Reg):
        scratch_reg = fun.GetScratchReg(base_kind, "base", False)
        lea = ir.Ins(o.LEA_STK,
                     [scratch_reg, ops[1],
                      ir.Const(offset_kind, 0)])
        ins.Init(o.LD, [ops[0], scratch_reg, ops[2]])
        return [lea, ins]
    elif opc is o.LEA_STK and isinstance(ops[2], ir.Reg):
        scratch_reg = fun.GetScratchReg(base_kind, "base", False)
        # TODO: maybe reverse the order so that we can tell that ops[0] holds a stack
        # location
        lea = ir.Ins(o.LEA_STK,
                     [scratch_reg, ops[1],
                      ir.Const(offset_kind, 0)])
        ins.Init(o.LEA, [ops[0], scratch_reg, ops[2]])
        return [lea, ins]
    else:
        return None
Esempio n. 13
0
def _InsEliminateImmediateStores(ins: ir.Ins,
                                 fun: ir.Fun) -> Optional[List[ir.Ins]]:
    """RISC architectures typically do not allow immediates to be stored directly

    TODO: maybe allow zero immediates
    """
    opc = ins.opcode
    ops = ins.operands
    if opc in {o.ST_MEM, o.ST, o.ST_STK} and isinstance(ops[2], ir.Const):
        scratch_reg = fun.GetScratchReg(ops[2].kind, "st_imm", False)
        mov = ir.Ins(o.MOV, [scratch_reg, ops[2]])
        ops[2] = scratch_reg
        return [mov, ins]
    else:
        return None
Esempio n. 14
0
def FunSeparateLocalRegUsage(fun: ir.Fun) -> int:
    count = 0
    for bbl in fun.bbls:
        for pos, ins in enumerate(bbl.inss):
            num_defs = ins.opcode.def_ops_count()
            for n, reg in enumerate(ins.operands[:num_defs]):
                assert isinstance(reg, ir.Reg)
                if reg.def_ins is ins or ir.REG_FLAG.GLOBAL in reg.flags or reg.cpu_reg is not None:
                    continue
                purpose = reg.name
                if purpose.startswith("$"):
                    underscore_pos = purpose.find("_")
                    purpose = purpose[underscore_pos + 1:]
                new_reg = fun.GetScratchReg(reg.kind, purpose, False)
                ins.operands[n] = new_reg
                _BblRenameReg(bbl, pos + 1, reg, new_reg)
                count += 1
    return count
Esempio n. 15
0
def _InsRewriteIntoAABForm(ins: ir.Ins, fun: ir.Fun) -> Optional[List[ir.Ins]]:
    ops = ins.operands
    if not NeedsAABFromRewrite(ins):
        return None
    if ops[0] == ops[1]:
        ops[0].flags |= ir.REG_FLAG.TWO_ADDRESS
        return None
    if ops[0] == ops[2] and o.OA.COMMUTATIVE in ins.opcode.attributes:
        ir.InsSwapOps(ins, 1, 2)
        ops[0].flags |= ir.REG_FLAG.TWO_ADDRESS
        return [ins]
    else:
        reg = fun.GetScratchReg(ins.operands[0].kind, "aab", False)
        reg.flags |= ir.REG_FLAG.TWO_ADDRESS
        return [
            ir.Ins(o.MOV, [reg, ops[1]]),
            ir.Ins(ins.opcode, [reg, reg, ops[2]]),
            ir.Ins(o.MOV, [ops[0], reg])
        ]
Esempio n. 16
0
def InsEliminateImmediate(ins: ir.Ins, pos: int, fun: ir.Fun) -> ir.Ins:
    """Rewrite instruction with an immediate as load of the immediate
    followed by a pure register version of that instruction, e.g.

    mul z = a 666
    becomes
    mov scratch = 666
    mul z = a scratch

    This is useful if the target architecture does not support immediate
    for that instruction, or the immediate is too large.

    This optimization is run rather late and may already see machine
    registers like the sp.
    Hence we are careful to use and update ins.orig_operand
    """
    const = ins.operands[pos]
    assert isinstance(const, ir.Const)
    reg = fun.GetScratchReg(const.kind, "imm", True)
    ins.operands[pos] = reg
    return ir.Ins(o.MOV, [reg, const])
Esempio n. 17
0
def InsEliminateImmediateViaMov(ins: ir.Ins, pos: int, fun: ir.Fun) -> ir.Ins:
    """Rewrite instruction with an immediate as mov of the immediate

    mul z = a 666
    becomes
    mov scratch = 666
    mul z = a scratch

    This is useful if the target architecture does not support immediate
    for that instruction, or the immediate is too large.

    This optimization is run rather late and may already see machine registers.
    Ideally, the generated mov instruction hould be iselectable by the target architecture or
    else another pass may be necessary.
    """
    # support of PUSHARG would require additional work because they need to stay consecutive
    assert ins.opcode is not o.PUSHARG
    const = ins.operands[pos]
    assert isinstance(const, ir.Const)
    reg = fun.GetScratchReg(const.kind, "imm", True)
    ins.operands[pos] = reg
    return ir.Ins(o.MOV, [reg, const])
Esempio n. 18
0
def _InsRewriteOutOfBoundsOffsetsStk(ins: ir.Ins,
                                     fun: ir.Fun) -> Optional[List[ir.Ins]]:
    # Note, we can handle any LEA_STK as long as it is adding a constant
    if ins.opcode not in {o.LD_STK, o.ST_STK}:
        return None
    mismatches = isel_tab.FindtImmediateMismatchesInBestMatchPattern(ins)
    assert mismatches != isel_tab.MATCH_IMPOSSIBLE, f"could not match opcode {ins} {ins.operands}"

    if mismatches == 0:
        return None

    inss = []
    tmp = fun.GetScratchReg(o.DK.A32, "imm_stk", False)
    if ins.opcode is o.ST_STK:
        # note we do not have to worry about ins.operands[2] being Const
        # because those were dealt with by FunEliminateImmediateStores
        assert mismatches == (1 << 1)
        if isinstance(ins.operands[1], ir.Const):
            inss.append(
                ir.Ins(o.LEA_STK, [tmp, ins.operands[0], ins.operands[1]]))
            ins.Init(o.ST, [tmp, _ZERO_OFFSET, ins.operands[2]])
        else:
            inss.append(ir.Ins(o.LEA_STK,
                               [tmp, ins.operands[0], _ZERO_OFFSET]))
            ins.Init(o.ST, [tmp, ins.operands[1], ins.operands[2]])
    else:
        assert ins.opcode is o.LD_STK
        assert mismatches & (1 << 2)
        if isinstance(ins.operands[2], ir.Const):
            inss.append(
                ir.Ins(o.LEA_STK, [tmp, ins.operands[1], ins.operands[2]]))
            ins.Init(o.LD, [ins.operands[0], tmp, _ZERO_OFFSET])
        else:
            inss.append(ir.Ins(o.LEA_STK,
                               [tmp, ins.operands[1], _ZERO_OFFSET]))
            ins.Init(o.LD, [ins.operands[0], tmp, ins.operands[2]])
    inss.append(ins)
    return inss
Esempio n. 19
0
def FunRegWidthWidening(fun: ir.Fun, narrow_kind: o.DK, wide_kind: o.DK):
    """
    Change the type of all register (and constants) of type src_kind into dst_kind.
    Add compensation code where necessary.
    dst_kind must be wider than src_kind.

    This is useful for target architectures that do not support operations
    for all operand widths.

    Note, this also widens input and output regs. So this must run
      for all functions including prototypes

      TODO: double check if we are doing the right thing with o.CONV
      TODO: there are more subtle bugs. For example
              mul x:U8  = 43 * 47    (= 229)
              div y:u8  = x   13      (= 17)
            whereas:
              mul x:U16  = 43 * 47    (= 2021)
              div y:u16  = x   13      (= 155)

      Other problematic operations: rem, popcnt, ...
      """
    assert ir.FUN_FLAG.STACK_FINALIZED not in fun.flags
    fun.input_types = [
        wide_kind if x == narrow_kind else x for x in fun.input_types
    ]
    fun.output_types = [
        wide_kind if x == narrow_kind else x for x in fun.output_types
    ]

    assert narrow_kind.flavor() == wide_kind.flavor()
    assert narrow_kind.bitwidth() < wide_kind.bitwidth()
    narrow_regs = {
        reg
        for reg in fun.reg_syms.values() if reg.kind == narrow_kind
    }

    for reg in narrow_regs:
        reg.kind = wide_kind

    count = 0
    for bbl in fun.bbls:
        inss = []

        for ins in bbl.inss:
            ops = ins.operands
            kind = ins.opcode.kind

            for n, reg in enumerate(ops):
                if n == 2 and kind is o.OPC_KIND.ST or n == 0 and kind is o.OPC_KIND.LD:
                    continue
                if isinstance(reg, ir.Const) and reg.kind is narrow_kind:
                    # if ins.opcode.constraints[n] == o.TC.OFFSET:
                    #    continue
                    ops[n] = ir.Const(wide_kind, reg.value)
            kind = ins.opcode.kind
            if kind is o.OPC_KIND.LD and ops[0] in narrow_regs:
                inss.append(ins)
                tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True)
                inss.append(ir.Ins(o.CONV, [ops[0], tmp_reg]))
                ops[0] = tmp_reg
            elif (kind is o.OPC_KIND.ST and isinstance(ops[2], ir.Reg)
                  and ops[2] in narrow_regs):
                tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True)
                inss.append(ir.Ins(o.CONV, [tmp_reg, ops[2]]))
                inss.append(ins)
                ops[2] = tmp_reg
            else:
                inss.append(ins)

        count += len(inss) - len(bbl.inss)
        bbl.inss = inss
    return count
Esempio n. 20
0
def FunRegWidthWidening(fun: ir.Fun, narrow_kind: o.DK, wide_kind: o.DK):
    """
    Change the type of all register (and constants) of type src_kind into dst_kind.
    Add compensation code where necessary.
    dst_kind must be wider than src_kind.

    This is useful for target architectures that do not support operations
    for all operand widths.

    Note, this also widens input and output regs. So this must run
      for all functions including prototypes

      TODO: double check if we are doing the right thing with o.CONV
      TODO: there are more subtle bugs. For example
              mul x:U8  = 43 * 47    (= 229)
              div y:u8  = x   13      (= 17)
            whereas:
              mul x:U16  = 43 * 47    (= 2021)
              div y:u16  = x   13      (= 155)

      Other problematic operations: rem, popcnt, ...

      The invariant we are maintaining is this one:
      if reg a gets widened into reg b with bitwidth(a) = w then
      the lower w bits of reg b will always contain the same data as reg a would have.
      """
    assert ir.FUN_FLAG.STACK_FINALIZED not in fun.flags
    fun.input_types = [
        wide_kind if x == narrow_kind else x for x in fun.input_types
    ]
    fun.output_types = [
        wide_kind if x == narrow_kind else x for x in fun.output_types
    ]

    assert narrow_kind.flavor() == wide_kind.flavor()
    assert narrow_kind.bitwidth() < wide_kind.bitwidth()
    narrow_regs = {
        reg
        for reg in fun.reg_syms.values() if reg.kind == narrow_kind
    }

    for reg in narrow_regs:
        reg.kind = wide_kind

    count = 0
    for bbl in fun.bbls:
        inss = []

        for ins in bbl.inss:
            ops = ins.operands
            kind = ins.opcode.kind
            changed = False
            for n, reg in enumerate(ops):
                if isinstance(reg, ir.Const) and reg.kind is narrow_kind:
                    if kind is o.OPC_KIND.ST and n == 2:
                        continue
                    ops[n] = ir.Const(wide_kind, reg.value)
                    changed = True
                if isinstance(reg, ir.Reg) and reg in narrow_regs:
                    changed = True
            if not changed:
                inss.append(ins)
                continue
            kind = ins.opcode.kind
            if ins.opcode is o.SHL or ins.opcode is o.SHR:
                # deal with the shift amount which is subject to an implicit modulo "bitwidth -1"
                # by changing the width of the reg - we lose this information
                tmp_reg = fun.GetScratchReg(wide_kind, "tricky", False)
                inss.append(
                    ir.Ins(o.AND, [
                        tmp_reg, ops[2],
                        ir.Const(wide_kind,
                                 narrow_kind.bitwidth() - 1)
                    ]))
                ops[2] = tmp_reg
                if ins.opcode is o.SHR and isinstance(ops[1], ir.Reg):
                    # for SHR we also need to make sure the new high order bits are correct
                    tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True)
                    inss.append(ir.Ins(o.CONV, [tmp_reg, ops[1]]))
                    # the implicit understanding is that this will become nop or a move and not modify the
                    # high-order bit we just set in the previous instruction
                    inss.append(ir.Ins(o.CONV, [ops[1], tmp_reg]))
                inss.append(ins)
            elif ins.opcode is o.CNTLZ:
                inss.append(ins)
                excess = wide_kind.bitwidth() - narrow_kind.bitwidth()
                inss.append(
                    ir.Ins(o.SUB,
                           [ops[0], ops[0],
                            ir.Const(wide_kind, excess)]))
            elif ins.opcode is o.CNTTZ:
                inss.append(ins)
                inss.append(
                    ir.Ins(o.CMPLT, [
                        ops[0], ops[0],
                        ir.Const(wide_kind, narrow_kind.bitwidth()), ops[0],
                        ir.Const(wide_kind, narrow_kind.bitwidth())
                    ]))
            elif kind is o.OPC_KIND.LD and ops[0] in narrow_regs:
                inss.append(ins)
                tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True)
                inss.append(ir.Ins(o.CONV, [ops[0], tmp_reg]))
                ops[0] = tmp_reg
            elif (kind is o.OPC_KIND.ST and isinstance(ops[2], ir.Reg)
                  and ops[2] in narrow_regs):
                tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True)
                inss.append(ir.Ins(o.CONV, [tmp_reg, ops[2]]))
                inss.append(ins)
                ops[2] = tmp_reg
            elif ins.opcode is o.CONV:
                tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True)
                inss.append(ir.Ins(o.CONV, [tmp_reg, ops[1]]))
                inss.append(ir.Ins(o.CONV, [ops[0], tmp_reg]))
            else:
                inss.append(ins)

        count += len(inss) - len(bbl.inss)
        bbl.inss = inss
    return count