def __init__(self, T, entries, pipe=False, flow=False):
        assert entries >= 0
        self.io = m.IO(
            # Flipped since enq/deq is from perspective of the client
            enq=m.DeqIO[T],
            deq=m.EnqIO[T],
            count=m.Out(m.UInt[m.bitutils.clog2(entries + 1)])) + m.ClockIO()

        ram = m.Memory(entries, T)()
        enq_ptr = mantle.CounterModM(entries,
                                     entries.bit_length(),
                                     has_ce=True,
                                     cout=False)
        deq_ptr = mantle.CounterModM(entries,
                                     entries.bit_length(),
                                     has_ce=True,
                                     cout=False)
        maybe_full = m.Register(init=False, has_enable=True)()

        ptr_match = enq_ptr.O == deq_ptr.O
        empty = ptr_match & ~maybe_full.O
        full = ptr_match & maybe_full.O

        self.io.deq.valid @= ~empty
        self.io.enq.ready @= ~full

        do_enq = self.io.enq.fired()
        do_deq = self.io.deq.fired()

        ram.write(self.io.enq.data, enq_ptr.O[:-1], m.enable(do_enq))

        enq_ptr.CE @= m.enable(do_enq)
        deq_ptr.CE @= m.enable(do_deq)

        maybe_full.I @= m.enable(do_enq)
        maybe_full.CE @= m.enable(do_enq != do_deq)
        self.io.deq.data @= ram[deq_ptr.O[:-1]]

        if flow:
            raise NotImplementedError()
        if pipe:
            raise NotImplementedError()

        def ispow2(n):
            return (n & (n - 1) == 0) and n != 0

        count_len = len(self.io.count)
        if ispow2(entries):
            self.io.count @= m.mux([m.bits(0, count_len), entries],
                                   maybe_full.O & ptr_match)
        else:
            ptr_diff = enq_ptr.O - deq_ptr.O
            self.io.count @= m.mux([
                m.mux([m.bits(0, count_len), entries], maybe_full.O),
                m.mux([ptr_diff, entries + ptr_diff], deq_ptr.O > enq_ptr.O)
            ], ptr_match)
Exemple #2
0
def writeport(addr_width, width, regs, WADDR, I, WE):
    n = 1 << addr_width

    decoder = Decoder(addr_width)
    enable = And(2,n)
    enable(decoder(WADDR), repeat(WE, n))

    for i in range(n):
        regs[i](I, CE=m.enable(enable.O[i]))
Exemple #3
0
    class Main(m.Circuit):
        io = m.IO(clocks=ClocksT, count=m.Out(m.UInt[3]))
        count = m.Register(m.UInt[3])()
        count.CLK @= io.clocks.clk0
        io.count @= count(count.O + 1)

        tff = m.Register(m.Bit, has_enable=True)()
        tff.CLK @= io.clocks.clk0
        tff.CE @= m.enable(count.O == 3)
        io.clocks.clk1 @= m.clock(tff(tff.O ^ 1))
Exemple #4
0
    def __init__(self, T, entries, with_bug=False):
        assert entries >= 0
        self.io = m.IO(
            # Flipped since enq/deq is from perspective of the client
            enq=m.DeqIO[T],
            deq=m.EnqIO[T]) + m.ClockIO()

        ram = m.Memory(entries, T)()
        enq_ptr = mantle.CounterModM(entries,
                                     entries.bit_length(),
                                     has_ce=True,
                                     cout=False)
        deq_ptr = mantle.CounterModM(entries,
                                     entries.bit_length(),
                                     has_ce=True,
                                     cout=False)
        maybe_full = m.Register(init=False, has_enable=True)()

        ptr_match = enq_ptr.O == deq_ptr.O
        empty = ptr_match & ~maybe_full.O
        if with_bug:
            # never full
            full = False
        else:
            full = ptr_match & maybe_full.O

        self.io.deq.valid @= ~empty
        self.io.enq.ready @= ~full

        do_enq = self.io.enq.fired()
        do_deq = self.io.deq.fired()

        ram.write(self.io.enq.data, enq_ptr.O[:-1], m.enable(do_enq))

        enq_ptr.CE @= m.enable(do_enq)
        deq_ptr.CE @= m.enable(do_deq)

        maybe_full.I @= m.enable(do_enq)
        maybe_full.CE @= m.enable(do_enq != do_deq)
        self.io.deq.data @= ram[deq_ptr.O[:-1]]

        def ispow2(n):
            return (n & (n - 1) == 0) and n != 0
Exemple #5
0
def test_fdce():
    main = DefineCircuit('main', 'I', In(Bit), "O", Out(Bit), "CLK", In(Clock))
    dff = FDCE()
    wire(m.enable(1), dff.CE)
    wire(0, dff.CLR)
    wire(main.I, dff.D)
    wire(dff.Q, main.O)
    EndCircuit()

    print(compile(main))
    print(repr(main))
Exemple #6
0
def test_clock():
    assert isinstance(clock(0), ClockType)
    assert isinstance(clock(1), ClockType)
    assert isinstance(clock(VCC), ClockType)
    assert isinstance(clock(GND), ClockType)
    assert isinstance(clock(bit(0)), ClockType)
    assert isinstance(clock(clock(0)), ClockType)
    assert isinstance(clock(reset(0)), ClockType)
    assert isinstance(clock(enable(0)), ClockType)
    assert isinstance(clock(bits(0, 1)), ClockType)
    assert isinstance(clock(uint(0, 1)), ClockType)
    assert isinstance(clock(sint(0, 1)), ClockType)
Exemple #7
0
def test_reset():
    assert isinstance(reset(0), ResetType)
    assert isinstance(reset(1), ResetType)
    assert isinstance(reset(VCC), ResetType)
    assert isinstance(reset(GND), ResetType)
    assert isinstance(reset(bit(0)), ResetType)
    assert isinstance(reset(clock(0)), ResetType)
    assert isinstance(reset(enable(0)), ResetType)
    assert isinstance(reset(reset(0)), ResetType)
    assert isinstance(reset(bits(0, 1)), ResetType)
    assert isinstance(reset(uint(0, 1)), ResetType)
    assert isinstance(reset(sint(0, 1)), ResetType)
Exemple #8
0
 def __init__(self, x_len: int):
     self.io = io = m.IO(
         raddr1=m.In(m.UInt[5]),
         raddr2=m.In(m.UInt[5]),
         rdata1=m.Out(m.UInt[x_len]),
         rdata2=m.Out(m.UInt[x_len]),
         wen=m.In(m.Enable),
         waddr=m.In(m.UInt[5]),
         wdata=m.In(m.UInt[x_len])
     ) + m.ClockIO(has_reset=True)
     regs = RegFileBuilder("reg_file", 32, x_len, write_forward=False,
                           reset_type=m.Reset, backend="verilog")
     io.rdata1 @= m.mux([0, regs[io.raddr1]], io.raddr1.reduce_or())
     io.rdata2 @= m.mux([0, regs[io.raddr2]], io.raddr2.reduce_or())
     wen = m.bit(io.wen) & io.waddr.reduce_or()
     regs.write(io.waddr, io.wdata, enable=m.enable(wen))
Exemple #9
0
    def __init__(self, height, array_length, T, read_latency, has_read_enable):
        addr_width = m.bitutils.clog2(height)
        self.io = m.IO(
            RADDR=m.In(m.Bits[addr_width]),
            RDATA=m.Out(m.Array[array_length, T]),
        ) + m.ClockIO()
        if has_read_enable:
            self.io += m.IO(RE=m.In(m.Enable))
        self.io += m.IO(WADDR=m.In(m.Bits[addr_width]),
                        WDATA=m.In(m.Array[array_length, T]),
                        WMASK=m.In(m.Bits[array_length]),
                        WE=m.In(m.Enable))
        for i in range(array_length):
            mem = m.Memory(height,
                           T,
                           read_latency,
                           has_read_enable=has_read_enable)()
            mem.RADDR @= self.io.RADDR
            if has_read_enable:
                mem.RE @= self.io.RE
            self.io.RDATA[i] @= mem.RDATA

            mem.write(self.io.WDATA[i], self.io.WADDR,
                      m.enable(m.bit(self.io.WE) & self.io.WMASK[i]))

        def read(self, addr, enable=None):
            self.RADDR @= addr
            if enable is not None:
                if not has_read_enable:
                    raise Exception("Cannot use `enable` with no read enable")
                self.RE @= enable
            return self.RDATA

        self.read = read

        def write(self, data, addr, mask, enable):
            self.WDATA @= data
            self.WADDR @= addr
            self.WMASK @= mask
            self.WE @= enable

        self.write = write
Exemple #10
0
    class _ConfigRegister(magma.Circuit):
        name = f"ConfigRegister_{width}_{addr_width}_{data_width}_{addr}"
        ports = {
            "clk": magma.In(magma.Clock),
            "reset": magma.In(magma.AsyncReset),
            "O": magma.Out(T),
            "config_addr": magma.In(magma.Bits[addr_width]),
            "config_data": magma.In(magma.Bits[data_width]),
        }
        if use_config_en:
            ports["config_en"] = magma.In(magma.Bit)
        io = magma.IO(**ports)

        reg = magma.Register(magma.Bits[width],
                             has_enable=True,
                             reset_type=magma.AsyncReset)()
        magma.wire(io.clk, reg.CLK)
        ce = (io.config_addr == magma.bits(addr, addr_width))
        magma.wire(io.reset, reg.ASYNCRESET)
        if use_config_en:
            ce = ce & io.config_en
        magma.wire(io.config_data[0:width], reg.I)
        magma.wire(magma.enable(ce), reg.CE)
        magma.wire(reg.O, io.O)
Exemple #11
0
def test_enable():
    assert isinstance(enable(0), EnableType)
    assert isinstance(enable(1), EnableType)
    assert isinstance(enable(VCC), EnableType)
    assert isinstance(enable(GND), EnableType)
    assert isinstance(enable(bit(0)), EnableType)
    assert isinstance(enable(clock(0)), EnableType)
    assert isinstance(enable(reset(0)), EnableType)
    assert isinstance(enable(enable(0)), EnableType)
    assert isinstance(enable(bits(0, 1)), EnableType)
    assert isinstance(enable(uint(0, 1)), EnableType)
    assert isinstance(enable(sint(0, 1)), EnableType)
        class DUT(m.Circuit):
            io = m.IO(done=m.Out(m.Bit)) + m.ClockIO(has_reset=True)
            core = Core(
                x_len, data_path_kwargs=m.generator.ParamDict(ImmGen=ImmGen))()
            core.host.fromhost.data.undriven()
            core.host.fromhost.valid @= False

            # reverse concat because we're using utils with chisel ordering
            _hex = [concat(*reversed(x)) for x in loadmem]
            imem = RegFileBuilder("imem",
                                  1 << 20,
                                  x_len,
                                  write_forward=False,
                                  reset_type=m.Reset,
                                  backend="verilog")
            dmem = RegFileBuilder("dmem",
                                  1 << 20,
                                  x_len,
                                  write_forward=False,
                                  reset_type=m.Reset,
                                  backend="verilog")

            INIT, RUN = False, True

            state = m.Register(init=INIT)()
            cycle = m.Register(m.UInt[32])()

            n = len(_hex)
            counter = CounterModM(n, n.bit_length(), has_ce=True)
            counter.CE @= m.enable(state.O == INIT)
            cntr, done = counter.O, counter.COUT

            iaddr = (core.icache.req.data.addr // (x_len // 8))[:20]
            daddr = (core.dcache.req.data.addr // (x_len // 8))[:20]

            dmem_data = dmem[daddr]
            imem_data = imem[iaddr]
            write = 0
            for i in range(x_len // 8):
                write |= m.zext_to(
                    m.mux([dmem_data, core.dcache.req.data.data],
                          core.dcache.req.valid
                          & core.dcache.req.data.mask[i])[8 * i:8 * (i + 1)],
                    32) << (8 * i)

            core.RESET @= m.reset(state.O == INIT)

            core.icache.resp.valid @= state.O == RUN
            core.dcache.resp.valid @= state.O == RUN

            core.icache.resp.data.data @= m.Register(
                m.UInt[x_len])()(imem_data)
            core.dcache.resp.data.data @= m.Register(
                m.UInt[x_len])()(dmem_data)

            chunk = m.mux(_hex, cntr)

            imem.write(m.zext_to(cntr, 20), chunk, m.enable(state.O == INIT))

            dmem.write(
                m.mux([m.zext_to(cntr, 20), daddr], state.O == INIT),
                m.mux([chunk, write], state.O == INIT),
                m.enable((state.O == INIT)
                         | (core.dcache.req.valid
                            & core.dcache.req.data.mask.reduce_or())))

            @m.inline_combinational()
            def logic():
                state.I @= state.O
                cycle.I @= cycle.O
                if state.O == INIT:
                    if done:
                        state.I @= RUN
                if state.O == RUN:
                    cycle.I @= cycle.O + 1

            debug = False
            if debug:
                m.display("LOADMEM[%x] <= %x", cntr * (x_len // 8),
                          chunk).when(m.posedge(io.CLK)).if_(state.O == INIT)

                m.display("INST[%x] => %x",
                          iaddr * (x_len // 8), dmem_data).when(
                              m.posedge(io.CLK)).if_((state.O == RUN)
                                                     & core.icache.req.valid)

                m.display("MEM[%x] <= %x", daddr * (x_len // 8), write).when(
                    m.posedge(
                        io.CLK)).if_((state.O == RUN) & core.dcache.req.valid
                                     & core.dcache.req.data.mask.reduce_or())

                m.display(
                    "MEM[%x] => %x", daddr * (x_len // 8),
                    dmem_data).when(m.posedge(
                        io.CLK)).if_((state.O == RUN) & core.dcache.req.valid
                                     & ~core.dcache.req.data.mask.reduce_or())

                m.display("cycles: %d", cycle.O).when(m.posedge(
                    io.CLK)).if_(io.done.value() == 1)
            f.assert_immediate(cycle.O < test.maxcycles)
            io.done @= core.host.tohost != 0
            f.assert_immediate(
                (core.host.tohost >> 1) == 0,
                failure_msg=("* tohost: %d *", core.host.tohost))
    def __init__(self, x_len, n_ways: int, n_sets: int, b_bytes: int):
        nasti_params = NastiParameters(data_bits=64,
                                       addr_bits=x_len,
                                       id_bits=5)

        self.io = m.IO(req=m.Consumer(m.Decoupled[make_CacheReq(x_len)]),
                       resp=m.Producer(m.Decoupled[make_CacheResp(x_len)]),
                       nasti=make_NastiIO(nasti_params)) + m.ClockIO()
        size = m.bitutils.clog2(nasti_params.x_data_bits)
        b_bits = b_bytes << 3
        b_len = m.bitutils.clog2(b_bytes)
        s_len = m.bitutils.clog2(n_sets)
        t_len = x_len - (s_len + b_len)
        nasti_params = NastiParameters(data_bits=64,
                                       addr_bits=x_len,
                                       id_bits=5)
        data_beats = b_bits // nasti_params.x_data_bits
        length = data_beats - 1

        data = m.Memory(n_sets, m.UInt[b_bits])()
        tags = m.Memory(n_sets, m.UInt[t_len])()
        v = m.Memory(n_sets, m.Bit)()
        d = m.Memory(n_sets, m.Bit)()

        req = self.io.req.data
        tag = (req.addr >> (b_len + s_len))[:t_len]
        idx = req.addr[b_len:b_len + s_len]
        off = req.addr[:b_len]
        read = data.read(idx)
        write = m.bits(0, b_bits)
        for i in range(b_bytes):
            write |= m.mux([(read & (0xff << (8 * i))),
                            ((m.zext_to(req.data, b_bits) >>
                              ((8 * (i & 0x3)))) & 0xff) << (8 * i)],
                           ((off // 4) == (i // 4)) & (req.mask >>
                                                       (i & 0x3))[0])[:b_bits]

        class State(m.Enum):
            IDLE = 0
            WRITE = 1
            WRITE_ACK = 2
            READ = 3

        state = m.Register(init=State.IDLE)()

        write_counter = mantle.CounterModM(data_beats,
                                           max(data_beats.bit_length(), 1),
                                           has_ce=True)
        write_counter.CE @= m.enable(state.O == State.WRITE)
        w_cnt, w_done = write_counter.O, write_counter.COUT

        read_counter = mantle.CounterModM(data_beats,
                                          max(data_beats.bit_length(), 1),
                                          has_ce=True)
        read_counter.CE @= m.enable((state.O == State.READ)
                                    & self.io.nasti.r.valid)
        r_cnt, r_done = read_counter.O, read_counter.COUT

        self.io.resp.data.data @= (read >> (m.zext_to(
            (off // 4), b_bits) * x_len))[:x_len]
        self.io.nasti.ar.data @= NastiReadAddressChannel(
            nasti_params, 0, (req.addr >> b_len) << b_len, size, length)
        tags_rdata = tags.read(idx)
        self.io.nasti.aw.data @= NastiWriteAddressChannel(
            nasti_params, 0,
            m.bits(m.concat(idx, tags_rdata), nasti_params.x_addr_bits) <<
            b_len, size, length)
        self.io.nasti.w.data @= NastiWriteDataChannel(
            nasti_params,
            (read >> (m.zext_to(w_cnt, b_bits) *
                      nasti_params.x_data_bits))[:nasti_params.x_data_bits],
            None, w_done)
        self.io.nasti.w.valid @= state.O == State.WRITE
        self.io.nasti.b.ready @= state.O == State.WRITE_ACK
        self.io.nasti.r.ready @= state.O == State.READ

        d_wen = m.Bit(name="d_wen")
        d.write(True, idx, m.enable(d_wen))

        data_wen = m.Bit(name="data_wen")
        data_wdata = m.UInt[b_bits](name="data_wdata")
        data.write(data_wdata, idx, m.enable(data_wen))
        # m.display("data_wdata=%x", data_wdata).when(m.posedge(self.io.CLK))

        v_wen = m.Bit(name="v_wen")
        v.write(True, idx, m.enable(v_wen))
        v_rdata = v.read(idx)

        tags_wen = m.Bit(name="tags_wen")
        tags.write(tag, idx, m.enable(tags_wen))

        d_rdata = d.read(idx)

        # m.display("gold_state=%x", state.O).when(m.posedge(self.io.CLK))
        # m.display("gold_w_done=%x", w_done).when(m.posedge(self.io.CLK))
        # m.display("gold_b_valid=%x",
        #           self.io.nasti.b.valid).when(m.posedge(self.io.CLK))

        if TRACE:
            m.display(
                "[%0t] [cache] data[%x] <= %x, off: %x, req: %x, mask: %b",
                m.time(), idx, write, off, self.io.req.data.data,
                self.io.req.data.mask)\
                .when(m.posedge(self.io.CLK))\
                .if_((state.O == State.IDLE) &
                     (self.io.req.valid & self.io.resp.ready) &
                     (v_rdata & (tags_rdata == tag)) & req.mask.reduce_or())

            m.display(
                "[%0t] [cache] data[%x] => %x, off: %x, resp: %x", m.time(),
                idx, write, off, self.io.resp.data.data.value())\
                .when(m.posedge(self.io.CLK))\
                .if_((state.O == State.IDLE) &
                     (self.io.req.valid & self.io.resp.ready) &
                     (v_rdata & (tags_rdata == tag)) & ~req.mask.reduce_or())

        @m.inline_combinational()
        def logic():
            self.io.resp.valid @= False
            self.io.req.ready @= False
            self.io.nasti.ar.valid @= False
            self.io.nasti.aw.valid @= False

            d_wen @= False

            data_wen @= False
            data_wdata @= m.UInt[b_bits](0)
            state.I @= state.O

            tags_wen @= False
            v_wen @= False

            if state.O == State.IDLE:
                if self.io.req.valid & self.io.resp.ready:
                    if v_rdata & (tags_rdata == tag):
                        if req.mask.reduce_or():
                            d_wen @= True
                            data_wdata @= write
                            data_wen @= True
                        self.io.req.ready @= True
                        self.io.resp.valid @= True
                    else:
                        if d_rdata:
                            self.io.nasti.aw.valid @= True
                            state.I @= State.WRITE
                        else:
                            data_wdata @= 0
                            data_wen @= True
                            self.io.nasti.ar.valid @= True
                            state.I @= State.READ
            elif state.O == State.WRITE:
                if w_done:
                    state.I @= State.WRITE_ACK
            elif state.O == State.WRITE_ACK:
                if self.io.nasti.b.valid:
                    data_wdata @= 0
                    data_wen @= True
                    self.io.nasti.ar.valid @= True
                    state.I @= State.READ
            elif state.O == State.READ:
                if self.io.nasti.r.valid:
                    data_wdata @= read | (
                        m.zext_to(self.io.nasti.r.data.data, b_bits) <<
                        (m.zext_to(r_cnt, b_bits) * nasti_params.x_data_bits))
                    data_wen @= True
                if r_done:
                    tags_wen @= True
                    v_wen @= True
                    state.I @= State.IDLE
    class DUT(m.Circuit):
        io = m.IO(done=m.Out(m.Bit)) + m.ClockIO()
        x_len = 32
        n_sets = 256
        b_bytes = 4 * (x_len >> 3)
        b_len = m.bitutils.clog2(b_bytes)
        s_len = m.bitutils.clog2(n_sets)
        t_len = x_len - (s_len + b_len)
        nasti_params = NastiParameters(data_bits=64,
                                       addr_bits=x_len,
                                       id_bits=5)

        dut = Cache(x_len, 1, n_sets, b_bytes)()
        dut_mem = make_NastiIO(nasti_params).undirected_t(name="dut_mem")
        dut_mem.ar @= make_Queue(dut.nasti.ar, 32)
        dut_mem.aw @= make_Queue(dut.nasti.aw, 32)
        dut_mem.w @= make_Queue(dut.nasti.w, 32)
        dut.nasti.b @= make_Queue(dut_mem.b, 32)
        dut.nasti.r @= make_Queue(dut_mem.r, 32)

        gold = GoldCache(x_len, 1, n_sets, b_bytes)()
        gold_req = type(gold.req).undirected_t(name="gold_req")
        gold_resp = type(gold.resp).undirected_t(name="gold_resp")
        gold_mem = make_NastiIO(nasti_params).undirected_t(name="gold_mem")
        gold.req @= make_Queue(gold_req, 32)
        gold_resp @= make_Queue(gold.resp, 32)
        gold_mem.ar @= make_Queue(gold.nasti.ar, 32)
        gold_mem.aw @= make_Queue(gold.nasti.aw, 32)
        gold_mem.w @= make_Queue(gold.nasti.w, 32)
        gold.nasti.b @= make_Queue(gold_mem.b, 32)
        gold.nasti.r @= make_Queue(gold_mem.r, 32)

        size = m.bitutils.clog2(nasti_params.x_data_bits // 8)
        b_bits = b_bytes << 3
        data_beats = b_bits // nasti_params.x_data_bits

        mem = m.Memory(1 << 20, m.UInt[nasti_params.x_data_bits])()

        class MemState(m.Enum):
            IDLE = 0
            WRITE = 1
            WRITE_ACK = 2
            READ = 3

        mem_state = m.Register(init=MemState.IDLE)()

        write_counter = mantle.CounterModM(data_beats,
                                           data_beats.bit_length(),
                                           has_ce=True)
        write_counter.CE @= m.enable((mem_state.O == MemState.WRITE)
                                     & dut_mem.w.valid & gold_mem.w.valid)
        read_counter = mantle.CounterModM(data_beats,
                                          data_beats.bit_length(),
                                          has_ce=True)
        read_counter.CE @= m.enable((mem_state.O == MemState.READ)
                                    & dut_mem.r.ready & gold_mem.r.ready)

        dut_mem.b.valid @= mem_state.O == MemState.WRITE_ACK
        dut_mem.b.data @= NastiWriteResponseChannel(nasti_params, 0)
        dut_mem.r.valid @= mem_state.O == MemState.READ
        dut_mem.r.data @= NastiReadDataChannel(
            nasti_params, 0,
            mem.read(
                ((gold_mem.ar.data.addr) +
                 m.zext_to(read_counter.O, nasti_params.x_addr_bits))[:20]),
            read_counter.COUT)
        gold_mem.ar.ready @= dut_mem.ar.ready
        gold_mem.aw.ready @= dut_mem.aw.ready
        gold_mem.w.ready @= dut_mem.w.ready
        gold_mem.b.valid @= dut_mem.b.valid
        gold_mem.b.data @= dut_mem.b.data
        gold_mem.r.valid @= dut_mem.r.valid
        gold_mem.r.data @= dut_mem.r.data

        mem_wen0 = m.Bit(name="mem_wen0")
        mem_wdata0 = m.UInt[nasti_params.x_data_bits](name="mem_wdata0")
        mem_wen1 = m.Bit(name="mem_wen1")
        mem_wdata1 = m.UInt[nasti_params.x_data_bits](name="mem_wdata1")
        mem_waddr1 = m.UInt[20](name="mem_waddr1")
        mem.write(
            m.mux([dut_mem.w.data.data, mem_wdata1], mem_wen1),
            m.mux([((dut_mem.aw.data.addr) +
                    m.zext_to(write_counter.O, nasti_params.x_addr_bits))[:20],
                   mem_waddr1], mem_wen1), m.enable(mem_wen0 | mem_wen1))
        # m.display("mem_wen0 = %x, mem_wen1 = %x", mem_wen0,
        #           mem_wen1).when(m.posedge(io.CLK))
        # m.display("dut_mem.w.valid = %x",
        #           dut_mem.w.valid).when(m.posedge(io.CLK))
        # m.display("gold_mem.w.valid = %x",
        #           gold_mem.w.valid).when(m.posedge(io.CLK))

        f.assert_immediate(
            (mem_state.O != MemState.IDLE)
            | ~(gold_mem.aw.valid & dut_mem.aw.valid) |
            (dut_mem.aw.data.addr == gold_mem.aw.data.addr),
            failure_msg=(
                "[dut_mem.aw.data.addr] %x != [gold_mem.aw.data.addr] %x",
                dut_mem.aw.data.addr, gold_mem.aw.data.addr))

        f.assert_immediate(
            (mem_state.O != MemState.IDLE)
            | ~(gold_mem.aw.valid & dut_mem.aw.valid)
            | ~(gold_mem.ar.valid & dut_mem.ar.valid) |
            (dut_mem.ar.data.addr == gold_mem.ar.data.addr),
            failure_msg=(
                "[dut_mem.ar.data.addr] %x != [gold_mem.ar.data.addr] %x",
                dut_mem.ar.data.addr, gold_mem.ar.data.addr))

        f.assert_immediate(
            (mem_state.O != MemState.WRITE)
            | ~(gold_mem.w.valid & dut_mem.w.valid) |
            (dut_mem.w.data.data == gold_mem.w.data.data),
            failure_msg=(
                "[dut_mem.w.data.data] %x != [gold_mem.w.data.data] %x",
                dut_mem.w.data.data, gold_mem.w.data.data))

        @m.inline_combinational()
        def mem_fsm():
            dut_mem.w.ready @= False
            dut_mem.aw.ready @= False
            dut_mem.ar.ready @= False

            mem_wen0 @= False

            mem_state.I @= mem_state.O

            if mem_state.O == MemState.IDLE:
                if gold_mem.aw.valid & dut_mem.aw.valid:
                    mem_state.I @= MemState.WRITE
                elif gold_mem.ar.valid & dut_mem.ar.valid:
                    mem_state.I @= MemState.READ
            elif mem_state.O == MemState.WRITE:
                if gold_mem.w.valid & dut_mem.w.valid:
                    mem_wen0 @= True
                    dut_mem.w.ready @= True
                if write_counter.COUT:
                    dut_mem.aw.ready @= True
                    mem_state.I @= MemState.WRITE_ACK
            elif mem_state.O == MemState.WRITE_ACK:
                if gold_mem.b.ready & dut_mem.b.ready:
                    mem_state.I @= MemState.IDLE
            elif mem_state.O == MemState.READ:
                if read_counter.COUT:
                    dut_mem.ar.ready @= True
                    mem_state.I @= MemState.IDLE

        if TRACE:
            m.display("[%0t]: [write] mem[%x] <= %x", m.time(),
                      mem.WADDR.value(), dut_mem.w.data.data).when(
                          m.posedge(io.CLK)).if_(mem_wen0)
            m.display("[%0t]: [read] mem[%x] => %x", m.time(),
                      mem.RADDR.value(),
                      dut_mem.r.data.data).when(m.posedge(
                          io.CLK)).if_((mem_state.O == MemState.READ)
                                       & dut_mem.r.ready & gold_mem.r.ready)

        def rand_data(nasti_params):
            rand_data = BitVector[nasti_params.x_data_bits](0)
            for i in range(nasti_params.x_data_bits // 8):
                rand_data |= BitVector[nasti_params.x_data_bits](
                    random.randint(0, 0xff) << (8 * i))
            return rand_data

        def rand_mask(x_len):
            return BitVector[x_len // 8](random.randint(
                1, (1 << (x_len // 8)) - 2))

        def make_test(rand_data, nasti_params, x_len):
            # Wrapper because function definition in side class namespace
            # doesn't inherit class variables
            def test(b_bits, tag, idx, off, mask=BitVector[x_len // 8](0)):
                test_data = rand_data(nasti_params)
                for i in range((b_bits // nasti_params.x_data_bits) - 1):
                    test_data = test_data.concat(rand_data(nasti_params))
                return m.uint(m.concat(off, idx, tag, test_data, mask))

            return test

        test = make_test(rand_data, nasti_params, x_len)

        tags = []
        for _ in range(3):
            tags.append(BitVector.random(t_len))
        idxs = []
        for _ in range(2):
            idxs.append(BitVector.random(s_len))
        offs = []
        for _ in range(6):
            offs.append(BitVector.random(b_len) & -4)

        init_addr = []
        init_data = []
        _iter = itertools.product(tags, idxs, range(0, data_beats))
        for tag, idx, off in _iter:
            init_addr.append(m.uint(m.concat(BitVector[b_len](off), idx, tag)))
            init_data.append(rand_data(nasti_params))

        test_vec = [
            test(b_bits, tags[0], idxs[0], offs[0]),  # 0: read miss
            test(b_bits, tags[0], idxs[0], offs[1]),  # 1: read hit
            test(b_bits, tags[1], idxs[0], offs[0]),  # 2: read miss
            test(b_bits, tags[1], idxs[0], offs[2]),  # 3: read hit
            test(b_bits, tags[1], idxs[0], offs[3]),  # 4: read hit
            test(b_bits, tags[1], idxs[0], offs[4],
                 rand_mask(x_len)),  # 5: write hit  # noqa
            test(b_bits, tags[1], idxs[0], offs[4]),  # 6: read hit
            test(b_bits, tags[2], idxs[0],
                 offs[5]),  # 7: read miss & write back  # noqa
            test(b_bits, tags[0], idxs[1], offs[0],
                 rand_mask(x_len)),  # 8: write miss  # noqa
            test(b_bits, tags[0], idxs[1], offs[0]),  # 9: read hit
            test(b_bits, tags[0], idxs[1], offs[1]),  # 10: read hit
            test(b_bits, tags[1], idxs[1], offs[2],
                 rand_mask(x_len)),  # 11: write miss & write back  # noqa
            test(b_bits, tags[1], idxs[1], offs[3]),  # 12: read hit
            test(b_bits, tags[2], idxs[1], offs[4]),  # 13: read write back
            test(b_bits, tags[2], idxs[1], offs[5])  # 14: read hit
        ]

        class TestState(m.Enum):
            INIT = 0
            START = 1
            WAIT = 2
            DONE = 3

        state = m.Register(init=TestState.INIT)()
        timeout = m.Register(m.UInt[32])()
        init_m = len(init_addr) - 1
        init_counter = mantle.CounterModM(init_m,
                                          init_m.bit_length(),
                                          has_ce=True)
        init_counter.CE @= m.enable(state.O == TestState.INIT)

        test_m = len(test_vec) - 1
        test_counter = mantle.CounterModM(test_m,
                                          test_m.bit_length(),
                                          has_ce=True)
        test_counter.CE @= m.enable(state.O == TestState.DONE)
        curr_vec = m.mux(test_vec, test_counter.O)
        mask = (curr_vec >> (b_len + s_len + t_len + b_bits))[:x_len // 8]
        data = (curr_vec >> (b_len + s_len + t_len))[:b_bits]
        tag = (curr_vec >> (b_len + s_len))[:t_len]
        idx = (curr_vec >> b_len)[:s_len]
        off = curr_vec[:b_len]

        dut.cpu.req.data.addr @= m.concat(off, idx, tag)
        # TODO: Is truncating this fine?
        req_data = data[:x_len]
        dut.cpu.req.data.data @= req_data
        dut.cpu.req.data.mask @= mask
        dut.cpu.req.valid @= state.O == TestState.WAIT
        dut.cpu.abort @= 0
        gold_req.data @= dut.cpu.req.data.value()
        gold_req.valid @= state.O == TestState.START
        gold_resp.ready @= state.O == TestState.DONE

        mem_waddr1 @= m.mux(init_addr, init_counter.O)[:20]
        mem_wdata1 @= m.mux(init_data, init_counter.O)

        check_resp_data = m.Bit()
        if TRACE:
            m.display("[%0t]: [init] mem[%x] <= %x", m.time(),
                      mem_waddr1, mem_wdata1)\
                .when(m.posedge(io.CLK))\
                .if_(state.O == TestState.INIT)

        @m.inline_combinational()
        def state_fsm():
            timeout.I @= timeout.O
            mem_wen1 @= m.bit(False)
            check_resp_data @= m.bit(False)
            state.I @= state.O
            if state.O == TestState.INIT:
                mem_wen1 @= m.bit(True)
                if init_counter.COUT:
                    state.I @= TestState.START
            elif state.O == TestState.START:
                if gold_req.ready:
                    timeout.I @= m.bits(0, 32)
                    state.I @= TestState.WAIT
            elif state.O == TestState.WAIT:
                timeout.I @= timeout.O + 1
                if dut.cpu.resp.valid & gold_resp.valid:
                    if ~mask.reduce_or():
                        check_resp_data @= m.bit(True)
                    state.I @= TestState.DONE
            elif state.O == TestState.DONE:
                state.I @= TestState.START

        f.assert_immediate((state.O != TestState.WAIT) | (timeout.O < 100))
        f.assert_immediate(
            ~check_resp_data | (dut.cpu.resp.data.data == gold_resp.data.data),
            failure_msg=("dut.cpu.resp.data.data => %x != %x",
                         dut.cpu.resp.data.data, gold_resp.data.data))
        # m.display("mem_state=%x", mem_state.O).when(m.posedge(io.CLK))
        # m.display("test_state=%x", state.O).when(m.posedge(io.CLK))
        # m.display("dut req valid = %x",
        #           dut.cpu.req.valid).when(m.posedge(io.CLK))
        # m.display("gold req valid = %x, ready = %x", gold_req.valid,
        #           gold_req.ready).when(m.posedge(io.CLK))
        # m.display("[%0t]: dut resp data = %x, gold resp data = %x", m.time(),
        #           dut.cpu.resp.data.data, gold_resp.data.data)\
        #     .when(m.posedge(io.CLK))
        io.done @= test_counter.COUT
    def __init__(self,
                 x_len,
                 ALU=ALUArea,
                 ImmGen=ImmGenWire,
                 BrCond=BrCondArea):
        self.io = make_DatapathIO(x_len) + m.ClockIO(has_reset=True)
        csr = CSRGen(x_len)()
        reg_file = RegFile(x_len)()
        alu = ALU(x_len)()
        imm_gen = ImmGen(x_len)()
        br_cond = BrCondArea(x_len)()

        # Fetch / Execute Registers
        fe_inst = m.Register(init=Instructions.NOP, has_enable=True)()
        fe_pc = m.Register(m.UInt[x_len], has_enable=True)()

        # Execute / Write Back Registers
        ew_inst = m.Register(init=Instructions.NOP)()
        ew_pc = m.Register(m.UInt[x_len])()
        ew_alu = m.Register(m.UInt[x_len])()
        csr_in = m.Register(m.UInt[x_len])()

        # Control signals
        st_type = m.Register(type(self.io.ctrl.st_type).undirected_t)()
        ld_type = m.Register(type(self.io.ctrl.ld_type).undirected_t)()
        wb_sel = m.Register(type(self.io.ctrl.wb_sel).undirected_t)()
        wb_en = m.Register(m.Bit)()
        csr_cmd = m.Register(type(self.io.ctrl.csr_cmd).undirected_t)()
        illegal = m.Register(m.Bit)()
        pc_check = m.Register(m.Bit)()

        # Fetch
        started = m.Register(m.Bit)()(m.bit(self.io.RESET))
        stall = ~self.io.icache.resp.valid | ~self.io.dcache.resp.valid
        pc = m.Register(init=UIntVector[x_len](Const.PC_START) -
                        UIntVector[x_len](4))()
        npc = m.mux([
            m.mux([
                m.mux([
                    m.mux([
                        m.mux([pc.O + m.uint(4, x_len), pc.O],
                              self.io.ctrl.pc_sel == PC_0), alu.sum_ >> 1 << 1
                    ], (self.io.ctrl.pc_sel == PC_ALU) | br_cond.taken),
                    csr.epc
                ], self.io.ctrl.pc_sel == PC_EPC), csr.evec
            ], csr.expt), pc.O
        ], stall)

        inst = m.mux([self.io.icache.resp.data.data, Instructions.NOP], started
                     | self.io.ctrl.inst_kill | br_cond.taken | csr.expt)

        pc.I @= npc
        self.io.icache.req.data.addr @= npc
        self.io.icache.req.data.data @= 0
        self.io.icache.req.data.mask @= 0
        self.io.icache.req.valid @= ~stall
        self.io.icache.abort @= False

        fe_pc.I @= pc.O
        fe_pc.CE @= m.enable(~stall)
        fe_inst.I @= inst
        fe_inst.CE @= m.enable(~stall)

        # Execute
        # Decode
        self.io.ctrl.inst @= fe_inst.O

        # reg_file read
        rs1_addr = fe_inst.O[15:20]
        rs2_addr = fe_inst.O[20:25]
        reg_file.raddr1 @= rs1_addr
        reg_file.raddr2 @= rs2_addr

        # gen immediates
        imm_gen.inst @= fe_inst.O
        imm_gen.sel @= self.io.ctrl.imm_sel

        # bypass
        wb_rd_addr = ew_inst.O[7:12]
        rs1_hazard = wb_en.O & rs1_addr.reduce_or() & (rs1_addr == wb_rd_addr)
        rs2_hazard = wb_en.O & rs2_addr.reduce_or() & (rs2_addr == wb_rd_addr)
        rs1 = m.mux([reg_file.rdata1, ew_alu.O],
                    (wb_sel.O == WB_ALU) & rs1_hazard)
        rs2 = m.mux([reg_file.rdata2, ew_alu.O],
                    (wb_sel.O == WB_ALU) & rs2_hazard)

        # ALU operations
        alu.A @= m.mux([fe_pc.O, rs1], self.io.ctrl.A_sel == A_RS1)
        alu.B @= m.mux([imm_gen.O, rs2], self.io.ctrl.B_sel == B_RS2)
        alu.op @= self.io.ctrl.alu_op

        # Branch condition calc
        br_cond.rs1 @= rs1
        br_cond.rs2 @= rs2
        br_cond.br_type @= self.io.ctrl.br_type

        # D$ access
        daddr = m.mux([alu.sum_, ew_alu.O], stall) >> 2 << 2
        w_offset = ((m.bits(alu.sum_[1], x_len) << 4) |
                    (m.bits(alu.sum_[0], x_len) << 3))
        self.io.dcache.req.valid @= ~stall & (self.io.ctrl.st_type.reduce_or()
                                              |
                                              self.io.ctrl.ld_type.reduce_or())
        self.io.dcache.req.data.addr @= daddr
        self.io.dcache.req.data.data @= rs2 << w_offset
        self.io.dcache.req.data.mask @= m.dict_lookup(
            {
                ST_SW: m.bits(0b1111, 4),
                ST_SH: m.bits(0b11, 4) << m.zext(alu.sum_[0:2], 2),
                ST_SB: m.bits(0b1, 4) << m.zext(alu.sum_[0:2], 2),
            }, m.mux([self.io.ctrl.st_type, st_type.O], stall), m.bits(0, 4))

        # Pipelining
        @m.inline_combinational()
        def pipeline_logic():
            ew_pc.I @= ew_pc.O
            ew_inst.I @= ew_inst.O
            ew_alu.I @= ew_alu.O
            csr_in.I @= csr_in.O
            st_type.I @= st_type.O
            ld_type.I @= ld_type.O
            wb_sel.I @= wb_sel.O
            wb_en.I @= wb_en.O
            csr_cmd.I @= csr_cmd.O
            illegal.I @= illegal.O
            pc_check.I @= pc_check.O
            if m.bit(self.io.RESET) | ~stall & csr.expt:
                st_type.I @= 0
                ld_type.I @= 0
                wb_en.I @= 0
                csr_cmd.I @= 0
                illegal.I @= False
                pc_check.I @= False
            elif ~stall & ~csr.expt:
                ew_pc.I @= fe_pc.O
                ew_inst.I @= fe_inst.O
                ew_alu.I @= alu.O
                csr_in.I @= m.mux([rs1, imm_gen.O],
                                  self.io.ctrl.imm_sel == IMM_Z)
                st_type.I @= self.io.ctrl.st_type
                ld_type.I @= self.io.ctrl.ld_type
                wb_sel.I @= self.io.ctrl.wb_sel
                wb_en.I @= self.io.ctrl.wb_en
                csr_cmd.I @= self.io.ctrl.csr_cmd
                illegal.I @= self.io.ctrl.illegal
                pc_check.I @= self.io.ctrl.pc_sel == PC_ALU

        # Load
        l_offset = ((m.uint(ew_alu.O[1], x_len) << 4) |
                    (m.uint(ew_alu.O[0], x_len) << 3))
        l_shift = self.io.dcache.resp.data.data >> l_offset
        load = m.dict_lookup(
            {
                LD_LH: m.sext_to(m.sint(l_shift[0:16]), x_len),
                LD_LHU: m.sint(m.zext_to(l_shift[0:16], x_len)),
                LD_LB: m.sext_to(m.sint(l_shift[0:8]), x_len),
                LD_LBU: m.sint(m.zext_to(l_shift[0:8], x_len))
            }, ld_type.O, m.sint(self.io.dcache.resp.data.data))

        # CSR access
        csr.stall @= stall
        csr.I @= csr_in.O
        csr.cmd @= csr_cmd.O
        csr.inst @= ew_inst.O
        csr.pc @= ew_pc.O
        csr.addr @= ew_alu.O
        csr.illegal @= illegal.O
        csr.pc_check @= pc_check.O
        csr.ld_type @= ld_type.O
        csr.st_type @= st_type.O
        self.io.host @= csr.host

        # Regfile write
        reg_write = m.dict_lookup(
            {
                WB_MEM: m.uint(load),
                WB_PC4: (ew_pc.O + 4),
                WB_CSR: csr.O
            }, wb_sel.O, ew_alu.O)

        reg_file.wen @= m.enable(wb_en.O & ~stall & ~csr.expt)
        reg_file.waddr @= wb_rd_addr
        reg_file.wdata @= reg_write

        # Abort store when there's an exception
        self.io.dcache.abort @= csr.expt
Exemple #16
0
    def __init__(self, x_len, n_ways: int, n_sets: int, b_bytes: int):
        b_bits = b_bytes << 3
        b_len = m.bitutils.clog2(b_bytes)
        s_len = m.bitutils.clog2(n_sets)
        t_len = x_len - (s_len + b_len)
        n_words = b_bits // x_len
        w_bytes = x_len // 8
        byte_offset_bits = m.bitutils.clog2(w_bytes)
        nasti_params = NastiParameters(data_bits=64,
                                       addr_bits=x_len,
                                       id_bits=5)
        data_beats = b_bits // nasti_params.x_data_bits

        class MetaData(m.Product):
            tag = m.UInt[t_len]

        self.io = m.IO(**make_cache_ports(x_len, nasti_params))
        self.io += m.ClockIO()

        class State(m.Enum):
            IDLE = 0
            READ_CACHE = 1
            WRITE_CACHE = 2
            WRITE_BACK = 3
            WRITE_ACK = 4
            REFILL_READY = 5
            REFILL = 6

        state = m.Register(init=State.IDLE)()

        # memory
        v = m.Register(m.UInt[n_sets], has_enable=True)()
        d = m.Register(m.UInt[n_sets], has_enable=True)()
        meta_mem = m.Memory(n_sets,
                            MetaData,
                            read_latency=1,
                            has_read_enable=True)()
        data_mem = [
            ArrayMaskMem(n_sets,
                         w_bytes,
                         m.UInt[8],
                         read_latency=1,
                         has_read_enable=True)() for _ in range(n_words)
        ]

        addr_reg = m.Register(type(self.io.cpu.req.data.addr).undirected_t,
                              has_enable=True)()
        cpu_data = m.Register(type(self.io.cpu.req.data.data).undirected_t,
                              has_enable=True)()
        cpu_mask = m.Register(type(self.io.cpu.req.data.mask).undirected_t,
                              has_enable=True)()

        self.io.nasti.r.ready @= state.O == State.REFILL
        # Counters
        assert data_beats > 0
        if data_beats > 1:
            read_counter = mantle.CounterModM(data_beats,
                                              max(data_beats.bit_length(), 1),
                                              has_ce=True)
            read_counter.CE @= m.enable(self.io.nasti.r.fired())
            read_count, read_wrap_out = read_counter.O, read_counter.COUT

            write_counter = mantle.CounterModM(data_beats,
                                               max(data_beats.bit_length(), 1),
                                               has_ce=True)
            write_count, write_wrap_out = write_counter.O, write_counter.COUT
        else:
            read_count, read_wrap_out = 0, 1
            write_count, write_wrap_out = 0, 1

        refill_buf = m.Register(m.Array[data_beats,
                                        m.UInt[nasti_params.x_data_bits]],
                                has_enable=True)()
        if data_beats == 1:
            refill_buf.I[0] @= self.io.nasti.r.data.data
        else:
            refill_buf.I @= m.set_index(refill_buf.O,
                                        self.io.nasti.r.data.data,
                                        read_count[:-1])
        refill_buf.CE @= m.enable(self.io.nasti.r.fired())

        is_idle = state.O == State.IDLE
        is_read = state.O == State.READ_CACHE
        is_write = state.O == State.WRITE_CACHE
        is_alloc = (state.O == State.REFILL) & read_wrap_out
        # m.display("[%0t]: is_alloc = %x", m.time(), is_alloc)\
        #     .when(m.posedge(self.io.CLK))
        is_alloc_reg = m.Register(m.Bit)()(is_alloc)

        hit = m.Bit(name="hit")
        wen = is_write & (hit | is_alloc_reg) & ~self.io.cpu.abort | is_alloc
        # m.display("[%0t]: wen = %x", m.time(), wen)\
        #     .when(m.posedge(self.io.CLK))
        ren = m.enable(~wen & (is_idle | is_read) & self.io.cpu.req.valid)
        ren_reg = m.enable(m.Register(m.Bit)()(ren))

        addr = self.io.cpu.req.data.addr
        idx = addr[b_len:s_len + b_len]
        tag_reg = addr_reg.O[s_len + b_len:x_len]
        idx_reg = addr_reg.O[b_len:s_len + b_len]
        off_reg = addr_reg.O[byte_offset_bits:b_len]

        rmeta = meta_mem.read(idx, ren)
        rdata = m.concat(*(mem.read(idx, ren) for mem in data_mem))
        rdata_buf = m.Register(type(rdata), has_enable=True)()(rdata,
                                                               CE=ren_reg)

        read = m.mux([
            m.as_bits(m.mux([rdata_buf, rdata], ren_reg)),
            m.as_bits(refill_buf.O)
        ], is_alloc_reg)
        # m.display("is_alloc_reg=%x", is_alloc_reg)\
        #     .when(m.posedge(self.io.CLK))

        hit @= v.O[idx_reg] & (rmeta.tag == tag_reg)

        # read mux
        self.io.cpu.resp.data.data @= m.array(
            [read[i * x_len:(i + 1) * x_len] for i in range(n_words)])[off_reg]
        self.io.cpu.resp.valid @= (is_idle | (is_read & hit) |
                                   (is_alloc_reg & ~cpu_mask.O.reduce_or()))
        m.display("resp.valid=%x", self.io.cpu.resp.valid.value())\
            .when(m.posedge(self.io.CLK))
        m.display("[%0t]: valid = %x", m.time(),
                  self.io.cpu.resp.valid.value())\
            .when(m.posedge(self.io.CLK))
        m.display("[%0t]: is_idle = %x, is_read = %x, hit = %x, is_alloc_reg = "
                  "%x, ~cpu_mask.O.reduce_or() = %x", m.time(), is_idle,
                  is_read, hit, is_alloc_reg, ~cpu_mask.O.reduce_or())\
            .when(m.posedge(self.io.CLK))
        m.display("[%0t]: refill_buf.O=%x, %x", m.time(), *refill_buf.O)\
            .when(m.posedge(self.io.CLK))\
            .if_(self.io.cpu.resp.valid.value() & is_alloc_reg)
        m.display("[%0t]: read=%x", m.time(), read)\
            .when(m.posedge(self.io.CLK))\
            .if_(self.io.cpu.resp.valid.value() & is_alloc_reg)

        addr_reg.I @= addr
        addr_reg.CE @= m.enable(self.io.cpu.resp.valid.value())

        cpu_data.I @= self.io.cpu.req.data.data
        cpu_data.CE @= m.enable(self.io.cpu.resp.valid.value())

        cpu_mask.I @= self.io.cpu.req.data.mask
        cpu_mask.CE @= m.enable(self.io.cpu.resp.valid.value())

        wmeta = MetaData(name="wmeta")
        wmeta.tag @= tag_reg

        offset_mask = (m.zext_to(cpu_mask.O, w_bytes * 8) << m.concat(
            m.bits(0, byte_offset_bits), off_reg))
        wmask = m.mux([m.SInt[w_bytes * 8](-1),
                       m.sint(offset_mask)], ~is_alloc)

        if len(refill_buf.O) == 1:
            wdata_alloc = self.io.nasti.r.data.data
        else:
            wdata_alloc = m.concat(
                # TODO: not sure why they use `init.reverse`
                # https://github.com/ucb-bar/riscv-mini/blob/release/src/main/scala/Cache.scala#L116
                m.concat(*refill_buf.O[:-1]),
                self.io.nasti.r.data.data)
        wdata = m.mux([wdata_alloc,
                       m.as_bits(m.repeat(cpu_data.O, n_words))], ~is_alloc)

        v.I @= m.set_index(v.O, m.bit(True), idx_reg)
        v.CE @= m.enable(wen)
        d.I @= m.set_index(d.O, ~is_alloc, idx_reg)
        d.CE @= m.enable(wen)
        # m.display("[%0t]: refill_buf.O = %x", m.time(),
        #           m.concat(*refill_buf.O)).when(m.posedge(self.io.CLK)).if_(wen)
        # m.display("[%0t]: nasti.r.data.data = %x", m.time(),
        #           self.io.nasti.r.data.data).when(m.posedge(self.io.CLK)).if_(wen)

        meta_mem.write(wmeta, idx_reg, m.enable(wen & is_alloc))
        for i, mem in enumerate(data_mem):
            data = [
                wdata[i * x_len + j * 8:i * x_len + (j + 1) * 8]
                for j in range(w_bytes)
            ]
            mem.write(m.array(data), idx_reg,
                      wmask[i * w_bytes:(i + 1) * w_bytes], m.enable(wen))
            # m.display("[%0t]: wdata = %x, %x, %x, %x", m.time(),
            #           *mem.WDATA.value()).when(m.posedge(self.io.CLK)).if_(wen)
            # m.display("[%0t]: wmask = %x, %x, %x, %x", m.time(),
            #           *mem.WMASK.value()).when(m.posedge(self.io.CLK)).if_(wen)

        tag_and_idx = m.zext_to(m.concat(idx_reg, tag_reg),
                                nasti_params.x_addr_bits)
        self.io.nasti.ar.data @= NastiReadAddressChannel(
            nasti_params, 0, tag_and_idx << m.Bits[len(tag_and_idx)](b_len),
            m.bitutils.clog2(nasti_params.x_data_bits // 8), data_beats - 1)

        rmeta_and_idx = m.zext_to(m.concat(idx_reg, rmeta.tag),
                                  nasti_params.x_addr_bits)
        self.io.nasti.aw.data @= NastiWriteAddressChannel(
            nasti_params, 0,
            rmeta_and_idx << m.Bits[len(rmeta_and_idx)](b_len),
            m.bitutils.clog2(nasti_params.x_data_bits // 8), data_beats - 1)

        self.io.nasti.w.data @= NastiWriteDataChannel(
            nasti_params,
            m.array([
                read[i * nasti_params.x_data_bits:(i + 1) *
                     nasti_params.x_data_bits] for i in range(data_beats)
            ])[write_count[:-1]], None, write_wrap_out)

        is_dirty = v.O[idx_reg] & d.O[idx_reg]

        # TODO: Have to use temporary so we can invoke `fired()`
        aw_valid = m.Bit(name="aw_valid")
        self.io.nasti.aw.valid @= aw_valid

        ar_valid = m.Bit(name="ar_valid")
        self.io.nasti.ar.valid @= ar_valid

        b_ready = m.Bit(name="b_ready")
        self.io.nasti.b.ready @= b_ready

        @m.inline_combinational()
        def logic():
            state.I @= state.O
            aw_valid @= False
            ar_valid @= False
            self.io.nasti.w.valid @= False
            b_ready @= False
            if state.O == State.IDLE:
                if self.io.cpu.req.valid:
                    if self.io.cpu.req.data.mask.reduce_or():
                        state.I @= State.WRITE_CACHE
                    else:
                        state.I @= State.READ_CACHE
            elif state.O == State.READ_CACHE:
                if hit:
                    if self.io.cpu.req.valid:
                        if self.io.cpu.req.data.mask.reduce_or():
                            state.I @= State.WRITE_CACHE
                        else:
                            state.I @= State.READ_CACHE
                    else:
                        state.I @= State.IDLE
                else:
                    aw_valid @= is_dirty
                    ar_valid @= ~is_dirty
                    if self.io.nasti.aw.fired():
                        state.I @= State.WRITE_BACK
                    elif self.io.nasti.ar.fired():
                        state.I @= State.REFILL
            elif state.O == State.WRITE_CACHE:
                if hit | is_alloc_reg | self.io.cpu.abort:
                    state.I @= State.IDLE
                else:
                    aw_valid @= is_dirty
                    ar_valid @= ~is_dirty
                    if self.io.nasti.aw.fired():
                        state.I @= State.WRITE_BACK
                    elif self.io.nasti.ar.fired():
                        state.I @= State.REFILL
            elif state.O == State.WRITE_BACK:
                self.io.nasti.w.valid @= True
                if write_wrap_out:
                    state.I @= State.WRITE_ACK
            elif state.O == State.WRITE_ACK:
                b_ready @= True
                if self.io.nasti.b.fired():
                    state.I @= State.REFILL_READY
            elif state.O == State.REFILL_READY:
                ar_valid @= True
                if self.io.nasti.ar.fired():
                    state.I @= State.REFILL
            elif state.O == State.REFILL:
                if read_wrap_out:
                    if cpu_mask.O.reduce_or():
                        state.I @= State.WRITE_CACHE
                    else:
                        state.I @= State.IDLE

        if data_beats > 1:
            # TODO: Have to do this at the end since the inline comb logic
            # wires up nasti.w
            write_counter.CE @= m.enable(self.io.nasti.w.fired())