Пример #1
0
 def get_available_reg(self, lr: reg_alloc.LiveRange) -> ir.CpuReg:
     lac = liveness.LiveRangeFlag.LAC in lr.flags
     is_gpr = lr.reg.kind.flavor() != o.DK_FLAVOR_F
     available = self.get_available(lac, is_gpr)
     # print(f"GET {lr} {self}  avail:{available:x}")
     if lr.reg.kind == o.DK.F64:
         for n in range(len(DBL_REGS)):
             mask = 3 << (n * 2)  # two adjacent bit at an even bit pos
             if available & mask == mask:
                 if (not self._flt_reserved[n * 2 + 0].has_conflict(lr) and
                         not self._flt_reserved[n * 2 + 1].has_conflict(lr)):
                     self.set_available(lac, is_gpr, available & ~mask)
                     return DBL_REGS[n]
     elif lr.reg.kind == o.DK.F32:
         for n in range(len(_FLT_REGS)):
             mask = 1 << n
             if available & mask == mask:
                 if not self._flt_reserved[n].has_conflict(lr):
                     self.set_available(lac, is_gpr, available & ~mask)
                     return _FLT_REGS[n]
     else:
         for n in range(len(_GPR_REGS)):
             mask = 1 << n
             if mask & available == mask:
                 if not self._gpr_reserved[n].has_conflict(lr):
                     self.set_available(lac, is_gpr, available & ~mask)
                     return _GPR_REGS[n]
     if self._allow_spilling:
         return ir.CPU_REG_SPILL
     print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
     print("\n".join(serialize.BblRenderToAsm(self._bbl)))
     print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
     assert False, f"in {self._fun.name}:{self._bbl.name} no reg available for {lr} in {self}"
Пример #2
0
 def get_available_reg(self, lr: reg_alloc.LiveRange) -> ir.CpuReg:
     lac = liveness.LiveRangeFlag.LAC in lr.flags
     is_gpr = lr.reg.kind.flavor() != o.DK_FLAVOR_F
     available = self.get_available(lac, is_gpr)
     # print(f"GET {lr} {self}  avail:{available:x}")
     if not is_gpr:
         for n in range(len(_FLT_REGS)):
             mask = 1 << n
             if available & mask == mask:
                 if not self._flt_reserved[n].has_conflict(lr):
                     self.set_available(lac, is_gpr, available & ~mask)
                     return _KIND_TO_CPU_REG_LIST[lr.reg.kind][n]
     else:
         for n in range(len(_GPR_REGS)):
             mask = 1 << n
             if mask & available == mask:
                 if not self._gpr_reserved[n].has_conflict(lr):
                     self.set_available(lac, is_gpr, available & ~mask)
                     return _KIND_TO_CPU_REG_LIST[lr.reg.kind][n]
     if self._allow_spilling:
         return ir.CPU_REG_SPILL
     print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
     lines = [f"{n - 1:2} {x}" for n, x in enumerate(serialize.BblRenderToAsm(self._bbl))]
     print("\n".join(lines))
     print(f"# ALLOCATION IMPOSSIBLE - no spilling allowed in {self._fun.name}:{self._bbl.name}")
     print(f"# {lr}")
     print(f"# ALLOCATOR status: {self}")
     print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
     assert False, f"in {self._fun.name}:{self._bbl.name} no reg available for {lr} in {self}"
Пример #3
0
def FunCheckCFG(fun: ir.Fun, check_fallthroughs):
    assert len(
        fun.bbls) == len(
        fun.bbl_syms), f"bbl mismatch {len(fun.bbls)} {fun.bbl_syms}"
    for n, bbl in enumerate(fun.bbls):
        assert bbl.name in fun.bbl_syms
        for x in bbl.edge_out:
            if x.name not in fun.bbl_syms:
                print ("\n".join(serialize.BblRenderToAsm(bbl)))
            assert x.name in fun.bbl_syms, f"missing bbl out edge {x}  from {bbl.name} in {fun.name}"
        for x in bbl.edge_in:
            assert x.name in fun.bbl_syms,  f"missing in out edge {x} to {bbl.name} in {fun.name}"
        # check everything but the last Ins
        for ins in bbl.inss[:-1]:
            assert not ins.opcode.is_bbl_terminator(), (
                f"{fun.name} {bbl}: bbl terminator in middle of bbl {ins} {bbl.inss[-1]}")
            InsCheckConstraints(ins)
        if not bbl.inss:
            assert len(bbl.edge_out) == 1, f"{bbl} {bbl.edge_out}"
            succ = bbl.edge_out[0]
            assert bbl in succ.edge_in
        else:
            last_ins = bbl.inss[-1]
            last_ins_kind = last_ins.opcode.kind
            InsCheckConstraints(last_ins)
            if last_ins_kind == o.OPC_KIND.SWITCH:
                # TODO
                pass
            elif last_ins_kind is o.OPC_KIND.COND_BRA:
                assert len(bbl.edge_out) == 2, f"expected two out edges for bbl {bbl.name} {fun.name}"
                succ1 = bbl.edge_out[0]
                assert bbl in succ1.edge_in, f"cond bra dst inconsistency in {fun.name}"
                succ2 = bbl.edge_out[1]
                assert bbl in succ2.edge_in, f"cond bra dst inconsistency in {fun.name}"
                assert last_ins.operands[2] in bbl.edge_out, last_ins
                if check_fallthroughs:
                    assert fun.bbls[n + 1] in bbl.edge_out

            elif last_ins_kind == o.OPC_KIND.BRA:
                assert len(bbl.edge_out) == 1
                succ = bbl.edge_out[0]
                assert bbl in succ.edge_in
                assert last_ins.operands[0] == succ
            elif last_ins_kind == o.OPC_KIND.RET:
                assert len(bbl.edge_out) == 0
            else:
                assert len(bbl.edge_out) == 1
                succ = bbl.edge_out[0]
                assert bbl in succ.edge_in
                if check_fallthroughs:
                    assert succ == fun.bbls[n + 1]
Пример #4
0
 def get_available_reg(self, lr: reg_alloc.LiveRange) -> ir.CpuReg:
     lac = liveness.LiveRangeFlag.LAC in lr.flags
     is_gpr = lr.reg.kind.flavor() != o.DK_FLAVOR_F
     available = self.get_available(lac, is_gpr)
     # print(f"GET {lr} {self}  avail:{available:x}")
     if lr.reg.kind == o.DK.F64:
         for n in range(len(DBL_REGS)):
             mask = 3 << (n * 2)  # two adjacent bit at an even bit pos
             if available & mask == mask:
                 if (not self._flt_reserved[n * 2 + 0].has_conflict(lr)
                         and not self._flt_reserved[n * 2 +
                                                    1].has_conflict(lr)):
                     self.set_available(lac, is_gpr, available & ~mask)
                     return DBL_REGS[n]
     elif lr.reg.kind == o.DK.F32:
         for n in range(len(FLT_REGS)):
             mask = 1 << n
             if available & mask == mask:
                 if not self._flt_reserved[n].has_conflict(lr):
                     self.set_available(lac, is_gpr, available & ~mask)
                     return FLT_REGS[n]
     else:
         for n in range(len(GPR_REGS)):
             mask = 1 << n
             if mask & available == mask:
                 if not self._gpr_reserved[n].has_conflict(lr):
                     self.set_available(lac, is_gpr, available & ~mask)
                     return GPR_REGS[n]
     if self._allow_spilling:
         return ir.CPU_REG_SPILL
     print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
     lines = [
         f"{n - 1:2} {x}"
         for n, x in enumerate(serialize.BblRenderToAsm(self._bbl))
     ]
     print("\n".join(lines))
     print(
         f"# ALLOCATION IMPOSSIBLE - no spilling allowed in {self._fun.name}:{self._bbl.name}"
     )
     print(f"# {lr}")
     print(f"# ALLOCATOR status: {self}")
     print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
     assert False
Пример #5
0
def FunComputeBblRegUsageStats(
        fun: ir.Fun, reg_kind_map: Dict[o.DK, int]) -> Dict[REG_KIND_LAC, int]:
    """
    Computes maximum number of register needed for locals across all Bbls

    Requires liveness.
    """
    pool = BblRegUsageStatsRegPool(reg_kind_map)
    for bbl in fun.bbls:
        live_ranges = liveness.BblGetLiveRanges(bbl, fun, bbl.live_out)
        live_ranges.sort()
        if TRACE_REG_ALLOC:
            print("@" * 60)
            print("\n".join(serialize.BblRenderToAsm(bbl)))
            for lr in live_ranges:
                print(lr)
        # we do not want re-use of regs that are not coming from the pool
        for lr in live_ranges:
            if LiveRangeShouldBeIgnored(lr, reg_kind_map):
                lr.flags |= liveness.LiveRangeFlag.IGNORE
        reg_alloc.RegisterAssignerLinearScan(live_ranges, pool)
    return pool.usage()
Пример #6
0
def FunComputeReachingDefs(fun: ir.Fun):
    """
    Poor man's SSA we compute reaching defs at the Bbl beginning and
    for each operand use.

    This should be run after unreachable code has been removed.
    """
    # Step 1: Initialization
    all_defs: Dict[str, ReachingDefs] = {}
    all_uses = set()
    for bbl in fun.bbls:
        defs, uses = _BblComputeDefs(bbl)
        all_uses.update(uses)
        all_defs[bbl.name] = ReachingDefs(defs)

    first = fun.bbls[0]
    all_defs[first.name].defs_in = {r: first for r in all_uses}
    # Step 2: Fixpoint computation
    # Note, we look at the first bbl first
    active = list(reversed(fun.bbls))
    while active:
        bbl = active.pop(-1)
        if not UpdateReachingDefsOut(all_defs[bbl.name]):
            continue
        new_out = all_defs[bbl.name].defs_out
        for succ in bbl.edge_out:
            succ_in = all_defs[succ.name].defs_in
            change = _MergeReachingDefs(succ_in, new_out, succ)
            if change:
                active.append(succ)

    # Step 3: Make analysis results accessible
    for bbl in fun.bbls:
        if bbl != fun.bbls[0] and not bbl.edge_in:
            bbl_str = '\n'.join(serialize.BblRenderToAsm(bbl))
            assert False, f"found unreachable bbl in fun {fun.name}:\n{bbl_str}"
        _BblPropagateDefs(bbl, all_defs[bbl.name].defs_in.copy())
Пример #7
0
def DumpBbl(bbl):
    lines = serialize.BblRenderToAsm(bbl)
    print(lines.pop(0))
    for n, l in enumerate(lines):
        print(f"{n:2d}", l)
Пример #8
0
def DumpBbl(bbl: ir.Bbl):
    print("\n".join(serialize.BblRenderToAsm(bbl)))
Пример #9
0
def _DumpBblWithLineNumbers(bbl):
    lines = serialize.BblRenderToAsm(bbl)
    print(lines.pop(0))
    for n, l in enumerate(lines):
        print(f"{n:2d}", l)