Exemplo n.º 1
0
def FunRegWidthWidening(fun: ir.Fun, narrow_kind: o.DK, wide_kind: o.DK):
    """
    Change the type of all register (and constants) of type src_kind into dst_kind.
    Add compensation code where necessary.
    dst_kind must be wider than src_kind.

    This is useful for target architectures that do not support operations
    for all operand widths.

    Note, this also widens input and output regs. So this must run
      for all functions including prototypes

      TODO: double check if we are doing the right thing with o.CONV
      TODO: there are more subtle bugs. For example
              mul x:U8  = 43 * 47    (= 229)
              div y:u8  = x   13      (= 17)
            whereas:
              mul x:U16  = 43 * 47    (= 2021)
              div y:u16  = x   13      (= 155)

      Other problematic operations: rem, popcnt, ...
      """
    assert ir.FUN_FLAG.STACK_FINALIZED not in fun.flags
    fun.input_types = [
        wide_kind if x == narrow_kind else x for x in fun.input_types
    ]
    fun.output_types = [
        wide_kind if x == narrow_kind else x for x in fun.output_types
    ]

    assert narrow_kind.flavor() == wide_kind.flavor()
    assert narrow_kind.bitwidth() < wide_kind.bitwidth()
    narrow_regs = {
        reg
        for reg in fun.reg_syms.values() if reg.kind == narrow_kind
    }

    for reg in narrow_regs:
        reg.kind = wide_kind

    count = 0
    for bbl in fun.bbls:
        inss = []

        for ins in bbl.inss:
            ops = ins.operands
            kind = ins.opcode.kind

            for n, reg in enumerate(ops):
                if n == 2 and kind is o.OPC_KIND.ST or n == 0 and kind is o.OPC_KIND.LD:
                    continue
                if isinstance(reg, ir.Const) and reg.kind is narrow_kind:
                    # if ins.opcode.constraints[n] == o.TC.OFFSET:
                    #    continue
                    ops[n] = ir.Const(wide_kind, reg.value)
            kind = ins.opcode.kind
            if kind is o.OPC_KIND.LD and ops[0] in narrow_regs:
                inss.append(ins)
                tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True)
                inss.append(ir.Ins(o.CONV, [ops[0], tmp_reg]))
                ops[0] = tmp_reg
            elif (kind is o.OPC_KIND.ST and isinstance(ops[2], ir.Reg)
                  and ops[2] in narrow_regs):
                tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True)
                inss.append(ir.Ins(o.CONV, [tmp_reg, ops[2]]))
                inss.append(ins)
                ops[2] = tmp_reg
            else:
                inss.append(ins)

        count += len(inss) - len(bbl.inss)
        bbl.inss = inss
    return count
Exemplo n.º 2
0
def FunRegWidthWidening(fun: ir.Fun, narrow_kind: o.DK, wide_kind: o.DK):
    """
    Change the type of all register (and constants) of type src_kind into dst_kind.
    Add compensation code where necessary.
    dst_kind must be wider than src_kind.

    This is useful for target architectures that do not support operations
    for all operand widths.

    Note, this also widens input and output regs. So this must run
      for all functions including prototypes

      TODO: double check if we are doing the right thing with o.CONV
      TODO: there are more subtle bugs. For example
              mul x:U8  = 43 * 47    (= 229)
              div y:u8  = x   13      (= 17)
            whereas:
              mul x:U16  = 43 * 47    (= 2021)
              div y:u16  = x   13      (= 155)

      Other problematic operations: rem, popcnt, ...

      The invariant we are maintaining is this one:
      if reg a gets widened into reg b with bitwidth(a) = w then
      the lower w bits of reg b will always contain the same data as reg a would have.
      """
    assert ir.FUN_FLAG.STACK_FINALIZED not in fun.flags
    fun.input_types = [
        wide_kind if x == narrow_kind else x for x in fun.input_types
    ]
    fun.output_types = [
        wide_kind if x == narrow_kind else x for x in fun.output_types
    ]

    assert narrow_kind.flavor() == wide_kind.flavor()
    assert narrow_kind.bitwidth() < wide_kind.bitwidth()
    narrow_regs = {
        reg
        for reg in fun.reg_syms.values() if reg.kind == narrow_kind
    }

    for reg in narrow_regs:
        reg.kind = wide_kind

    count = 0
    for bbl in fun.bbls:
        inss = []

        for ins in bbl.inss:
            ops = ins.operands
            kind = ins.opcode.kind
            changed = False
            for n, reg in enumerate(ops):
                if isinstance(reg, ir.Const) and reg.kind is narrow_kind:
                    if kind is o.OPC_KIND.ST and n == 2:
                        continue
                    ops[n] = ir.Const(wide_kind, reg.value)
                    changed = True
                if isinstance(reg, ir.Reg) and reg in narrow_regs:
                    changed = True
            if not changed:
                inss.append(ins)
                continue
            kind = ins.opcode.kind
            if ins.opcode is o.SHL or ins.opcode is o.SHR:
                # deal with the shift amount which is subject to an implicit modulo "bitwidth -1"
                # by changing the width of the reg - we lose this information
                tmp_reg = fun.GetScratchReg(wide_kind, "tricky", False)
                inss.append(
                    ir.Ins(o.AND, [
                        tmp_reg, ops[2],
                        ir.Const(wide_kind,
                                 narrow_kind.bitwidth() - 1)
                    ]))
                ops[2] = tmp_reg
                if ins.opcode is o.SHR and isinstance(ops[1], ir.Reg):
                    # for SHR we also need to make sure the new high order bits are correct
                    tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True)
                    inss.append(ir.Ins(o.CONV, [tmp_reg, ops[1]]))
                    # the implicit understanding is that this will become nop or a move and not modify the
                    # high-order bit we just set in the previous instruction
                    inss.append(ir.Ins(o.CONV, [ops[1], tmp_reg]))
                inss.append(ins)
            elif ins.opcode is o.CNTLZ:
                inss.append(ins)
                excess = wide_kind.bitwidth() - narrow_kind.bitwidth()
                inss.append(
                    ir.Ins(o.SUB,
                           [ops[0], ops[0],
                            ir.Const(wide_kind, excess)]))
            elif ins.opcode is o.CNTTZ:
                inss.append(ins)
                inss.append(
                    ir.Ins(o.CMPLT, [
                        ops[0], ops[0],
                        ir.Const(wide_kind, narrow_kind.bitwidth()), ops[0],
                        ir.Const(wide_kind, narrow_kind.bitwidth())
                    ]))
            elif kind is o.OPC_KIND.LD and ops[0] in narrow_regs:
                inss.append(ins)
                tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True)
                inss.append(ir.Ins(o.CONV, [ops[0], tmp_reg]))
                ops[0] = tmp_reg
            elif (kind is o.OPC_KIND.ST and isinstance(ops[2], ir.Reg)
                  and ops[2] in narrow_regs):
                tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True)
                inss.append(ir.Ins(o.CONV, [tmp_reg, ops[2]]))
                inss.append(ins)
                ops[2] = tmp_reg
            elif ins.opcode is o.CONV:
                tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True)
                inss.append(ir.Ins(o.CONV, [tmp_reg, ops[1]]))
                inss.append(ir.Ins(o.CONV, [ops[0], tmp_reg]))
            else:
                inss.append(ins)

        count += len(inss) - len(bbl.inss)
        bbl.inss = inss
    return count