Exemple #1
0
def GenerateMemcpyFun(unit: ir.Unit, addr_type: o.DK) -> ir.Fun:
    fun = unit.AddFun(
        ir.Fun("$memcpy", o.FUN_KIND.NORMAL, [],
               [addr_type, addr_type, o.DK.U32]))
    dst = fun.AddReg(ir.Reg("dst", addr_type))
    src = fun.AddReg(ir.Reg("src", addr_type))
    cnt = fun.AddReg(ir.Reg("cnt", o.DK.U32))
    data = fun.AddReg(ir.Reg("data", o.DK.U8))

    prolog = fun.AddBbl(ir.Bbl("prolog"))
    loop = fun.AddBbl(ir.Bbl("loop"))
    epilog = fun.AddBbl(ir.Bbl("epilog"))

    prolog.AddIns(ir.Ins(o.POPARG, [dst]))
    prolog.AddIns(ir.Ins(o.POPARG, [src]))
    prolog.AddIns(ir.Ins(o.POPARG, [cnt]))
    prolog.AddIns(ir.Ins(o.BRA, [epilog]))

    loop.AddIns(ir.Ins(o.SUB, [cnt, cnt, ONE]))
    loop.AddIns(ir.Ins(o.LD, [data, src, cnt]))
    loop.AddIns(ir.Ins(o.ST, [dst, cnt, data]))

    epilog.AddIns(ir.Ins(o.BLT, [ZERO, cnt, loop]))
    epilog.AddIns(ir.Ins(o.RET, []))
    return fun
Exemple #2
0
def BuildExample() -> ir.Unit:
    unit = ir.Unit("fib")
    fun_fib = unit.AddFun(
        ir.Fun("fib", o.FUN_KIND.NORMAL, [o.DK.U32], [o.DK.U32]))
    bbl_start = fun_fib.AddBbl(ir.Bbl("start"))
    bbl_difficult = fun_fib.AddBbl(ir.Bbl("difficult"))

    reg_in = fun_fib.AddReg(ir.Reg("in", o.DK.U32))
    reg_x = fun_fib.AddReg(ir.Reg("x", o.DK.U32))
    reg_out = fun_fib.AddReg(ir.Reg("out", o.DK.U32))

    bbl_start.AddIns(ir.Ins(o.POPARG, [reg_in]))
    bbl_start.AddIns(
        ir.Ins(o.BLT, [ir.Const(o.DK.U32, 1), reg_in, bbl_difficult]))
    bbl_start.AddIns(ir.Ins(o.PUSHARG, [reg_in]))
    bbl_start.AddIns(ir.Ins(o.RET, []))

    bbl_difficult.AddIns(ir.Ins(o.MOV, [reg_out, ir.Const(o.DK.U32, 0)]))
    bbl_difficult.AddIns(ir.Ins(o.SUB, [reg_x, reg_in, ir.Const(o.DK.U32, 1)]))

    bbl_difficult.AddIns(ir.Ins(o.PUSHARG, [reg_x]))
    bbl_difficult.AddIns(ir.Ins(o.BSR, [fun_fib]))
    bbl_difficult.AddIns(ir.Ins(o.POPARG, [reg_x]))
    bbl_difficult.AddIns(ir.Ins(o.ADD, [reg_out, reg_out, reg_x]))

    bbl_difficult.AddIns(ir.Ins(o.SUB, [reg_x, reg_in, ir.Const(o.DK.U32, 2)]))
    bbl_difficult.AddIns(ir.Ins(o.PUSHARG, [reg_x]))
    bbl_difficult.AddIns(ir.Ins(o.BSR, [fun_fib]))
    bbl_difficult.AddIns(ir.Ins(o.POPARG, [reg_x]))
    bbl_difficult.AddIns(ir.Ins(o.ADD, [reg_out, reg_out, reg_x]))

    bbl_difficult.AddIns(ir.Ins(o.PUSHARG, [reg_out]))
    bbl_difficult.AddIns(ir.Ins(o.RET, []))
    return unit
Exemple #3
0
def GenerateInitGlobalVarsFun(mod: wasm.Module, unit: ir.Unit,
                              addr_type: o.DK) -> ir.Fun:
    fun = unit.AddFun(ir.Fun("init_global_vars_fun", o.FUN_KIND.NORMAL, [],
                             []))
    bbl = fun.AddBbl(ir.Bbl("start"))
    epilog = fun.AddBbl(ir.Bbl("end"))
    epilog.AddIns(ir.Ins(o.RET, []))

    section = mod.sections.get(wasm.SECTION_ID.GLOBAL)
    if not section:
        return fun
    val32 = fun.AddReg(ir.Reg("val32", o.DK.U32))
    val64 = fun.AddReg(ir.Reg("val64", o.DK.U64))
    for n, data in enumerate(section.items):
        kind = o.MEM_KIND.RO if data.global_type.mut is wasm.MUT.CONST else o.MEM_KIND.RW
        mem = unit.AddMem(ir.Mem(f"global_vars_{n}", 16, kind))
        ins = GetInsFromInitializerExpression(data.expr)
        var_type = data.global_type.value_type
        if ins.opcode is wasm_opc.GLOBAL_GET:
            mem.AddData(
                ir.DataBytes(1, b"\0" * (4 if var_type.is_32bit() else 8)))
            src_mem = unit.GetMem(f"global_vars_{int(ins.args[0])}")
            reg = val32 if var_type.is_32bit() else val64
            bbl.AddIns(ir.Ins(o.LD_MEM, [reg, src_mem, ZERO]))
            bbl.AddIns(ir.Ins(o.ST_MEM, [mem, ZERO, reg]))
        elif ins.opcode.kind is wasm_opc.OPC_KIND.CONST:
            mem.AddData(
                ir.DataBytes(1, ExtractBytesFromConstIns(ins, var_type)))
        else:
            assert False, f"unsupported init instructions {ins}"
    return fun
Exemple #4
0
def GenerateStartup(unit: ir.Unit, global_argc, global_argv, main: ir.Fun,
                    init_global: ir.Fun, init_data: ir.Fun,
                    initial_heap_size_pages: int, addr_type: o.DK) -> ir.Fun:
    bit_width = addr_type.bitwidth()

    global_mem_base = unit.AddMem(ir.Mem("__memory_base", 0,
                                         o.MEM_KIND.EXTERN))

    fun = unit.AddFun(
        ir.Fun("main", o.FUN_KIND.NORMAL, [o.DK.U32], [o.DK.U32, addr_type]))
    argc = fun.AddReg(ir.Reg("argc", o.DK.U32))
    argv = fun.AddReg(ir.Reg("argv", addr_type))

    bbl = fun.AddBbl(ir.Bbl("start"))
    bbl.AddIns(ir.Ins(o.POPARG, [argc]))
    bbl.AddIns(ir.Ins(o.POPARG, [argv]))
    bbl.AddIns(ir.Ins(o.ST_MEM, [global_argc, ZERO, argc]))
    bbl.AddIns(ir.Ins(o.ST_MEM, [global_argv, ZERO, argv]))

    bbl.AddIns(ir.Ins(o.BSR, [unit.GetFun("__wasi_init")]))
    if initial_heap_size_pages:
        bbl.AddIns(
            ir.Ins(o.PUSHARG, [ir.Const(o.DK.S32, initial_heap_size_pages)]))
        bbl.AddIns(ir.Ins(o.BSR, [unit.GetFun("__memory_grow")]))
        bbl.AddIns(ir.Ins(o.POPARG, [fun.AddReg(ir.Reg("dummy", o.DK.S32))]))

    mem_base = fun.AddReg(ir.Reg("mem_base", addr_type))
    bbl.AddIns(ir.Ins(o.LD_MEM, [mem_base, global_mem_base, ZERO]))

    if init_global:
        bbl.AddIns(ir.Ins(o.BSR, [init_global]))
    if init_data:
        bbl.AddIns(ir.Ins(o.PUSHARG, [mem_base]))
        bbl.AddIns(ir.Ins(o.BSR, [init_data]))

    bbl.AddIns(ir.Ins(o.PUSHARG, [mem_base]))
    bbl.AddIns(ir.Ins(o.BSR, [main]))
    bbl.AddIns(ir.Ins(o.PUSHARG, [ir.Const(o.DK.U32, 0)]))
    bbl.AddIns(ir.Ins(o.RET, []))
    return fun
Exemple #5
0
def _GetRegOrConstOperand(fun: ir.Fun, last_kind: o.DK, ok: o.OP_KIND,
                          tc: o.TC, token: str, regs_cpu: Dict[str,
                                                               ir.Reg]) -> Any:
    if ok == o.OP_KIND.REG_OR_CONST:
        ok = o.OP_KIND.CONST if parse.IsLikelyConst(token) else o.OP_KIND.REG

    if ok is o.OP_KIND.REG:
        cpu_reg = None
        pos = token.find("@")
        if pos > 0:
            cpu_reg_name = token[pos + 1:]
            token = token[:pos]
            if cpu_reg_name == "STK":
                cpu_reg = ir.StackSlot(0)
            else:
                cpu_reg = regs_cpu.get(cpu_reg_name)
                assert cpu_reg is not None, f"unknown cpu_reg {token[pos + 1:]} known regs {regs_cpu.keys()}"
        pos = token.find(":")
        if pos < 0:
            reg = fun.GetReg(token)
        else:
            kind = token[pos + 1:]
            reg_name = token[:pos]
            reg = ir.Reg(reg_name, o.SHORT_STR_TO_RK.get(kind))
            fun.AddReg(reg)
            assert o.CheckTypeConstraint(last_kind, tc, reg.kind)
        if cpu_reg:
            if reg.cpu_reg:
                assert reg.cpu_reg == cpu_reg
            else:
                reg.cpu_reg = cpu_reg
        return reg

    else:
        pos = token.find(":")
        if pos >= 0:
            kind = token[pos + 1:]
            value_str = token[:pos]
            const = ir.ParseConst(value_str, o.SHORT_STR_TO_RK.get(kind))
            return const
        elif tc == o.TC.SAME_AS_PREV:
            const = ir.ParseConst(token, last_kind)
            return const
        elif tc == o.TC.OFFSET:
            const = ir.ParseOffsetConst(token)
            return const
        elif tc == o.TC.UINT:
            assert token[0] != "-"
            const = ir.ParseOffsetConst(token)
            return const
        else:
            assert False, f"cannot deduce type for const {token} [{tc}]"
Exemple #6
0
def GenerateInitDataFun(mod: wasm.Module, unit: ir.Unit, memcpy: ir.Fun,
                        addr_type: o.DK) -> typing.Optional[ir.Fun]:
    fun = unit.AddFun(
        ir.Fun("init_data_fun", o.FUN_KIND.NORMAL, [], [addr_type]))
    bbl = fun.AddBbl(ir.Bbl("start"))
    epilog = fun.AddBbl(ir.Bbl("end"))
    epilog.AddIns(ir.Ins(o.RET, []))
    section = mod.sections.get(wasm.SECTION_ID.DATA)

    mem_base = fun.AddReg(ir.Reg("mem_base", addr_type))
    bbl.AddIns(ir.Ins(o.POPARG, [mem_base]))
    if not section:
        return None

    offset = fun.AddReg(ir.Reg("offset", o.DK.S32))
    src = fun.AddReg(ir.Reg("src", addr_type))
    dst = fun.AddReg(ir.Reg("dst", addr_type))

    for n, data in enumerate(section.items):
        assert data.memory_index == 0
        assert isinstance(data.offset, wasm.Expression)
        ins = GetInsFromInitializerExpression(data.offset)
        init = unit.AddMem(ir.Mem(f"global_init_mem_{n}", 16, o.MEM_KIND.RO))
        init.AddData(ir.DataBytes(1, data.init))
        if ins.opcode is wasm_opc.GLOBAL_GET:
            src_mem = unit.GetMem(f"global_vars_{int(ins.args[0])}")
            bbl.AddIns(ir.Ins(o.LD_MEM, [offset, src_mem, ZERO]))
        elif ins.opcode is wasm_opc.I32_CONST:
            bbl.AddIns(ir.Ins(o.MOV,
                              [offset, ir.Const(o.DK.S32, ins.args[0])]))
        else:
            assert False, f"unsupported init instructions {ins}"
        bbl.AddIns(ir.Ins(o.LEA, [dst, mem_base, offset]))
        bbl.AddIns(ir.Ins(o.LEA_MEM, [src, init, ZERO]))
        bbl.AddIns(ir.Ins(o.PUSHARG, [ir.Const(o.DK.U32, len(data.init))]))
        bbl.AddIns(ir.Ins(o.PUSHARG, [src]))
        bbl.AddIns(ir.Ins(o.PUSHARG, [dst]))
        bbl.AddIns(ir.Ins(o.BSR, [memcpy]))
    return fun
Exemple #7
0
    def testNoChange(self):
        x = ir.Reg("x", o.DK.S32)
        target = ir.Bbl("target")
        bbl = ir.Bbl("bbl")
        bbl.live_out.add(x)
        bbl.AddIns(ir.Ins(O("poparg"), [x]))
        bbl.AddIns(ir.Ins(O("blt"), [target, ir.OffsetConst(1), x]))

        DumpBbl(bbl)

        live_ranges = liveness.BblGetLiveRanges(bbl, None, bbl.live_out, False)
        live_ranges.sort()
        lr_cross_bbl = [lr for lr in live_ranges if lr.is_cross_bbl()]
        lr_lac = [lr for lr in live_ranges if liveness.LiveRangeFlag.LAC in lr.flags]

        assert len(live_ranges) == 1
        assert len(lr_cross_bbl) == 1
        assert len(lr_lac) == 0, f"{lr_lac}"
Exemple #8
0
def ParseLiveRanges(fin, cpu_reg_map: Dict[str, ir.CpuReg]) -> List[LiveRange]:
    out: List[LiveRange] = []
    for line in fin:
        token = line.split()
        if not token or token[0] == "#":
            continue
        assert token.pop(0) == "LR"
        start = _ParsePos(token.pop(0))
        assert token.pop(0) == "-"
        end = _ParsePos(token.pop(0))
        flags = 0
        lr = LiveRange(start, end, ir.REG_INVALID, 0)
        out.append(lr)
        while token:
            t = token.pop(0)
            if t == "PRE_ALLOC":
                lr.flags |= LiveRangeFlag.PRE_ALLOC
            elif t == "LAC":
                lr.flags |= LiveRangeFlag.LAC
            elif t.startswith("def:"):
                reg_str = t[4:]
                cpu_reg_str = ""
                reg_name, kind_str = reg_str.split(":")
                if "@" in kind_str:
                    kind_str, cpu_reg_str = kind_str.split("@")
                lr.reg = ir.Reg(reg_name, o.DK[kind_str])
                if cpu_reg_str:
                    lr.reg.cpu_reg = cpu_reg_map[cpu_reg_str]
                    lr.cpu_reg = lr.reg.cpu_reg
                break
            elif t.startswith("uses:"):
                reg_str = t[5:]
                uses = [u.split(":") for u in reg_str.split(",")]
                for reg_name, def_pos_str in uses:
                    lr.uses.append(
                        FindDefRange(reg_name, int(def_pos_str), out))
            elif t.startswith("#"):
                break
            else:
                assert False, f"parse error [{t}] {line}"
    return out
Exemple #9
0
def ParseLiveRanges(fin, cpu_reg_map: Dict[str, ir.CpuReg]) -> List[LiveRange]:
    out: List[LiveRange] = []
    for line in fin:
        token = line.split()
        if not token or token[0] == "#":
            continue
        assert token.pop(0) == "RANGE"
        start_str = token.pop(0)
        start = BEFORE_BBL if start_str == "BEFORE_BBL" else int(start_str)
        assert token.pop(0) == "-"
        end_str = token.pop(0)
        end = AFTER_BBL if end_str == "AFTER_BBL" else int(end_str)
        flags = 0
        lr = LiveRange(start, end, ir.REG_INVALID, 0)
        out.append(lr)
        while token:
            t = token.pop(0)
            if t == "PRE_ALLOC":
                lr.flags |= LiveRangeFlag.PRE_ALLOC
            elif t == "LAC":
                lr.flags |= LiveRangeFlag.LAC
            elif t == "def:":
                reg_str = token.pop(0)
                cpu_reg_str = ""
                reg_name, kind_str = reg_str.split(":")
                if "@" in kind_str:
                    kind_str, cpu_reg_str = kind_str.split("@")
                lr.reg = ir.Reg(reg_name, o.DK[kind_str])
                if cpu_reg_str:
                    lr.reg.cpu_reg = cpu_reg_map[cpu_reg_str]
                    lr.cpu_reg = lr.reg.cpu_reg
                break
            elif t == "uses:":
                uses = [u.split(":") for u in token.pop(0).split(",")]
                for reg_name, def_pos_str in uses:
                    lr.uses.append(FindDefRange(reg_name, int(def_pos_str), out))
            else:
                assert False
    return out
Exemple #10
0
from typing import List, Dict, Optional, Tuple

from Base import canonicalize
from Base import reg_alloc
from Base import ir
from Base import liveness
from Base import lowering
from Base import opcode_tab as o
from Base import reg_stats
from Base import sanity
from Base import optimize
from Base import serialize
from CodeGenA64 import isel_tab
from CodeGenA64 import regs

_DUMMY_A32 = ir.Reg("dummy", o.DK.A32)
_ZERO_OFFSET = ir.Const(o.DK.U32, 0)


def _InsRewriteOutOfBoundsImmediates(
        ins: ir.Ins, fun: ir.Fun, unit: ir.Unit) -> Optional[List[ir.Ins]]:
    if ins.opcode in isel_tab.OPCODES_REQUIRING_SPECIAL_HANDLING:
        return None
    inss = []
    mismatches = isel_tab.FindtImmediateMismatchesInBestMatchPattern(ins, True)
    assert mismatches != isel_tab.MATCH_IMPOSSIBLE, f"could not match opcode {ins} {ins.operands}"

    if mismatches == 0:
        return None
    for pos in range(o.MAX_OPERANDS):
        if mismatches & (1 << pos) != 0:
Exemple #11
0
#!/usr/bin/python3
"""
mov.r reg_s32 reg_u32
"""

import unittest

from Base import ir
from Base import sanity
from Base import opcode_tab as o

reg_s64 = ir.Reg("reg_s64", o.DK.S64)
reg_s32 = ir.Reg("reg_s32", o.DK.S32)
reg_s18 = ir.Reg("reg_s16", o.DK.S16)
reg_s8 = ir.Reg("reg_s8", o.DK.S8)

reg_u64 = ir.Reg("reg_u64", o.DK.U64)
reg_u32 = ir.Reg("reg_u32", o.DK.U32)
reg_u16 = ir.Reg("reg_u16", o.DK.U16)
reg_u8 = ir.Reg("reg_u8", o.DK.U8)

reg_a64 = ir.Reg("reg_a64", o.DK.A64)
reg_a32 = ir.Reg("reg_a32", o.DK.A32)

reg_c64 = ir.Reg("reg_c64", o.DK.C64)
reg_c32 = ir.Reg("reg_c32", o.DK.C32)


class TestRegState(unittest.TestCase):
    def testMov(self):
        mov = o.Opcode.Lookup("mov")
Exemple #12
0
def DirReg(unit: ir.Unit, operands: List):
    fun = unit.funs[-1]
    reg_list = operands[1]
    assert isinstance(reg_list, list)
    for r in reg_list:
        fun.AddReg(ir.Reg(r, operands[0]))
Exemple #13
0
def GenerateFun(unit: ir.Unit, mod: wasm.Module, wasm_fun: wasm.Function,
                fun: ir.Fun, global_table, addr_type):
    # op_stack contains regs produced by GetOpReg (it may only occur a the position encoding in its name
    op_stack: typing.List[typing.Union[ir.Reg, ir.Const]] = []
    block_stack: typing.List[Block] = []
    bbls: typing.List[ir.Bbl] = []
    bbl_count = 0
    jtb_count = 0

    bbls.append(fun.AddBbl(ir.Bbl("start")))

    loc_index = 0
    assert fun.input_types[0] is addr_type
    mem_base = fun.AddReg(ir.Reg("mem_base", addr_type))
    bbls[-1].AddIns(ir.Ins(o.POPARG, [mem_base]))
    for dk in fun.input_types[1:]:
        reg = fun.AddReg(ir.Reg(f"$loc_{loc_index}", dk))
        loc_index += 1
        bbls[-1].AddIns(ir.Ins(o.POPARG, [reg]))

    for locals in wasm_fun.impl.locals_list:
        for i in range(locals.count):
            reg = fun.AddReg(
                ir.Reg(f"$loc_{loc_index}",
                       WASM_TYPE_TO_CWERG_TYPE[locals.kind]))
            loc_index += 1
            bbls[-1].AddIns(ir.Ins(o.MOV, [reg, ir.Const(reg.kind, 0)]))

    DEBUG = False
    if DEBUG:
        print()
        print(
            f"# {fun.name} #ins:{len(wasm_fun.impl.expr.instructions)} in:{fun.input_types} out:{fun.output_types}"
        )

    opc = None
    last_opc = None
    for n, wasm_ins in enumerate(wasm_fun.impl.expr.instructions):
        last_opc = opc
        opc = wasm_ins.opcode
        args = wasm_ins.args
        op_stack_size_before = len(op_stack)
        if DEBUG:
            print(f"#@@ {opc.name}", args, len(op_stack))
        if opc.kind is wasm_opc.OPC_KIND.CONST:
            # breaks for floats
            # breaks for floats
            kind = OPC_TYPE_TO_CWERG_TYPE[opc.op_type]
            dst = GetOpReg(fun, kind, len(op_stack))
            bbls[-1].AddIns(ir.Ins(o.MOV, [dst, ir.Const(kind, args[0])]))
            op_stack.append(dst)
        elif opc is wasm_opc.NOP:
            pass
        elif opc is wasm_opc.DROP:
            op_stack.pop(-1)
        elif opc is wasm_opc.LOCAL_GET:
            loc = GetLocalReg(fun, int(args[0]))
            dst = GetOpReg(fun, loc.kind, len(op_stack))
            bbls[-1].AddIns(ir.Ins(o.MOV, [dst, loc]))
            op_stack.append(dst)
        elif opc is wasm_opc.LOCAL_SET:
            op = op_stack.pop(-1)
            bbls[-1].AddIns(ir.Ins(o.MOV,
                                   [GetLocalReg(fun, int(args[0])), op]))
        elif opc is wasm_opc.LOCAL_TEE:
            op = op_stack[-1]  # no pop!
            bbls[-1].AddIns(ir.Ins(o.MOV,
                                   [GetLocalReg(fun, int(args[0])), op]))
        elif opc is wasm_opc.GLOBAL_GET:
            var_index = int(args[0])
            var: wasm.Glob = mod.sections.get(
                wasm.SECTION_ID.GLOBAL).items[var_index]
            dst = GetOpReg(fun,
                           WASM_TYPE_TO_CWERG_TYPE[var.global_type.value_type],
                           len(op_stack))
            var_mem = unit.GetMem(f"global_vars_{var_index}")
            bbls[-1].AddIns(ir.Ins(o.LD_MEM, [dst, var_mem, ZERO]))
            op_stack.append(dst)
        elif opc is wasm_opc.GLOBAL_SET:
            op = op_stack.pop(-1)
            var_index = int(args[0])
            var_mem = unit.GetMem(f"global_vars_{var_index}")
            bbls[-1].AddIns(ir.Ins(o.ST_MEM, [var_mem, ZERO, op]))
        elif opc.kind is wasm_opc.OPC_KIND.ALU:
            if wasm_opc.FLAGS.UNARY in opc.flags:
                opcode, arg_factory = WASM_ALU1_TO_CWERG[opc.basename]
                op = op_stack.pop(-1)
                dst = GetOpReg(fun, op.kind, len(op_stack))
                bbls[-1].AddIns(ir.Ins(opcode, [dst] + arg_factory(op)))
                op_stack.append(dst)
            else:
                op2 = op_stack.pop(-1)
                op1 = op_stack.pop(-1)
                dst = GetOpReg(fun, op1.kind, len(op_stack))
                alu = WASM_ALU_TO_CWERG[opc.basename]
                if wasm_opc.FLAGS.UNSIGNED in opc.flags:
                    tmp1 = GetOpReg(fun, ToUnsigned(op1.kind),
                                    op_stack_size_before + 1)
                    tmp2 = GetOpReg(fun, ToUnsigned(op1.kind),
                                    op_stack_size_before + 2)
                    tmp3 = GetOpReg(fun, ToUnsigned(op1.kind),
                                    op_stack_size_before + 3)
                    bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1]))
                    bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2]))
                    if isinstance(alu, o.Opcode):
                        bbls[-1].AddIns(ir.Ins(alu, [tmp3, tmp1, tmp2]))
                    else:
                        alu(tmp3, tmp1, tmp2, bbls[-1])
                    bbls[-1].AddIns(ir.Ins(o.CONV, [dst, tmp3]))
                else:
                    if isinstance(alu, o.Opcode):
                        bbls[-1].AddIns(ir.Ins(alu, [dst, op1, op2]))
                    else:
                        alu(dst, op1, op2, bbls[-1])
                op_stack.append(dst)
        elif opc.kind is wasm_opc.OPC_KIND.CONV:
            conv, dst_unsigned, src_unsigned = WASM_CONV_TO_CWERG[opc.name]
            op = op_stack.pop(-1)
            dst = GetOpReg(fun, OPC_TYPE_TO_CWERG_TYPE[opc.op_type],
                           len(op_stack))
            if src_unsigned:
                tmp = GetOpReg(fun, ToUnsigned(op.kind),
                               op_stack_size_before + 1)
                bbls[-1].AddIns(ir.Ins(o.CONV, [tmp, op]))
                op = tmp
            if dst_unsigned:
                tmp = GetOpReg(fun, ToUnsigned(dst.kind),
                               op_stack_size_before + 1)
                dst, tmp = tmp, dst
            bbls[-1].AddIns(ir.Ins(conv, [dst, op]))
            if dst_unsigned:
                bbls[-1].AddIns(ir.Ins(o.CONV, [tmp, dst]))
                dst = tmp
            op_stack.append(dst)
        elif opc.kind is wasm_opc.OPC_KIND.BITCAST:
            assert opc.name in SUPPORTED_BITCAST
            op = op_stack.pop(-1)
            dst = GetOpReg(fun, OPC_TYPE_TO_CWERG_TYPE[opc.op_type],
                           len(op_stack))
            bbls[-1].AddIns(ir.Ins(o.BITCAST, [dst, op]))
            op_stack.append(dst)
        elif opc.kind is wasm_opc.OPC_KIND.CMP:
            # this always works because of the sentinel: "end"
            succ = wasm_fun.impl.expr.instructions[n + 1]
            if succ.opcode not in {
                    wasm_opc.IF, wasm_opc.BR_IF, wasm_opc.SELECT
            }:
                cmp, res1, res2, op1, op2, unsigned = MakeCompare(
                    opc, op_stack)
                if unsigned:
                    tmp1 = GetOpReg(fun, ToUnsigned(op1.kind),
                                    op_stack_size_before + 1)
                    tmp2 = GetOpReg(fun, ToUnsigned(op1.kind),
                                    op_stack_size_before + 2)
                    bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1]))
                    bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2]))
                    op1 = tmp1
                    op2 = tmp2
                dst = GetOpReg(fun, o.DK.S32, len(op_stack))
                bbls[-1].AddIns(ir.Ins(cmp, [dst, res1, res2, op1, op2]))
                op_stack.append(dst)
        elif opc is wasm_opc.LOOP or opc is wasm_opc.BLOCK:
            block_stack.append(
                MakeBlock(bbl_count, opc, args, fun, op_stack, mod))
            bbl_count += 1
            bbls.append(block_stack[-1].start_bbl)
        elif opc is wasm_opc.IF:
            # note we do set the new bbl right away because we add some instructions to the old one
            # this always works because the stack cannot be empty at this point
            pred = wasm_fun.impl.expr.instructions[n - 1].opcode
            br, op1, op2, unsigned = MakeBranch(pred, op_stack, True)
            if unsigned:
                tmp1 = GetOpReg(fun, ToUnsigned(op1.kind),
                                op_stack_size_before + 1)
                tmp2 = GetOpReg(fun, ToUnsigned(op1.kind),
                                op_stack_size_before + 2)
                bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1]))
                bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2]))
                op1 = tmp1
                op2 = tmp2
            block_stack.append(
                MakeBlock(bbl_count, opc, args, fun, op_stack, mod))
            # assert len(block_stack[-1].param_types) == 0
            bbl_count += 1
            bbls[-1].AddIns(ir.Ins(br, [op1, op2, block_stack[-1].else_bbl]))
            bbls.append(block_stack[-1].start_bbl)
        elif opc is wasm_opc.ELSE:
            block = block_stack[-1]
            assert block.opcode is wasm_opc.IF
            block.FinalizeIf(op_stack, bbls[-1], fun, last_opc
                             in {wasm_opc.UNREACHABLE, wasm_opc.BR})
            if DEBUG:
                print(f"#@@ AFTER STACK RESTORE", len(op_stack))
            op_stack = op_stack[0:block.stack_start]
            bbls[-1].AddIns(ir.Ins(o.BRA, [block.end_bbl]))
            assert block.else_bbl is not None
            bbls.append(block.else_bbl)
            block.else_bbl = None
        elif opc is wasm_opc.END:
            if block_stack:
                block = block_stack.pop(-1)
                block.FinalizeEnd(
                    op_stack, bbls[-1], fun, last_opc
                    in {wasm_opc.UNREACHABLE, wasm_opc.BR})
                if block.else_bbl:
                    bbls.append(block.else_bbl)
                bbls.append(block.end_bbl)
            else:
                # end of function
                assert n + 1 == len(wasm_fun.impl.expr.instructions)
                pred = wasm_fun.impl.expr.instructions[n - 1].opcode
                if pred not in {
                        wasm_opc.RETURN, wasm_opc.UNREACHABLE, wasm_opc.BR
                }:
                    for x in reversed(fun.output_types):
                        op = op_stack.pop(-1)
                        assert op.kind == x, f"outputs: {fun.output_types} mismatch {op.kind} vs {x}"
                        bbls[-1].AddIns(ir.Ins(o.PUSHARG, [op]))
                    bbls[-1].AddIns(ir.Ins(o.RET, []))

        elif opc is wasm_opc.CALL:
            wasm_callee = mod.functions[int(wasm_ins.args[0])]
            callee = unit.GetFun(wasm_callee.name)
            assert callee, f"unknown fun: {wasm_callee.name}"
            EmitCall(fun, bbls[-1], ir.Ins(o.BSR, [callee]), op_stack,
                     mem_base, callee)
        elif opc is wasm_opc.CALL_INDIRECT:
            assert isinstance(args[1], wasm.TableIdx), f"{type(args[1])}"
            assert int(args[1]) == 0, f"only one table supported"
            assert isinstance(args[0], wasm.TypeIdx), f"{type(args[0])}"
            type_sec = mod.sections.get(wasm.SECTION_ID.TYPE)
            func_type: wasm.FunctionType = type_sec.items[int(args[0])]
            arguments = [addr_type] + TranslateTypeList(func_type.args)
            returns = TranslateTypeList(func_type.rets)
            # print (f"# @@@@ CALL INDIRECT {returns} <- {arguments}  [{int(args[0])}] {func_type}")
            signature = FindFunWithSignature(unit, arguments, returns)
            table_reg = GetOpReg(fun, addr_type, len(op_stack))
            code_type = o.DK.C32 if addr_type is o.DK.A32 else o.DK.C64
            fun_reg = GetOpReg(fun, code_type, len(op_stack) + 1)
            index = op_stack.pop(-1)
            assert index.kind is o.DK.S32

            bbls[-1].AddIns(
                ir.Ins(o.MUL, [
                    index, index,
                    ir.Const(o.DK.U32,
                             code_type.bitwidth() // 8)
                ]))
            bbls[-1].AddIns(ir.Ins(o.LEA_MEM, [table_reg, global_table, ZERO]))
            bbls[-1].AddIns(ir.Ins(o.LD, [fun_reg, table_reg, index]))
            EmitCall(fun, bbls[-1], ir.Ins(o.JSR, [fun_reg, signature]),
                     op_stack, mem_base, signature)
        elif opc is wasm_opc.RETURN:
            for x in reversed(fun.output_types):
                op = op_stack.pop(-1)
                assert op.kind == x, f"outputs: {fun.output_types} mismatch {op.kind} vs {x}"
                bbls[-1].AddIns(ir.Ins(o.PUSHARG, [op]))
            bbls[-1].AddIns(ir.Ins(o.RET, []))
        elif opc is wasm_opc.BR:
            assert isinstance(args[0], wasm.LabelIdx)
            block = GetTargetBlock(block_stack, args[0])
            target = block.start_bbl
            if block.opcode is not wasm_opc.LOOP:
                target = block.end_bbl
                block.FinalizeResultsCopy(op_stack, bbls[-1], fun)
            bbls[-1].AddIns(ir.Ins(o.BRA, [target]))
        elif opc is wasm_opc.BR_IF:
            assert isinstance(args[0], wasm.LabelIdx)
            pred = wasm_fun.impl.expr.instructions[n - 1].opcode
            br, op1, op2, unsigned = MakeBranch(pred, op_stack, False)
            if unsigned:
                tmp1 = GetOpReg(fun, ToUnsigned(op1.kind),
                                op_stack_size_before + 1)
                tmp2 = GetOpReg(fun, ToUnsigned(op1.kind),
                                op_stack_size_before + 2)
                bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1]))
                bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2]))
                op1 = tmp1
                op2 = tmp2
            block = GetTargetBlock(block_stack, args[0])
            target = block.start_bbl
            if block.opcode is not wasm_opc.LOOP:
                target = block.end_bbl
                block.FinalizeResultsCopy(op_stack, bbls[-1], fun)
            bbls[-1].AddIns(ir.Ins(br, [op1, op2, target]))
        elif opc is wasm_opc.SELECT:
            pred = wasm_fun.impl.expr.instructions[n - 1].opcode
            br, op1, op2, unsigned = MakeBranch(pred, op_stack, False)
            val_f = op_stack.pop(-1)
            val_t = op_stack.pop(-1)
            assert val_f.kind == val_t.kind
            reg = GetOpReg(fun, val_f.kind, len(op_stack))
            op_stack.append(reg)
            bbls[-1].AddIns(ir.Ins(o.MOV, [reg, val_t]))
            bbls.append(fun.AddBbl(ir.Bbl(f"select_{bbl_count}")))
            bbl_count += 1
            if unsigned:
                tmp1 = GetOpReg(fun, ToUnsigned(op1.kind),
                                op_stack_size_before + 1)
                tmp2 = GetOpReg(fun, ToUnsigned(op1.kind),
                                op_stack_size_before + 2)
                bbls[-2].AddIns(ir.Ins(o.CONV, [tmp1, op1]))
                bbls[-2].AddIns(ir.Ins(o.CONV, [tmp2, op2]))
                op1 = tmp1
                op2 = tmp2
            bbls[-2].AddIns(ir.Ins(br, [op1, op2, bbls[-1]]))
            bbls[-2].AddIns(ir.Ins(o.MOV, [reg, val_f]))
        elif opc.kind is wasm_opc.OPC_KIND.STORE:
            val = op_stack.pop(-1)
            offset = op_stack.pop(-1)
            if args[1] != 0:
                tmp = GetOpReg(fun, offset.kind, len(op_stack))
                bbls[-1].AddIns(
                    ir.Ins(o.ADD,
                           [tmp, offset,
                            ir.Const(offset.kind, args[1])]))
                offset = tmp
            dk_tmp = STORE_TO_CWERG_TYPE.get(opc.name)
            if dk_tmp is not None:
                tmp = GetOpReg(fun, dk_tmp, len(op_stack) + 1)
                bbls[-1].AddIns(ir.Ins(o.CONV, [tmp, val]))
                val = tmp
            bbls[-1].AddIns(ir.Ins(o.ST, [mem_base, offset, val]))
        elif opc.kind is wasm_opc.OPC_KIND.LOAD:
            offset = op_stack.pop(-1)
            if args[1] != 0:
                tmp = GetOpReg(fun, offset.kind, len(op_stack))
                bbls[-1].AddIns(
                    ir.Ins(o.ADD,
                           [tmp, offset,
                            ir.Const(offset.kind, args[1])]))
                offset = tmp
            dst = GetOpReg(fun, OPC_TYPE_TO_CWERG_TYPE[opc.op_type],
                           len(op_stack))
            op_stack.append(dst)
            dk_tmp = LOAD_TO_CWERG_TYPE.get(opc.basename)
            if dk_tmp:
                tmp = GetOpReg(fun, dk_tmp, len(op_stack))
                bbls[-1].AddIns(ir.Ins(o.LD, [tmp, mem_base, offset]))
                bbls[-1].AddIns(ir.Ins(o.CONV, [dst, tmp]))
            else:
                bbls[-1].AddIns(ir.Ins(o.LD, [dst, mem_base, offset]))
        elif opc is wasm_opc.BR_TABLE:
            bbl_tab = {
                n: GetTargetBbl(block_stack, x)
                for n, x in enumerate(args[0])
            }
            bbl_def = GetTargetBbl(block_stack, args[1])
            op = op_stack.pop(-1)
            tab_size = ir.Const(ToUnsigned(op.kind), len(bbl_tab))
            jtb_count += 1
            jtb = fun.AddJtb(
                ir.Jtb(f"jtb_{jtb_count}", bbl_def, bbl_tab, tab_size.value))
            reg_unsigned = GetOpReg(fun, ToUnsigned(op.kind), len(op_stack))
            bbls[-1].AddIns(ir.Ins(o.CONV, [reg_unsigned, op]))
            bbls[-1].AddIns(ir.Ins(o.BLE, [tab_size, reg_unsigned, bbl_def]))
            bbls[-1].AddIns(ir.Ins(o.SWITCH, [reg_unsigned, jtb]))
        elif opc is wasm_opc.UNREACHABLE:
            bbls[-1].AddIns(ir.Ins(o.TRAP, []))
        elif opc is wasm_opc.MEMORY_GROW or opc is wasm_opc.MEMORY_SIZE:
            op = ZERO_S
            if opc is wasm_opc.MEMORY_GROW:
                op = op_stack.pop(-1)
            bbls[-1].AddIns(ir.Ins(o.PUSHARG, [op]))
            assert unit.GetFun("__memory_grow")
            bbls[-1].AddIns(ir.Ins(o.BSR, [unit.GetFun("__memory_grow")]))
            dst = GetOpReg(fun, o.DK.S32, len(op_stack))
            bbls[-1].AddIns(ir.Ins(o.POPARG, [dst]))
            op_stack.append(dst)
        else:
            assert False, f"unsupported opcode [{opc.name}]"
    assert not op_stack, f"op_stack not empty in {fun.name}: {op_stack}"
    assert not block_stack, f"block_stack not empty in {fun.name}: {block_stack}"
    assert len(bbls) == len(fun.bbls)
    fun.bbls = bbls
Exemple #14
0
def GetOpReg(fun: ir.Fun, dk: o.DK, pos: int) -> ir.Reg:
    reg_name = f"$op_{pos}_{dk.name}"
    reg = fun.MaybeGetReg(reg_name)
    return reg if reg else fun.AddReg(ir.Reg(reg_name, dk))
Exemple #15
0
 def testMov(self):
     for kind in [o.DK.S32, o.DK.U32, o.DK.A32, o.DK.F32]:
         ins = ir.Ins(o.MOV, [ir.Reg(name="0", kind=kind), ir.Reg(name="1", kind=kind)])
         assert isel_tab.FindMatchingPattern(ins) is not None