Beispiel #1
0
def _truncate(dk: o.DK, val):
    if dk.flavor() == o.DK_FLAVOR_F:
        return val
    elif dk.flavor() == o.DK_FLAVOR_U:
        return val & ((1 << dk.bitwidth()) - 1)
    else:
        return SignedIntFromBits(val, dk.bitwidth())
Beispiel #2
0
def MaybeMakeGlobalTable(mod: wasm.Module, unit: ir.Unit, addr_type: o.DK):
    bit_width = addr_type.bitwidth()
    table_sec = mod.sections.get(wasm.SECTION_ID.TABLE)
    table_elements = mod.sections.get(wasm.SECTION_ID.ELEMENT)
    if not table_sec:
        return None

    global_table = None
    assert table_elements
    assert len(table_sec.items) == 1
    table_type: wasm.TableType = table_sec.items[0].table_type
    assert table_type.element_type == wasm.REF_TYPE.FUNCREF
    table_data = [None] * table_type.limits.max
    for elem in table_elements.items:
        ins = GetInsFromInitializerExpression(elem.expr)
        assert ins.opcode is wasm_opc.I32_CONST
        start = ins.args[0]
        for n, fun in enumerate(elem.funcidxs):
            table_data[start + n] = fun

    global_table = unit.AddMem(
        ir.Mem("global_table", bit_width // 8, o.MEM_KIND.RO))
    width = addr_type.bitwidth() // 8
    for fun in table_data:
        if fun is None:
            global_table.AddData(ir.DataBytes(width, b"\0"))
        else:
            assert isinstance(fun, wasm.FuncIdx)
            fun = unit.GetFun(mod.functions[int(fun)].name)
            global_table.AddData(ir.DataAddrFun(width, fun))
    return global_table
Beispiel #3
0
def Cntlz(kind: o.DK, val: int):
    if val == 0:
        return kind.bitwidth()
    mask = 1 << (kind.bitwidth() - 1)
    n = 0
    while val & mask:
        mask >>= 1
        n += 1
    return n
Beispiel #4
0
def ParseConst(value_str: str, kind: o.DK) -> Const:
    flavor = kind.flavor()
    if flavor is o.DK_FLAVOR_F:
        return Const(kind, float(value_str))

    bit_width = kind.bitwidth()
    x = int(value_str, 0)
    if flavor is o.DK_FLAVOR_U:
        assert x >= 0
        assert x < (1 << bit_width)
    elif x >= 0:
        assert x < (1 << (bit_width - 1))
    else:
        assert -x <= (1 << (bit_width - 1))

    return Const(kind, x)
Beispiel #5
0
def ConvertIntValue(kind_dst: o.DK, val: ir.Const) -> ir.Const:
    kind_src = val.kind
    width_dst = kind_dst.bitwidth()
    width_src = kind_src.bitwidth()
    masked = val.value & ((1 << width_dst) - 1)
    if kind_dst.flavor() == o.DK_FLAVOR_U:
        return ir.Const(kind_dst, val.value & masked)
    # print ("@@@", kind_dst.name, width_dst, kind_src, width_src, num_kind, x)
    elif width_dst > width_src:
        return ir.Const(kind_dst, val.value)
    else:
        # dst is ACS and width_dst <= width_src
        will_be_negative = val.value & (1 << (width_dst - 1))
        if will_be_negative:
            return ir.Const(kind_dst, masked - (1 << width_dst))
        return ir.Const(kind_dst, masked)
Beispiel #6
0
def Translate(mod: wasm.Module, addr_type: o.DK) -> ir.Unit:
    table_import = mod.sections.get(wasm.SECTION_ID.IMPORT)
    for i in table_import.items:
        assert isinstance(i, wasm.Import)
        assert isinstance(i.desc,
                          wasm.TypeIdx), f"cannot handle strange imports: {i}"

    bit_width = addr_type.bitwidth()
    unit = ir.Unit("unit")

    global_argv = unit.AddMem(
        ir.Mem("global_argv", 2 * bit_width // 8, o.MEM_KIND.RW))
    global_argv.AddData(ir.DataBytes(bit_width // 8, b"\0"))

    global_argc = unit.AddMem(ir.Mem("global_argc", 4, o.MEM_KIND.RW))
    global_argc.AddData(ir.DataBytes(4, b"\0"))

    memcpy = GenerateMemcpyFun(unit, addr_type)
    init_global = GenerateInitGlobalVarsFun(mod, unit, addr_type)
    init_data = GenerateInitDataFun(mod, unit, memcpy, addr_type)
    unit.AddFun(ir.Fun("__wasi_init", o.FUN_KIND.EXTERN, [], []))
    unit.AddFun(
        ir.Fun("__memory_grow", o.FUN_KIND.EXTERN, [o.DK.S32], [o.DK.S32]))

    main = None
    for wasm_fun in mod.functions:
        # forward declare everything since we cannot rely on a topological sort of the funs
        if isinstance(wasm_fun.impl, wasm.Import):
            assert wasm_fun.name in WASI_FUNCTIONS, f"unimplemented external function: {wasm_fun.name}"
        arguments = [addr_type] + TranslateTypeList(wasm_fun.func_type.args)
        returns = TranslateTypeList(wasm_fun.func_type.rets)
        # assert len(returns) <= 1
        unit.AddFun(
            ir.Fun(wasm_fun.name, o.FUN_KIND.EXTERN, returns, arguments))

    global_table = MaybeMakeGlobalTable(mod, unit, addr_type)

    for wasm_fun in mod.functions:
        if isinstance(wasm_fun.impl, wasm.Import):
            continue
        fun = unit.GetFun(wasm_fun.name)
        fun.kind = o.FUN_KIND.NORMAL
        if fun.name == "_start":
            fun.name = "$main"
            main = fun
        GenerateFun(unit, mod, wasm_fun, fun, global_table, addr_type)
        # print ("\n".join(serialize.FunRenderToAsm(fun)))
        sanity.FunCheck(fun, unit, check_cfg=False)

    initial_heap_size_pages = 0
    sec_memory = mod.sections.get(wasm.SECTION_ID.MEMORY)
    if sec_memory:
        assert len(sec_memory.items) == 1
        heap: wasm.Mem = sec_memory.items[0]
        initial_heap_size_pages = heap.mem_type.limits.min
    assert main, f"missing main function"
    GenerateStartup(unit, global_argc, global_argv, main, init_global,
                    init_data, initial_heap_size_pages, addr_type)

    return unit
Beispiel #7
0
def Cnttz(kind: o.DK, val: int):
    if val == 0:
        return kind.bitwidth()
    n = 0
    while val & 1 == 0:
        val >>= 1
        n += 1
    return n
Beispiel #8
0
def ConvertIntValue(kind_dst: o.DK, val: ir.Const) -> ir.Const:
    kind_src = val.kind
    width_dst = kind_dst.bitwidth()
    width_src = kind_src.bitwidth()
    # print ("@@@", kind_dst.name, width_dst, kind_src, width_src, num_kind, x)
    masked = val.value & ((1 << width_dst) - 1)
    if width_dst > width_src:
        if kind_dst.flavor() == kind_src.flavor() or kind_src.flavor(
        ) == o.DK_FLAVOR_U:
            return ir.Const(kind_dst, val.value)
        # kind_dst == RK_U, kind_src == RK_S
        return ir.Const(kind_dst, masked)
    elif kind_dst.flavor() == o.DK_FLAVOR_U:
        return ir.Const(kind_dst, masked)
    else:
        # kind_dst[0] == RK_S
        sign = val.value & (1 << (width_dst - 1))
        if sign == 0:
            return ir.Const(kind_dst, masked)
        return ir.Const(kind_dst, masked - (1 << width_dst))
Beispiel #9
0
def GenerateStartup(unit: ir.Unit, global_argc, global_argv, main: ir.Fun,
                    init_global: ir.Fun, init_data: ir.Fun,
                    initial_heap_size_pages: int, addr_type: o.DK) -> ir.Fun:
    bit_width = addr_type.bitwidth()

    global_mem_base = unit.AddMem(ir.Mem("__memory_base", 0,
                                         o.MEM_KIND.EXTERN))

    fun = unit.AddFun(
        ir.Fun("main", o.FUN_KIND.NORMAL, [o.DK.U32], [o.DK.U32, addr_type]))
    argc = fun.AddReg(ir.Reg("argc", o.DK.U32))
    argv = fun.AddReg(ir.Reg("argv", addr_type))

    bbl = fun.AddBbl(ir.Bbl("start"))
    bbl.AddIns(ir.Ins(o.POPARG, [argc]))
    bbl.AddIns(ir.Ins(o.POPARG, [argv]))
    bbl.AddIns(ir.Ins(o.ST_MEM, [global_argc, ZERO, argc]))
    bbl.AddIns(ir.Ins(o.ST_MEM, [global_argv, ZERO, argv]))

    bbl.AddIns(ir.Ins(o.BSR, [unit.GetFun("__wasi_init")]))
    if initial_heap_size_pages:
        bbl.AddIns(
            ir.Ins(o.PUSHARG, [ir.Const(o.DK.S32, initial_heap_size_pages)]))
        bbl.AddIns(ir.Ins(o.BSR, [unit.GetFun("__memory_grow")]))
        bbl.AddIns(ir.Ins(o.POPARG, [fun.AddReg(ir.Reg("dummy", o.DK.S32))]))

    mem_base = fun.AddReg(ir.Reg("mem_base", addr_type))
    bbl.AddIns(ir.Ins(o.LD_MEM, [mem_base, global_mem_base, ZERO]))

    if init_global:
        bbl.AddIns(ir.Ins(o.BSR, [init_global]))
    if init_data:
        bbl.AddIns(ir.Ins(o.PUSHARG, [mem_base]))
        bbl.AddIns(ir.Ins(o.BSR, [init_data]))

    bbl.AddIns(ir.Ins(o.PUSHARG, [mem_base]))
    bbl.AddIns(ir.Ins(o.BSR, [main]))
    bbl.AddIns(ir.Ins(o.PUSHARG, [ir.Const(o.DK.U32, 0)]))
    bbl.AddIns(ir.Ins(o.RET, []))
    return fun
Beispiel #10
0
 def get_cpu_reg_family(self, kind: o.DK) -> int:
     return FLT_NOT_LAC if kind.flavor() is o.DK_FLAVOR_F else GPR_NOT_LAC
Beispiel #11
0
def FunRegWidthWidening(fun: ir.Fun, narrow_kind: o.DK, wide_kind: o.DK):
    """
    Change the type of all register (and constants) of type src_kind into dst_kind.
    Add compensation code where necessary.
    dst_kind must be wider than src_kind.

    This is useful for target architectures that do not support operations
    for all operand widths.

    Note, this also widens input and output regs. So this must run
      for all functions including prototypes

      TODO: double check if we are doing the right thing with o.CONV
      TODO: there are more subtle bugs. For example
              mul x:U8  = 43 * 47    (= 229)
              div y:u8  = x   13      (= 17)
            whereas:
              mul x:U16  = 43 * 47    (= 2021)
              div y:u16  = x   13      (= 155)

      Other problematic operations: rem, popcnt, ...
      """
    assert ir.FUN_FLAG.STACK_FINALIZED not in fun.flags
    fun.input_types = [
        wide_kind if x == narrow_kind else x for x in fun.input_types
    ]
    fun.output_types = [
        wide_kind if x == narrow_kind else x for x in fun.output_types
    ]

    assert narrow_kind.flavor() == wide_kind.flavor()
    assert narrow_kind.bitwidth() < wide_kind.bitwidth()
    narrow_regs = {
        reg
        for reg in fun.reg_syms.values() if reg.kind == narrow_kind
    }

    for reg in narrow_regs:
        reg.kind = wide_kind

    count = 0
    for bbl in fun.bbls:
        inss = []

        for ins in bbl.inss:
            ops = ins.operands
            kind = ins.opcode.kind

            for n, reg in enumerate(ops):
                if n == 2 and kind is o.OPC_KIND.ST or n == 0 and kind is o.OPC_KIND.LD:
                    continue
                if isinstance(reg, ir.Const) and reg.kind is narrow_kind:
                    # if ins.opcode.constraints[n] == o.TC.OFFSET:
                    #    continue
                    ops[n] = ir.Const(wide_kind, reg.value)
            kind = ins.opcode.kind
            if kind is o.OPC_KIND.LD and ops[0] in narrow_regs:
                inss.append(ins)
                tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True)
                inss.append(ir.Ins(o.CONV, [ops[0], tmp_reg]))
                ops[0] = tmp_reg
            elif (kind is o.OPC_KIND.ST and isinstance(ops[2], ir.Reg)
                  and ops[2] in narrow_regs):
                tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True)
                inss.append(ir.Ins(o.CONV, [tmp_reg, ops[2]]))
                inss.append(ins)
                ops[2] = tmp_reg
            else:
                inss.append(ins)

        count += len(inss) - len(bbl.inss)
        bbl.inss = inss
    return count
Beispiel #12
0
def FunRegWidthWidening(fun: ir.Fun, narrow_kind: o.DK, wide_kind: o.DK):
    """
    Change the type of all register (and constants) of type src_kind into dst_kind.
    Add compensation code where necessary.
    dst_kind must be wider than src_kind.

    This is useful for target architectures that do not support operations
    for all operand widths.

    Note, this also widens input and output regs. So this must run
      for all functions including prototypes

      TODO: double check if we are doing the right thing with o.CONV
      TODO: there are more subtle bugs. For example
              mul x:U8  = 43 * 47    (= 229)
              div y:u8  = x   13      (= 17)
            whereas:
              mul x:U16  = 43 * 47    (= 2021)
              div y:u16  = x   13      (= 155)

      Other problematic operations: rem, popcnt, ...

      The invariant we are maintaining is this one:
      if reg a gets widened into reg b with bitwidth(a) = w then
      the lower w bits of reg b will always contain the same data as reg a would have.
      """
    assert ir.FUN_FLAG.STACK_FINALIZED not in fun.flags
    fun.input_types = [
        wide_kind if x == narrow_kind else x for x in fun.input_types
    ]
    fun.output_types = [
        wide_kind if x == narrow_kind else x for x in fun.output_types
    ]

    assert narrow_kind.flavor() == wide_kind.flavor()
    assert narrow_kind.bitwidth() < wide_kind.bitwidth()
    narrow_regs = {
        reg
        for reg in fun.reg_syms.values() if reg.kind == narrow_kind
    }

    for reg in narrow_regs:
        reg.kind = wide_kind

    count = 0
    for bbl in fun.bbls:
        inss = []

        for ins in bbl.inss:
            ops = ins.operands
            kind = ins.opcode.kind
            changed = False
            for n, reg in enumerate(ops):
                if isinstance(reg, ir.Const) and reg.kind is narrow_kind:
                    if kind is o.OPC_KIND.ST and n == 2:
                        continue
                    ops[n] = ir.Const(wide_kind, reg.value)
                    changed = True
                if isinstance(reg, ir.Reg) and reg in narrow_regs:
                    changed = True
            if not changed:
                inss.append(ins)
                continue
            kind = ins.opcode.kind
            if ins.opcode is o.SHL or ins.opcode is o.SHR:
                # deal with the shift amount which is subject to an implicit modulo "bitwidth -1"
                # by changing the width of the reg - we lose this information
                tmp_reg = fun.GetScratchReg(wide_kind, "tricky", False)
                inss.append(
                    ir.Ins(o.AND, [
                        tmp_reg, ops[2],
                        ir.Const(wide_kind,
                                 narrow_kind.bitwidth() - 1)
                    ]))
                ops[2] = tmp_reg
                if ins.opcode is o.SHR and isinstance(ops[1], ir.Reg):
                    # for SHR we also need to make sure the new high order bits are correct
                    tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True)
                    inss.append(ir.Ins(o.CONV, [tmp_reg, ops[1]]))
                    # the implicit understanding is that this will become nop or a move and not modify the
                    # high-order bit we just set in the previous instruction
                    inss.append(ir.Ins(o.CONV, [ops[1], tmp_reg]))
                inss.append(ins)
            elif ins.opcode is o.CNTLZ:
                inss.append(ins)
                excess = wide_kind.bitwidth() - narrow_kind.bitwidth()
                inss.append(
                    ir.Ins(o.SUB,
                           [ops[0], ops[0],
                            ir.Const(wide_kind, excess)]))
            elif ins.opcode is o.CNTTZ:
                inss.append(ins)
                inss.append(
                    ir.Ins(o.CMPLT, [
                        ops[0], ops[0],
                        ir.Const(wide_kind, narrow_kind.bitwidth()), ops[0],
                        ir.Const(wide_kind, narrow_kind.bitwidth())
                    ]))
            elif kind is o.OPC_KIND.LD and ops[0] in narrow_regs:
                inss.append(ins)
                tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True)
                inss.append(ir.Ins(o.CONV, [ops[0], tmp_reg]))
                ops[0] = tmp_reg
            elif (kind is o.OPC_KIND.ST and isinstance(ops[2], ir.Reg)
                  and ops[2] in narrow_regs):
                tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True)
                inss.append(ir.Ins(o.CONV, [tmp_reg, ops[2]]))
                inss.append(ins)
                ops[2] = tmp_reg
            elif ins.opcode is o.CONV:
                tmp_reg = fun.GetScratchReg(narrow_kind, "narrowed", True)
                inss.append(ir.Ins(o.CONV, [tmp_reg, ops[1]]))
                inss.append(ir.Ins(o.CONV, [ops[0], tmp_reg]))
            else:
                inss.append(ins)

        count += len(inss) - len(bbl.inss)
        bbl.inss = inss
    return count