Beispiel #1
0
def _GetRegOrConstOperand(fun: ir.Fun, last_kind: o.DK, ok: o.OP_KIND,
                          tc: o.TC, token: str, regs_cpu: Dict[str,
                                                               ir.Reg]) -> Any:
    if ok == o.OP_KIND.REG_OR_CONST:
        ok = o.OP_KIND.CONST if parse.IsLikelyConst(token) else o.OP_KIND.REG

    if ok is o.OP_KIND.REG:
        cpu_reg = None
        pos = token.find("@")
        if pos > 0:
            cpu_reg_name = token[pos + 1:]
            token = token[:pos]
            if cpu_reg_name == "STK":
                cpu_reg = ir.StackSlot(0)
            else:
                cpu_reg = regs_cpu.get(cpu_reg_name)
                assert cpu_reg is not None, f"unknown cpu_reg {token[pos + 1:]} known regs {regs_cpu.keys()}"
        pos = token.find(":")
        if pos < 0:
            reg = fun.GetReg(token)
        else:
            kind = token[pos + 1:]
            reg_name = token[:pos]
            reg = ir.Reg(reg_name, o.SHORT_STR_TO_RK.get(kind))
            fun.AddReg(reg)
            assert o.CheckTypeConstraint(last_kind, tc, reg.kind)
        if cpu_reg:
            if reg.cpu_reg:
                assert reg.cpu_reg == cpu_reg
            else:
                reg.cpu_reg = cpu_reg
        return reg

    else:
        pos = token.find(":")
        if pos >= 0:
            kind = token[pos + 1:]
            value_str = token[:pos]
            const = ir.ParseConst(value_str, o.SHORT_STR_TO_RK.get(kind))
            return const
        elif tc == o.TC.SAME_AS_PREV:
            const = ir.ParseConst(token, last_kind)
            return const
        elif tc == o.TC.OFFSET:
            const = ir.ParseOffsetConst(token)
            return const
        elif tc == o.TC.UINT:
            assert token[0] != "-"
            const = ir.ParseOffsetConst(token)
            return const
        else:
            assert False, f"cannot deduce type for const {token} [{tc}]"
Beispiel #2
0
def GenerateFun(unit: ir.Unit, mod: wasm.Module, wasm_fun: wasm.Function,
                fun: ir.Fun, global_table, addr_type):
    # op_stack contains regs produced by GetOpReg (it may only occur a the position encoding in its name
    op_stack: typing.List[typing.Union[ir.Reg, ir.Const]] = []
    block_stack: typing.List[Block] = []
    bbls: typing.List[ir.Bbl] = []
    bbl_count = 0
    jtb_count = 0

    bbls.append(fun.AddBbl(ir.Bbl("start")))

    loc_index = 0
    assert fun.input_types[0] is addr_type
    mem_base = fun.AddReg(ir.Reg("mem_base", addr_type))
    bbls[-1].AddIns(ir.Ins(o.POPARG, [mem_base]))
    for dk in fun.input_types[1:]:
        reg = fun.AddReg(ir.Reg(f"$loc_{loc_index}", dk))
        loc_index += 1
        bbls[-1].AddIns(ir.Ins(o.POPARG, [reg]))

    for locals in wasm_fun.impl.locals_list:
        for i in range(locals.count):
            reg = fun.AddReg(
                ir.Reg(f"$loc_{loc_index}",
                       WASM_TYPE_TO_CWERG_TYPE[locals.kind]))
            loc_index += 1
            bbls[-1].AddIns(ir.Ins(o.MOV, [reg, ir.Const(reg.kind, 0)]))

    DEBUG = False
    if DEBUG:
        print()
        print(
            f"# {fun.name} #ins:{len(wasm_fun.impl.expr.instructions)} in:{fun.input_types} out:{fun.output_types}"
        )

    opc = None
    last_opc = None
    for n, wasm_ins in enumerate(wasm_fun.impl.expr.instructions):
        last_opc = opc
        opc = wasm_ins.opcode
        args = wasm_ins.args
        op_stack_size_before = len(op_stack)
        if DEBUG:
            print(f"#@@ {opc.name}", args, len(op_stack))
        if opc.kind is wasm_opc.OPC_KIND.CONST:
            # breaks for floats
            # breaks for floats
            kind = OPC_TYPE_TO_CWERG_TYPE[opc.op_type]
            dst = GetOpReg(fun, kind, len(op_stack))
            bbls[-1].AddIns(ir.Ins(o.MOV, [dst, ir.Const(kind, args[0])]))
            op_stack.append(dst)
        elif opc is wasm_opc.NOP:
            pass
        elif opc is wasm_opc.DROP:
            op_stack.pop(-1)
        elif opc is wasm_opc.LOCAL_GET:
            loc = GetLocalReg(fun, int(args[0]))
            dst = GetOpReg(fun, loc.kind, len(op_stack))
            bbls[-1].AddIns(ir.Ins(o.MOV, [dst, loc]))
            op_stack.append(dst)
        elif opc is wasm_opc.LOCAL_SET:
            op = op_stack.pop(-1)
            bbls[-1].AddIns(ir.Ins(o.MOV,
                                   [GetLocalReg(fun, int(args[0])), op]))
        elif opc is wasm_opc.LOCAL_TEE:
            op = op_stack[-1]  # no pop!
            bbls[-1].AddIns(ir.Ins(o.MOV,
                                   [GetLocalReg(fun, int(args[0])), op]))
        elif opc is wasm_opc.GLOBAL_GET:
            var_index = int(args[0])
            var: wasm.Glob = mod.sections.get(
                wasm.SECTION_ID.GLOBAL).items[var_index]
            dst = GetOpReg(fun,
                           WASM_TYPE_TO_CWERG_TYPE[var.global_type.value_type],
                           len(op_stack))
            var_mem = unit.GetMem(f"global_vars_{var_index}")
            bbls[-1].AddIns(ir.Ins(o.LD_MEM, [dst, var_mem, ZERO]))
            op_stack.append(dst)
        elif opc is wasm_opc.GLOBAL_SET:
            op = op_stack.pop(-1)
            var_index = int(args[0])
            var_mem = unit.GetMem(f"global_vars_{var_index}")
            bbls[-1].AddIns(ir.Ins(o.ST_MEM, [var_mem, ZERO, op]))
        elif opc.kind is wasm_opc.OPC_KIND.ALU:
            if wasm_opc.FLAGS.UNARY in opc.flags:
                opcode, arg_factory = WASM_ALU1_TO_CWERG[opc.basename]
                op = op_stack.pop(-1)
                dst = GetOpReg(fun, op.kind, len(op_stack))
                bbls[-1].AddIns(ir.Ins(opcode, [dst] + arg_factory(op)))
                op_stack.append(dst)
            else:
                op2 = op_stack.pop(-1)
                op1 = op_stack.pop(-1)
                dst = GetOpReg(fun, op1.kind, len(op_stack))
                alu = WASM_ALU_TO_CWERG[opc.basename]
                if wasm_opc.FLAGS.UNSIGNED in opc.flags:
                    tmp1 = GetOpReg(fun, ToUnsigned(op1.kind),
                                    op_stack_size_before + 1)
                    tmp2 = GetOpReg(fun, ToUnsigned(op1.kind),
                                    op_stack_size_before + 2)
                    tmp3 = GetOpReg(fun, ToUnsigned(op1.kind),
                                    op_stack_size_before + 3)
                    bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1]))
                    bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2]))
                    if isinstance(alu, o.Opcode):
                        bbls[-1].AddIns(ir.Ins(alu, [tmp3, tmp1, tmp2]))
                    else:
                        alu(tmp3, tmp1, tmp2, bbls[-1])
                    bbls[-1].AddIns(ir.Ins(o.CONV, [dst, tmp3]))
                else:
                    if isinstance(alu, o.Opcode):
                        bbls[-1].AddIns(ir.Ins(alu, [dst, op1, op2]))
                    else:
                        alu(dst, op1, op2, bbls[-1])
                op_stack.append(dst)
        elif opc.kind is wasm_opc.OPC_KIND.CONV:
            conv, dst_unsigned, src_unsigned = WASM_CONV_TO_CWERG[opc.name]
            op = op_stack.pop(-1)
            dst = GetOpReg(fun, OPC_TYPE_TO_CWERG_TYPE[opc.op_type],
                           len(op_stack))
            if src_unsigned:
                tmp = GetOpReg(fun, ToUnsigned(op.kind),
                               op_stack_size_before + 1)
                bbls[-1].AddIns(ir.Ins(o.CONV, [tmp, op]))
                op = tmp
            if dst_unsigned:
                tmp = GetOpReg(fun, ToUnsigned(dst.kind),
                               op_stack_size_before + 1)
                dst, tmp = tmp, dst
            bbls[-1].AddIns(ir.Ins(conv, [dst, op]))
            if dst_unsigned:
                bbls[-1].AddIns(ir.Ins(o.CONV, [tmp, dst]))
                dst = tmp
            op_stack.append(dst)
        elif opc.kind is wasm_opc.OPC_KIND.BITCAST:
            assert opc.name in SUPPORTED_BITCAST
            op = op_stack.pop(-1)
            dst = GetOpReg(fun, OPC_TYPE_TO_CWERG_TYPE[opc.op_type],
                           len(op_stack))
            bbls[-1].AddIns(ir.Ins(o.BITCAST, [dst, op]))
            op_stack.append(dst)
        elif opc.kind is wasm_opc.OPC_KIND.CMP:
            # this always works because of the sentinel: "end"
            succ = wasm_fun.impl.expr.instructions[n + 1]
            if succ.opcode not in {
                    wasm_opc.IF, wasm_opc.BR_IF, wasm_opc.SELECT
            }:
                cmp, res1, res2, op1, op2, unsigned = MakeCompare(
                    opc, op_stack)
                if unsigned:
                    tmp1 = GetOpReg(fun, ToUnsigned(op1.kind),
                                    op_stack_size_before + 1)
                    tmp2 = GetOpReg(fun, ToUnsigned(op1.kind),
                                    op_stack_size_before + 2)
                    bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1]))
                    bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2]))
                    op1 = tmp1
                    op2 = tmp2
                dst = GetOpReg(fun, o.DK.S32, len(op_stack))
                bbls[-1].AddIns(ir.Ins(cmp, [dst, res1, res2, op1, op2]))
                op_stack.append(dst)
        elif opc is wasm_opc.LOOP or opc is wasm_opc.BLOCK:
            block_stack.append(
                MakeBlock(bbl_count, opc, args, fun, op_stack, mod))
            bbl_count += 1
            bbls.append(block_stack[-1].start_bbl)
        elif opc is wasm_opc.IF:
            # note we do set the new bbl right away because we add some instructions to the old one
            # this always works because the stack cannot be empty at this point
            pred = wasm_fun.impl.expr.instructions[n - 1].opcode
            br, op1, op2, unsigned = MakeBranch(pred, op_stack, True)
            if unsigned:
                tmp1 = GetOpReg(fun, ToUnsigned(op1.kind),
                                op_stack_size_before + 1)
                tmp2 = GetOpReg(fun, ToUnsigned(op1.kind),
                                op_stack_size_before + 2)
                bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1]))
                bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2]))
                op1 = tmp1
                op2 = tmp2
            block_stack.append(
                MakeBlock(bbl_count, opc, args, fun, op_stack, mod))
            # assert len(block_stack[-1].param_types) == 0
            bbl_count += 1
            bbls[-1].AddIns(ir.Ins(br, [op1, op2, block_stack[-1].else_bbl]))
            bbls.append(block_stack[-1].start_bbl)
        elif opc is wasm_opc.ELSE:
            block = block_stack[-1]
            assert block.opcode is wasm_opc.IF
            block.FinalizeIf(op_stack, bbls[-1], fun, last_opc
                             in {wasm_opc.UNREACHABLE, wasm_opc.BR})
            if DEBUG:
                print(f"#@@ AFTER STACK RESTORE", len(op_stack))
            op_stack = op_stack[0:block.stack_start]
            bbls[-1].AddIns(ir.Ins(o.BRA, [block.end_bbl]))
            assert block.else_bbl is not None
            bbls.append(block.else_bbl)
            block.else_bbl = None
        elif opc is wasm_opc.END:
            if block_stack:
                block = block_stack.pop(-1)
                block.FinalizeEnd(
                    op_stack, bbls[-1], fun, last_opc
                    in {wasm_opc.UNREACHABLE, wasm_opc.BR})
                if block.else_bbl:
                    bbls.append(block.else_bbl)
                bbls.append(block.end_bbl)
            else:
                # end of function
                assert n + 1 == len(wasm_fun.impl.expr.instructions)
                pred = wasm_fun.impl.expr.instructions[n - 1].opcode
                if pred not in {
                        wasm_opc.RETURN, wasm_opc.UNREACHABLE, wasm_opc.BR
                }:
                    for x in reversed(fun.output_types):
                        op = op_stack.pop(-1)
                        assert op.kind == x, f"outputs: {fun.output_types} mismatch {op.kind} vs {x}"
                        bbls[-1].AddIns(ir.Ins(o.PUSHARG, [op]))
                    bbls[-1].AddIns(ir.Ins(o.RET, []))

        elif opc is wasm_opc.CALL:
            wasm_callee = mod.functions[int(wasm_ins.args[0])]
            callee = unit.GetFun(wasm_callee.name)
            assert callee, f"unknown fun: {wasm_callee.name}"
            EmitCall(fun, bbls[-1], ir.Ins(o.BSR, [callee]), op_stack,
                     mem_base, callee)
        elif opc is wasm_opc.CALL_INDIRECT:
            assert isinstance(args[1], wasm.TableIdx), f"{type(args[1])}"
            assert int(args[1]) == 0, f"only one table supported"
            assert isinstance(args[0], wasm.TypeIdx), f"{type(args[0])}"
            type_sec = mod.sections.get(wasm.SECTION_ID.TYPE)
            func_type: wasm.FunctionType = type_sec.items[int(args[0])]
            arguments = [addr_type] + TranslateTypeList(func_type.args)
            returns = TranslateTypeList(func_type.rets)
            # print (f"# @@@@ CALL INDIRECT {returns} <- {arguments}  [{int(args[0])}] {func_type}")
            signature = FindFunWithSignature(unit, arguments, returns)
            table_reg = GetOpReg(fun, addr_type, len(op_stack))
            code_type = o.DK.C32 if addr_type is o.DK.A32 else o.DK.C64
            fun_reg = GetOpReg(fun, code_type, len(op_stack) + 1)
            index = op_stack.pop(-1)
            assert index.kind is o.DK.S32

            bbls[-1].AddIns(
                ir.Ins(o.MUL, [
                    index, index,
                    ir.Const(o.DK.U32,
                             code_type.bitwidth() // 8)
                ]))
            bbls[-1].AddIns(ir.Ins(o.LEA_MEM, [table_reg, global_table, ZERO]))
            bbls[-1].AddIns(ir.Ins(o.LD, [fun_reg, table_reg, index]))
            EmitCall(fun, bbls[-1], ir.Ins(o.JSR, [fun_reg, signature]),
                     op_stack, mem_base, signature)
        elif opc is wasm_opc.RETURN:
            for x in reversed(fun.output_types):
                op = op_stack.pop(-1)
                assert op.kind == x, f"outputs: {fun.output_types} mismatch {op.kind} vs {x}"
                bbls[-1].AddIns(ir.Ins(o.PUSHARG, [op]))
            bbls[-1].AddIns(ir.Ins(o.RET, []))
        elif opc is wasm_opc.BR:
            assert isinstance(args[0], wasm.LabelIdx)
            block = GetTargetBlock(block_stack, args[0])
            target = block.start_bbl
            if block.opcode is not wasm_opc.LOOP:
                target = block.end_bbl
                block.FinalizeResultsCopy(op_stack, bbls[-1], fun)
            bbls[-1].AddIns(ir.Ins(o.BRA, [target]))
        elif opc is wasm_opc.BR_IF:
            assert isinstance(args[0], wasm.LabelIdx)
            pred = wasm_fun.impl.expr.instructions[n - 1].opcode
            br, op1, op2, unsigned = MakeBranch(pred, op_stack, False)
            if unsigned:
                tmp1 = GetOpReg(fun, ToUnsigned(op1.kind),
                                op_stack_size_before + 1)
                tmp2 = GetOpReg(fun, ToUnsigned(op1.kind),
                                op_stack_size_before + 2)
                bbls[-1].AddIns(ir.Ins(o.CONV, [tmp1, op1]))
                bbls[-1].AddIns(ir.Ins(o.CONV, [tmp2, op2]))
                op1 = tmp1
                op2 = tmp2
            block = GetTargetBlock(block_stack, args[0])
            target = block.start_bbl
            if block.opcode is not wasm_opc.LOOP:
                target = block.end_bbl
                block.FinalizeResultsCopy(op_stack, bbls[-1], fun)
            bbls[-1].AddIns(ir.Ins(br, [op1, op2, target]))
        elif opc is wasm_opc.SELECT:
            pred = wasm_fun.impl.expr.instructions[n - 1].opcode
            br, op1, op2, unsigned = MakeBranch(pred, op_stack, False)
            val_f = op_stack.pop(-1)
            val_t = op_stack.pop(-1)
            assert val_f.kind == val_t.kind
            reg = GetOpReg(fun, val_f.kind, len(op_stack))
            op_stack.append(reg)
            bbls[-1].AddIns(ir.Ins(o.MOV, [reg, val_t]))
            bbls.append(fun.AddBbl(ir.Bbl(f"select_{bbl_count}")))
            bbl_count += 1
            if unsigned:
                tmp1 = GetOpReg(fun, ToUnsigned(op1.kind),
                                op_stack_size_before + 1)
                tmp2 = GetOpReg(fun, ToUnsigned(op1.kind),
                                op_stack_size_before + 2)
                bbls[-2].AddIns(ir.Ins(o.CONV, [tmp1, op1]))
                bbls[-2].AddIns(ir.Ins(o.CONV, [tmp2, op2]))
                op1 = tmp1
                op2 = tmp2
            bbls[-2].AddIns(ir.Ins(br, [op1, op2, bbls[-1]]))
            bbls[-2].AddIns(ir.Ins(o.MOV, [reg, val_f]))
        elif opc.kind is wasm_opc.OPC_KIND.STORE:
            val = op_stack.pop(-1)
            offset = op_stack.pop(-1)
            if args[1] != 0:
                tmp = GetOpReg(fun, offset.kind, len(op_stack))
                bbls[-1].AddIns(
                    ir.Ins(o.ADD,
                           [tmp, offset,
                            ir.Const(offset.kind, args[1])]))
                offset = tmp
            dk_tmp = STORE_TO_CWERG_TYPE.get(opc.name)
            if dk_tmp is not None:
                tmp = GetOpReg(fun, dk_tmp, len(op_stack) + 1)
                bbls[-1].AddIns(ir.Ins(o.CONV, [tmp, val]))
                val = tmp
            bbls[-1].AddIns(ir.Ins(o.ST, [mem_base, offset, val]))
        elif opc.kind is wasm_opc.OPC_KIND.LOAD:
            offset = op_stack.pop(-1)
            if args[1] != 0:
                tmp = GetOpReg(fun, offset.kind, len(op_stack))
                bbls[-1].AddIns(
                    ir.Ins(o.ADD,
                           [tmp, offset,
                            ir.Const(offset.kind, args[1])]))
                offset = tmp
            dst = GetOpReg(fun, OPC_TYPE_TO_CWERG_TYPE[opc.op_type],
                           len(op_stack))
            op_stack.append(dst)
            dk_tmp = LOAD_TO_CWERG_TYPE.get(opc.basename)
            if dk_tmp:
                tmp = GetOpReg(fun, dk_tmp, len(op_stack))
                bbls[-1].AddIns(ir.Ins(o.LD, [tmp, mem_base, offset]))
                bbls[-1].AddIns(ir.Ins(o.CONV, [dst, tmp]))
            else:
                bbls[-1].AddIns(ir.Ins(o.LD, [dst, mem_base, offset]))
        elif opc is wasm_opc.BR_TABLE:
            bbl_tab = {
                n: GetTargetBbl(block_stack, x)
                for n, x in enumerate(args[0])
            }
            bbl_def = GetTargetBbl(block_stack, args[1])
            op = op_stack.pop(-1)
            tab_size = ir.Const(ToUnsigned(op.kind), len(bbl_tab))
            jtb_count += 1
            jtb = fun.AddJtb(
                ir.Jtb(f"jtb_{jtb_count}", bbl_def, bbl_tab, tab_size.value))
            reg_unsigned = GetOpReg(fun, ToUnsigned(op.kind), len(op_stack))
            bbls[-1].AddIns(ir.Ins(o.CONV, [reg_unsigned, op]))
            bbls[-1].AddIns(ir.Ins(o.BLE, [tab_size, reg_unsigned, bbl_def]))
            bbls[-1].AddIns(ir.Ins(o.SWITCH, [reg_unsigned, jtb]))
        elif opc is wasm_opc.UNREACHABLE:
            bbls[-1].AddIns(ir.Ins(o.TRAP, []))
        elif opc is wasm_opc.MEMORY_GROW or opc is wasm_opc.MEMORY_SIZE:
            op = ZERO_S
            if opc is wasm_opc.MEMORY_GROW:
                op = op_stack.pop(-1)
            bbls[-1].AddIns(ir.Ins(o.PUSHARG, [op]))
            assert unit.GetFun("__memory_grow")
            bbls[-1].AddIns(ir.Ins(o.BSR, [unit.GetFun("__memory_grow")]))
            dst = GetOpReg(fun, o.DK.S32, len(op_stack))
            bbls[-1].AddIns(ir.Ins(o.POPARG, [dst]))
            op_stack.append(dst)
        else:
            assert False, f"unsupported opcode [{opc.name}]"
    assert not op_stack, f"op_stack not empty in {fun.name}: {op_stack}"
    assert not block_stack, f"block_stack not empty in {fun.name}: {block_stack}"
    assert len(bbls) == len(fun.bbls)
    fun.bbls = bbls
Beispiel #3
0
def GetOpReg(fun: ir.Fun, dk: o.DK, pos: int) -> ir.Reg:
    reg_name = f"$op_{pos}_{dk.name}"
    reg = fun.MaybeGetReg(reg_name)
    return reg if reg else fun.AddReg(ir.Reg(reg_name, dk))