Ejemplo n.º 1
0
    def __init__(self, T, entries, pipe=False, flow=False):
        assert entries >= 0
        self.io = m.IO(
            # Flipped since enq/deq is from perspective of the client
            enq=m.DeqIO[T],
            deq=m.EnqIO[T],
            count=m.Out(m.UInt[m.bitutils.clog2(entries + 1)])) + m.ClockIO()

        ram = m.Memory(entries, T)()
        enq_ptr = mantle.CounterModM(entries,
                                     entries.bit_length(),
                                     has_ce=True,
                                     cout=False)
        deq_ptr = mantle.CounterModM(entries,
                                     entries.bit_length(),
                                     has_ce=True,
                                     cout=False)
        maybe_full = m.Register(init=False, has_enable=True)()

        ptr_match = enq_ptr.O == deq_ptr.O
        empty = ptr_match & ~maybe_full.O
        full = ptr_match & maybe_full.O

        self.io.deq.valid @= ~empty
        self.io.enq.ready @= ~full

        do_enq = self.io.enq.fired()
        do_deq = self.io.deq.fired()

        ram.write(self.io.enq.data, enq_ptr.O[:-1], m.enable(do_enq))

        enq_ptr.CE @= m.enable(do_enq)
        deq_ptr.CE @= m.enable(do_deq)

        maybe_full.I @= m.enable(do_enq)
        maybe_full.CE @= m.enable(do_enq != do_deq)
        self.io.deq.data @= ram[deq_ptr.O[:-1]]

        if flow:
            raise NotImplementedError()
        if pipe:
            raise NotImplementedError()

        def ispow2(n):
            return (n & (n - 1) == 0) and n != 0

        count_len = len(self.io.count)
        if ispow2(entries):
            self.io.count @= m.mux([m.bits(0, count_len), entries],
                                   maybe_full.O & ptr_match)
        else:
            ptr_diff = enq_ptr.O - deq_ptr.O
            self.io.count @= m.mux([
                m.mux([m.bits(0, count_len), entries], maybe_full.O),
                m.mux([ptr_diff, entries + ptr_diff], deq_ptr.O > enq_ptr.O)
            ], ptr_match)
Ejemplo n.º 2
0
    class DUT(m.Circuit):
        io = m.IO(done=m.Out(m.Bit)) + m.ClockIO()
        imm = ImmGen(32)()
        ctrl = Control(32)()

        counter = mantle.CounterModM(len(insts), len(insts).bit_length())
        i = m.mux([iimm(i) for i in insts], counter.O)
        s = m.mux([simm(i) for i in insts], counter.O)
        b = m.mux([bimm(i) for i in insts], counter.O)
        u = m.mux([uimm(i) for i in insts], counter.O)
        j = m.mux([jimm(i) for i in insts], counter.O)
        z = m.mux([zimm(i) for i in insts], counter.O)
        x = m.mux([iimm(i) & -2 for i in insts], counter.O)

        O = m.mux([
            m.mux([
                m.mux([
                    m.mux([
                        m.mux([m.mux([x, z], ctrl.imm_sel == IMM_Z), j],
                              ctrl.imm_sel == IMM_J), u
                    ], ctrl.imm_sel == IMM_U), b
                ], ctrl.imm_sel == IMM_B), s
            ], ctrl.imm_sel == IMM_S), i
        ], ctrl.imm_sel == IMM_I)
        inst = m.mux(insts, counter.O)
        ctrl.inst @= inst
        imm.inst @= inst
        imm.sel @= ctrl.imm_sel
        io.done @= counter.COUT

        f.assert_immediate(imm.O == O,
                           failure_msg=("Counter: %d, Type: 0x%x, O: %x ?= %x",
                                        counter.O, imm.sel, imm.O, O))
        m.display("Counter: %d, Type: 0x%x, O: %x ?= %x", counter.O, imm.sel,
                  imm.O, O)
Ejemplo n.º 3
0
    def __init__(self, T, entries, with_bug=False):
        assert entries >= 0
        self.io = m.IO(
            # Flipped since enq/deq is from perspective of the client
            enq=m.DeqIO[T],
            deq=m.EnqIO[T]) + m.ClockIO()

        ram = m.Memory(entries, T)()
        enq_ptr = mantle.CounterModM(entries,
                                     entries.bit_length(),
                                     has_ce=True,
                                     cout=False)
        deq_ptr = mantle.CounterModM(entries,
                                     entries.bit_length(),
                                     has_ce=True,
                                     cout=False)
        maybe_full = m.Register(init=False, has_enable=True)()

        ptr_match = enq_ptr.O == deq_ptr.O
        empty = ptr_match & ~maybe_full.O
        if with_bug:
            # never full
            full = False
        else:
            full = ptr_match & maybe_full.O

        self.io.deq.valid @= ~empty
        self.io.enq.ready @= ~full

        do_enq = self.io.enq.fired()
        do_deq = self.io.deq.fired()

        ram.write(self.io.enq.data, enq_ptr.O[:-1], m.enable(do_enq))

        enq_ptr.CE @= m.enable(do_enq)
        deq_ptr.CE @= m.enable(do_deq)

        maybe_full.I @= m.enable(do_enq)
        maybe_full.CE @= m.enable(do_enq != do_deq)
        self.io.deq.data @= ram[deq_ptr.O[:-1]]

        def ispow2(n):
            return (n & (n - 1) == 0) and n != 0
Ejemplo n.º 4
0
    def definition(io):
        load = io.LOAD
        baud = rising(io.SCK) | falling(io.SCK)

        valid_counter = mantle.CounterModM(buf_size, 12, has_ce=True)
        m.wire(load & baud, valid_counter.CE)

        valid_list = [wi * (b - 1) + i * a - 1
                      for i in range(1, wo + 1)]  # len = 32

        valid = m.GND

        for i in valid_list:
            valid = valid | mantle.Decode(i, 12)(valid_counter.O)

        # register on input
        st_in = mantle.Register(width, has_ce=True)
        st_in(io.DATA)
        m.wire(load, st_in.CE)

        # --------------------------DOWNSCALING----------------------------- #
        # downscale the image from 352x288 to 32x32
        Downscale = m.DeclareCircuit(
            'Downscale', "I_0_0",
            m.In(m.Array(1, m.Array(1, m.Array(width, m.Bit)))), "WE",
            m.In(m.Bit), "CLK", m.In(m.Clock), "O",
            m.Out(m.Array(width, m.Bit)), "V", m.Out(m.Bit))

        dscale = Downscale()

        m.wire(st_in.O, dscale.I_0_0[0][0])
        m.wire(1, dscale.WE)
        m.wire(load, dscale.CLK)

        add16 = mantle.Add(width)  # needed for Add16 definition

        # threshold the downscale output
        px_bit = mantle.ULE(16)(dscale.O, m.uint(THRESH, 16)) & valid

        # ---------------------------UART OUTPUT----------------------------- #

        m.wire(px_bit, io.O)
        m.wire(valid, io.VALID)
Ejemplo n.º 5
0
hx8kboard.Clock.on()

hx8kboard.J2[9].output().on()
hx8kboard.J2[10].output().on()
hx8kboard.J2[11].output().on()
hx8kboard.J2[12].output().on()

main = hx8kboard.main()

# "test" data
init = [m.uint(i, 16) for i in range(16)]
printf = mantle.Counter(4, has_ce=True)
rom = ROM16(4, init, printf.O)

# baud for uart output
clock = mantle.CounterModM(103, 8)
baud = clock.COUT

bit_counter = mantle.Counter(5, has_ce=True)
m.wire(baud, bit_counter.CE)

load = mantle.Decode(0, 5)(bit_counter.O)

valid_counter = mantle.CounterModM(buf_size, 13, has_ce=True)
m.wire(load & baud, valid_counter.CE)

valid_list = [wi * (b - 1) + i * a - 1 for i in range(1, wo + 1)]  # len = 16

valid = m.GND

for i in valid_list:
Ejemplo n.º 6
0
    def definition(cam):
        edge_f = falling(cam.SCK)
        edge_r = rising(cam.SCK)

        # ROM to store commands
        rom_index = mantle.Counter(4, has_ce=True)
        rom = ROM16(4, init, rom_index.O)

        # Message length is 16 bits, setup counter to generate done signal
        # after EOM
        done_counter = mantle.Counter(5, has_ce=True, has_reset=True)
        count = done_counter.O
        done = mantle.Decode(16, 5)(count)

        # State machine to generate run signal (enable)
        run = mantle.DFF(has_ce=True)
        run_n = mantle.LUT3([0, 0, 1, 0, 1, 0, 1, 0])
        run_n(done, trigger, run)
        run(run_n)
        m.wire(edge_f, run.CE)

        # Reset the message length counter after done
        run_reset = mantle.LUT2(I0 | ~I1)(done, run)
        done_counter(CE=edge_r, RESET=run_reset)

        # State variables for high-level state machine
        ready = mantle.LUT2(~I0 & I1)(run, edge_f)
        start = mantle.ULE(4)(rom_index.O, m.uint(3, 4))
        burst = mantle.UGE(4)(rom_index.O, m.uint(9, 4))

        # Shift register to store 16-bit command|data to send
        mosi = mantle.PISO(16, has_ce=True)
        # SPI enable is negative of load-don't load and shift out data at the
        # same time
        enable = mantle.LUT3(I0 & ~I1 & ~I2)(trigger, run, burst)
        mosi(~burst, rom.O, enable)
        m.wire(edge_f, mosi.CE)

        # Shit register to read in 8-bit data
        miso = mantle.SIPO(8, has_ce=True)
        miso(cam.MISO)
        valid = mantle.LUT2(~I0 & I1)(enable, edge_r)
        m.wire(valid, miso.CE)

        # Capture done state variable
        cap_done = mantle.SRFF(has_ce=True)
        cap_done(mantle.EQ(8)(miso.O, m.bits(0x08, 8)), 0)
        m.wire(enable & edge_r, cap_done.CE)

        # Use state variables to determine what commands are sent (how)
        increment = mantle.LUT4(I0 & (I1 | I2) & ~I3)(
            ready, start, cap_done, burst)
        m.wire(increment, rom_index.CE)

        # wire outputs
        m.wire(enable, cam.EN)
        m.wire(mosi.O, cam.MOSI)
        m.wire(miso.O, cam.DATA)
        m.wire(burst,  cam.VALID)

        # --------------------------UART OUTPUT---------------------------- #

        # run UART at 2x SPI rate to allow it to keep up
        baud = edge_r | edge_f

        # reset when SPI burst read (image transfer) begins
        ff = mantle.FF(has_ce=True)
        m.wire(edge_r, ff.CE)
        u_reset = mantle.LUT2(I0 & ~I1)(burst, ff(burst))

        # UART data out every 8 bits
        u_counter = mantle.CounterModM(8, 3, has_ce=True, has_reset=True)
        u_counter(CE=edge_r, RESET=u_reset)
        load = burst & rising(u_counter.COUT)

        uart = UART(8)
        uart(CLK=cam.CLK, BAUD=baud, DATA=miso, LOAD=load)

        # wire output
        m.wire(uart, cam.UART)

        # generate signal for when transfer is done
        data_count = mantle.Counter(18, has_ce=True)
        tx_done = mantle.SRFF(has_ce=True)
        # transfer has size 153600 bytes, first 2 bytes are ignored
        tx_done(mantle.EQ(18)(data_count.O, m.bits(153602, 18)), 0)
        m.wire(load, tx_done.CE)
        m.wire(load, data_count.CE)

        # wire output
        m.wire(tx_done, cam.DONE)
Ejemplo n.º 7
0
Archivo: ftdi.py Proyecto: splhack/loam
icestick.Clock.on()
icestick.TX.output().on()

main = icestick.main()

valid = 1

init = [m.array(int2seq(ord(c), 8)) for c in 'hello, world  \r\n']

printf = mantle.Counter(4, has_ce=True)
rom = ROM(4, init, printf.O)

data = m.array([rom.O[7], rom.O[6], rom.O[5], rom.O[4],
                rom.O[3], rom.O[2], rom.O[1], rom.O[0], 0])

counter = mantle.CounterModM(103, 8)
baud = counter.COUT

count = mantle.Counter(4, has_ce=True, has_reset=True)
decode = mantle.Decode(15, 4)
done = decode(count.O)

run = mantle.DFF(has_ce=True)
run_n = mantle.LUT3([0,0,1,0, 1,0,1,0])
run_n(done, valid, run.O)
run(run_n, ce=baud)

reset = mantle.LUT2(mantle.I0&~mantle.I1)(done, run)
count(CE=baud, RESET=reset)

shift = mantle.PISO(9, has_ce=True)
Ejemplo n.º 8
0
from uart import UART

icestick = IceStick()
icestick.Clock.on()
for i in range(8):
    icestick.J3[i].output().on()

main = icestick.main()

# "test" data
init = [m.uint(n, 16) for n in range(16)]
printf = mantle.Counter(4, has_ce=True)
rom = ROM16(4, init, printf.O)

# baud for uart output
clock = mantle.CounterModM(103, 8)
baud = clock.COUT

bit_counter = mantle.Counter(5, has_ce=True)
m.wire(baud, bit_counter.CE)

load = mantle.Decode(0, 5)(bit_counter.O)

valid_counter = mantle.CounterModM(8, 3, has_ce=True)
m.wire(load & baud, valid_counter.CE)

valid_list = [5, 7]

valid = m.GND

for i in valid_list:
Ejemplo n.º 9
0
hx8kboard.J2[5].output().on()
hx8kboard.J2[8].output().on()
hx8kboard.J2[9].output().on()
hx8kboard.J2[10].output().on()
hx8kboard.J2[11].output().on()
hx8kboard.J2[12].output().on()

main = hx8kboard.main()

# "test" data
init = [m.uint(i, 16) for i in range(16)]
printf = mantle.Counter(4, has_ce=True)
rom = ROM16(4, init, printf.O)

# baud for uart output
clock = mantle.CounterModM(16, 8)
baud = clock.COUT

bit_counter = mantle.Counter(5, has_ce=True)
m.wire(baud, bit_counter.CE)

load = mantle.Decode(0, 5)(bit_counter.O)

m.wire(load & baud, printf.CE)

rescale = Rescale()

# inputs
m.wire(main.CLKIN, rescale.CLK)
m.wire(rom.O, rescale.DATA)
m.wire(baud, rescale.SCK)
Ejemplo n.º 10
0
    def definition(io):
        load = io.LOAD
        baud = io.BAUD

        valid_counter = mantle.CounterModM(buf_size, 13, has_ce=True)
        m.wire(load & baud, valid_counter.CE)

        valid_list = [wi * (b - 1) + i * a - 1 for i in range(1, wo + 1)]
        valid = m.GND

        for i in valid_list:
            valid = valid | mantle.Decode(i, 13)(valid_counter.O)

        # register on input
        st_in = mantle.Register(16, has_ce=True)
        st_in(io.DATA)
        m.wire(load, st_in.CE)

        # --------------------------DOWNSCALING----------------------------- #
        # downscale the image from 320x240 to 16x16
        Downscale = m.DeclareCircuit(
                    'Downscale',
                    "I_0_0", m.In(m.Array(1, m.Array(1, m.Array(16, m.Bit)))),
                    "WE", m.In(m.Bit), "CLK", m.In(m.Clock),
                    "O", m.Out(m.Array(16, m.Bit)), "V", m.Out(m.Bit))

        dscale = Downscale()

        m.wire(st_in.O, dscale.I_0_0[0][0])
        m.wire(1, dscale.WE)
        m.wire(load, dscale.CLK)

        add16 = mantle.Add(16)  # needed for Add16 definition

        # --------------------------FILL IMG RAM--------------------------- #
        # each valid output of dscale represents an entry of 16x16 binary image
        # accumulate each group of 16 entries into a 16-bit value representing
        # a row of the image
        col = mantle.Counter(4, has_ce=True)

        row_full = mantle.SRFF(has_ce=True)
        row_full(mantle.EQ(4)(col.O, m.bits(15, 4)), 0)
        m.wire(falling(dscale.V), row_full.CE)
        col_ce = rising(dscale.V) & ~row_full.O
        m.wire(col_ce, col.CE)

        row = mantle.Counter(4, has_ce=True)

        img_full = mantle.SRFF(has_ce=True)
        img_full(mantle.EQ(4)(row.O, m.bits(15, 4)), 0)
        m.wire(falling(col.COUT), img_full.CE)
        row_ce = rising(col.COUT) & ~img_full.O
        m.wire(row_ce, row.CE)

        # ---------------------------UART OUTPUT----------------------------- #

        uart_st = UART(16)
        uart_st(CLK=io.CLK, BAUD=baud, DATA=dscale.O, LOAD=load)

        m.wire(row.O, io.ROW)
        m.wire(img_full.O, io.DONE)
        m.wire(uart_st.O, io.UART)
Ejemplo n.º 11
0
from loam.boards.hx8kboard import HX8KBoard
from uart import UART

hx8kboard = HX8KBoard()
hx8kboard.Clock.on()

hx8kboard.J2[3].output().on()
hx8kboard.J2[4].output().on()
hx8kboard.J2[5].output().on()
hx8kboard.J2[8].output().on()
hx8kboard.J2[9].output().on()

main = hx8kboard.main()

# baud for uart output
clock = mantle.CounterModM(100, 8)
baud = clock.COUT

bit_counter = mantle.Counter(5, has_ce=True)
m.wire(baud, bit_counter.CE)

load = mantle.Decode(0, 5)(bit_counter.O)

valid_counter = mantle.CounterModM(4800, 13, has_ce=True)
m.wire(load & baud, valid_counter.CE)

valid_list = [10 * i for i in range(16)]  # len = 16

valid = m.GND

for i in valid_list:
Ejemplo n.º 12
0
    def __init__(self, x_len, n_ways: int, n_sets: int, b_bytes: int):
        nasti_params = NastiParameters(data_bits=64,
                                       addr_bits=x_len,
                                       id_bits=5)

        self.io = m.IO(req=m.Consumer(m.Decoupled[make_CacheReq(x_len)]),
                       resp=m.Producer(m.Decoupled[make_CacheResp(x_len)]),
                       nasti=make_NastiIO(nasti_params)) + m.ClockIO()
        size = m.bitutils.clog2(nasti_params.x_data_bits)
        b_bits = b_bytes << 3
        b_len = m.bitutils.clog2(b_bytes)
        s_len = m.bitutils.clog2(n_sets)
        t_len = x_len - (s_len + b_len)
        nasti_params = NastiParameters(data_bits=64,
                                       addr_bits=x_len,
                                       id_bits=5)
        data_beats = b_bits // nasti_params.x_data_bits
        length = data_beats - 1

        data = m.Memory(n_sets, m.UInt[b_bits])()
        tags = m.Memory(n_sets, m.UInt[t_len])()
        v = m.Memory(n_sets, m.Bit)()
        d = m.Memory(n_sets, m.Bit)()

        req = self.io.req.data
        tag = (req.addr >> (b_len + s_len))[:t_len]
        idx = req.addr[b_len:b_len + s_len]
        off = req.addr[:b_len]
        read = data.read(idx)
        write = m.bits(0, b_bits)
        for i in range(b_bytes):
            write |= m.mux([(read & (0xff << (8 * i))),
                            ((m.zext_to(req.data, b_bits) >>
                              ((8 * (i & 0x3)))) & 0xff) << (8 * i)],
                           ((off // 4) == (i // 4)) & (req.mask >>
                                                       (i & 0x3))[0])[:b_bits]

        class State(m.Enum):
            IDLE = 0
            WRITE = 1
            WRITE_ACK = 2
            READ = 3

        state = m.Register(init=State.IDLE)()

        write_counter = mantle.CounterModM(data_beats,
                                           max(data_beats.bit_length(), 1),
                                           has_ce=True)
        write_counter.CE @= m.enable(state.O == State.WRITE)
        w_cnt, w_done = write_counter.O, write_counter.COUT

        read_counter = mantle.CounterModM(data_beats,
                                          max(data_beats.bit_length(), 1),
                                          has_ce=True)
        read_counter.CE @= m.enable((state.O == State.READ)
                                    & self.io.nasti.r.valid)
        r_cnt, r_done = read_counter.O, read_counter.COUT

        self.io.resp.data.data @= (read >> (m.zext_to(
            (off // 4), b_bits) * x_len))[:x_len]
        self.io.nasti.ar.data @= NastiReadAddressChannel(
            nasti_params, 0, (req.addr >> b_len) << b_len, size, length)
        tags_rdata = tags.read(idx)
        self.io.nasti.aw.data @= NastiWriteAddressChannel(
            nasti_params, 0,
            m.bits(m.concat(idx, tags_rdata), nasti_params.x_addr_bits) <<
            b_len, size, length)
        self.io.nasti.w.data @= NastiWriteDataChannel(
            nasti_params,
            (read >> (m.zext_to(w_cnt, b_bits) *
                      nasti_params.x_data_bits))[:nasti_params.x_data_bits],
            None, w_done)
        self.io.nasti.w.valid @= state.O == State.WRITE
        self.io.nasti.b.ready @= state.O == State.WRITE_ACK
        self.io.nasti.r.ready @= state.O == State.READ

        d_wen = m.Bit(name="d_wen")
        d.write(True, idx, m.enable(d_wen))

        data_wen = m.Bit(name="data_wen")
        data_wdata = m.UInt[b_bits](name="data_wdata")
        data.write(data_wdata, idx, m.enable(data_wen))
        # m.display("data_wdata=%x", data_wdata).when(m.posedge(self.io.CLK))

        v_wen = m.Bit(name="v_wen")
        v.write(True, idx, m.enable(v_wen))
        v_rdata = v.read(idx)

        tags_wen = m.Bit(name="tags_wen")
        tags.write(tag, idx, m.enable(tags_wen))

        d_rdata = d.read(idx)

        # m.display("gold_state=%x", state.O).when(m.posedge(self.io.CLK))
        # m.display("gold_w_done=%x", w_done).when(m.posedge(self.io.CLK))
        # m.display("gold_b_valid=%x",
        #           self.io.nasti.b.valid).when(m.posedge(self.io.CLK))

        if TRACE:
            m.display(
                "[%0t] [cache] data[%x] <= %x, off: %x, req: %x, mask: %b",
                m.time(), idx, write, off, self.io.req.data.data,
                self.io.req.data.mask)\
                .when(m.posedge(self.io.CLK))\
                .if_((state.O == State.IDLE) &
                     (self.io.req.valid & self.io.resp.ready) &
                     (v_rdata & (tags_rdata == tag)) & req.mask.reduce_or())

            m.display(
                "[%0t] [cache] data[%x] => %x, off: %x, resp: %x", m.time(),
                idx, write, off, self.io.resp.data.data.value())\
                .when(m.posedge(self.io.CLK))\
                .if_((state.O == State.IDLE) &
                     (self.io.req.valid & self.io.resp.ready) &
                     (v_rdata & (tags_rdata == tag)) & ~req.mask.reduce_or())

        @m.inline_combinational()
        def logic():
            self.io.resp.valid @= False
            self.io.req.ready @= False
            self.io.nasti.ar.valid @= False
            self.io.nasti.aw.valid @= False

            d_wen @= False

            data_wen @= False
            data_wdata @= m.UInt[b_bits](0)
            state.I @= state.O

            tags_wen @= False
            v_wen @= False

            if state.O == State.IDLE:
                if self.io.req.valid & self.io.resp.ready:
                    if v_rdata & (tags_rdata == tag):
                        if req.mask.reduce_or():
                            d_wen @= True
                            data_wdata @= write
                            data_wen @= True
                        self.io.req.ready @= True
                        self.io.resp.valid @= True
                    else:
                        if d_rdata:
                            self.io.nasti.aw.valid @= True
                            state.I @= State.WRITE
                        else:
                            data_wdata @= 0
                            data_wen @= True
                            self.io.nasti.ar.valid @= True
                            state.I @= State.READ
            elif state.O == State.WRITE:
                if w_done:
                    state.I @= State.WRITE_ACK
            elif state.O == State.WRITE_ACK:
                if self.io.nasti.b.valid:
                    data_wdata @= 0
                    data_wen @= True
                    self.io.nasti.ar.valid @= True
                    state.I @= State.READ
            elif state.O == State.READ:
                if self.io.nasti.r.valid:
                    data_wdata @= read | (
                        m.zext_to(self.io.nasti.r.data.data, b_bits) <<
                        (m.zext_to(r_cnt, b_bits) * nasti_params.x_data_bits))
                    data_wen @= True
                if r_done:
                    tags_wen @= True
                    v_wen @= True
                    state.I @= State.IDLE
Ejemplo n.º 13
0
    class DUT(m.Circuit):
        io = m.IO(done=m.Out(m.Bit)) + m.ClockIO()
        x_len = 32
        n_sets = 256
        b_bytes = 4 * (x_len >> 3)
        b_len = m.bitutils.clog2(b_bytes)
        s_len = m.bitutils.clog2(n_sets)
        t_len = x_len - (s_len + b_len)
        nasti_params = NastiParameters(data_bits=64,
                                       addr_bits=x_len,
                                       id_bits=5)

        dut = Cache(x_len, 1, n_sets, b_bytes)()
        dut_mem = make_NastiIO(nasti_params).undirected_t(name="dut_mem")
        dut_mem.ar @= make_Queue(dut.nasti.ar, 32)
        dut_mem.aw @= make_Queue(dut.nasti.aw, 32)
        dut_mem.w @= make_Queue(dut.nasti.w, 32)
        dut.nasti.b @= make_Queue(dut_mem.b, 32)
        dut.nasti.r @= make_Queue(dut_mem.r, 32)

        gold = GoldCache(x_len, 1, n_sets, b_bytes)()
        gold_req = type(gold.req).undirected_t(name="gold_req")
        gold_resp = type(gold.resp).undirected_t(name="gold_resp")
        gold_mem = make_NastiIO(nasti_params).undirected_t(name="gold_mem")
        gold.req @= make_Queue(gold_req, 32)
        gold_resp @= make_Queue(gold.resp, 32)
        gold_mem.ar @= make_Queue(gold.nasti.ar, 32)
        gold_mem.aw @= make_Queue(gold.nasti.aw, 32)
        gold_mem.w @= make_Queue(gold.nasti.w, 32)
        gold.nasti.b @= make_Queue(gold_mem.b, 32)
        gold.nasti.r @= make_Queue(gold_mem.r, 32)

        size = m.bitutils.clog2(nasti_params.x_data_bits // 8)
        b_bits = b_bytes << 3
        data_beats = b_bits // nasti_params.x_data_bits

        mem = m.Memory(1 << 20, m.UInt[nasti_params.x_data_bits])()

        class MemState(m.Enum):
            IDLE = 0
            WRITE = 1
            WRITE_ACK = 2
            READ = 3

        mem_state = m.Register(init=MemState.IDLE)()

        write_counter = mantle.CounterModM(data_beats,
                                           data_beats.bit_length(),
                                           has_ce=True)
        write_counter.CE @= m.enable((mem_state.O == MemState.WRITE)
                                     & dut_mem.w.valid & gold_mem.w.valid)
        read_counter = mantle.CounterModM(data_beats,
                                          data_beats.bit_length(),
                                          has_ce=True)
        read_counter.CE @= m.enable((mem_state.O == MemState.READ)
                                    & dut_mem.r.ready & gold_mem.r.ready)

        dut_mem.b.valid @= mem_state.O == MemState.WRITE_ACK
        dut_mem.b.data @= NastiWriteResponseChannel(nasti_params, 0)
        dut_mem.r.valid @= mem_state.O == MemState.READ
        dut_mem.r.data @= NastiReadDataChannel(
            nasti_params, 0,
            mem.read(
                ((gold_mem.ar.data.addr) +
                 m.zext_to(read_counter.O, nasti_params.x_addr_bits))[:20]),
            read_counter.COUT)
        gold_mem.ar.ready @= dut_mem.ar.ready
        gold_mem.aw.ready @= dut_mem.aw.ready
        gold_mem.w.ready @= dut_mem.w.ready
        gold_mem.b.valid @= dut_mem.b.valid
        gold_mem.b.data @= dut_mem.b.data
        gold_mem.r.valid @= dut_mem.r.valid
        gold_mem.r.data @= dut_mem.r.data

        mem_wen0 = m.Bit(name="mem_wen0")
        mem_wdata0 = m.UInt[nasti_params.x_data_bits](name="mem_wdata0")
        mem_wen1 = m.Bit(name="mem_wen1")
        mem_wdata1 = m.UInt[nasti_params.x_data_bits](name="mem_wdata1")
        mem_waddr1 = m.UInt[20](name="mem_waddr1")
        mem.write(
            m.mux([dut_mem.w.data.data, mem_wdata1], mem_wen1),
            m.mux([((dut_mem.aw.data.addr) +
                    m.zext_to(write_counter.O, nasti_params.x_addr_bits))[:20],
                   mem_waddr1], mem_wen1), m.enable(mem_wen0 | mem_wen1))
        # m.display("mem_wen0 = %x, mem_wen1 = %x", mem_wen0,
        #           mem_wen1).when(m.posedge(io.CLK))
        # m.display("dut_mem.w.valid = %x",
        #           dut_mem.w.valid).when(m.posedge(io.CLK))
        # m.display("gold_mem.w.valid = %x",
        #           gold_mem.w.valid).when(m.posedge(io.CLK))

        f.assert_immediate(
            (mem_state.O != MemState.IDLE)
            | ~(gold_mem.aw.valid & dut_mem.aw.valid) |
            (dut_mem.aw.data.addr == gold_mem.aw.data.addr),
            failure_msg=(
                "[dut_mem.aw.data.addr] %x != [gold_mem.aw.data.addr] %x",
                dut_mem.aw.data.addr, gold_mem.aw.data.addr))

        f.assert_immediate(
            (mem_state.O != MemState.IDLE)
            | ~(gold_mem.aw.valid & dut_mem.aw.valid)
            | ~(gold_mem.ar.valid & dut_mem.ar.valid) |
            (dut_mem.ar.data.addr == gold_mem.ar.data.addr),
            failure_msg=(
                "[dut_mem.ar.data.addr] %x != [gold_mem.ar.data.addr] %x",
                dut_mem.ar.data.addr, gold_mem.ar.data.addr))

        f.assert_immediate(
            (mem_state.O != MemState.WRITE)
            | ~(gold_mem.w.valid & dut_mem.w.valid) |
            (dut_mem.w.data.data == gold_mem.w.data.data),
            failure_msg=(
                "[dut_mem.w.data.data] %x != [gold_mem.w.data.data] %x",
                dut_mem.w.data.data, gold_mem.w.data.data))

        @m.inline_combinational()
        def mem_fsm():
            dut_mem.w.ready @= False
            dut_mem.aw.ready @= False
            dut_mem.ar.ready @= False

            mem_wen0 @= False

            mem_state.I @= mem_state.O

            if mem_state.O == MemState.IDLE:
                if gold_mem.aw.valid & dut_mem.aw.valid:
                    mem_state.I @= MemState.WRITE
                elif gold_mem.ar.valid & dut_mem.ar.valid:
                    mem_state.I @= MemState.READ
            elif mem_state.O == MemState.WRITE:
                if gold_mem.w.valid & dut_mem.w.valid:
                    mem_wen0 @= True
                    dut_mem.w.ready @= True
                if write_counter.COUT:
                    dut_mem.aw.ready @= True
                    mem_state.I @= MemState.WRITE_ACK
            elif mem_state.O == MemState.WRITE_ACK:
                if gold_mem.b.ready & dut_mem.b.ready:
                    mem_state.I @= MemState.IDLE
            elif mem_state.O == MemState.READ:
                if read_counter.COUT:
                    dut_mem.ar.ready @= True
                    mem_state.I @= MemState.IDLE

        if TRACE:
            m.display("[%0t]: [write] mem[%x] <= %x", m.time(),
                      mem.WADDR.value(), dut_mem.w.data.data).when(
                          m.posedge(io.CLK)).if_(mem_wen0)
            m.display("[%0t]: [read] mem[%x] => %x", m.time(),
                      mem.RADDR.value(),
                      dut_mem.r.data.data).when(m.posedge(
                          io.CLK)).if_((mem_state.O == MemState.READ)
                                       & dut_mem.r.ready & gold_mem.r.ready)

        def rand_data(nasti_params):
            rand_data = BitVector[nasti_params.x_data_bits](0)
            for i in range(nasti_params.x_data_bits // 8):
                rand_data |= BitVector[nasti_params.x_data_bits](
                    random.randint(0, 0xff) << (8 * i))
            return rand_data

        def rand_mask(x_len):
            return BitVector[x_len // 8](random.randint(
                1, (1 << (x_len // 8)) - 2))

        def make_test(rand_data, nasti_params, x_len):
            # Wrapper because function definition in side class namespace
            # doesn't inherit class variables
            def test(b_bits, tag, idx, off, mask=BitVector[x_len // 8](0)):
                test_data = rand_data(nasti_params)
                for i in range((b_bits // nasti_params.x_data_bits) - 1):
                    test_data = test_data.concat(rand_data(nasti_params))
                return m.uint(m.concat(off, idx, tag, test_data, mask))

            return test

        test = make_test(rand_data, nasti_params, x_len)

        tags = []
        for _ in range(3):
            tags.append(BitVector.random(t_len))
        idxs = []
        for _ in range(2):
            idxs.append(BitVector.random(s_len))
        offs = []
        for _ in range(6):
            offs.append(BitVector.random(b_len) & -4)

        init_addr = []
        init_data = []
        _iter = itertools.product(tags, idxs, range(0, data_beats))
        for tag, idx, off in _iter:
            init_addr.append(m.uint(m.concat(BitVector[b_len](off), idx, tag)))
            init_data.append(rand_data(nasti_params))

        test_vec = [
            test(b_bits, tags[0], idxs[0], offs[0]),  # 0: read miss
            test(b_bits, tags[0], idxs[0], offs[1]),  # 1: read hit
            test(b_bits, tags[1], idxs[0], offs[0]),  # 2: read miss
            test(b_bits, tags[1], idxs[0], offs[2]),  # 3: read hit
            test(b_bits, tags[1], idxs[0], offs[3]),  # 4: read hit
            test(b_bits, tags[1], idxs[0], offs[4],
                 rand_mask(x_len)),  # 5: write hit  # noqa
            test(b_bits, tags[1], idxs[0], offs[4]),  # 6: read hit
            test(b_bits, tags[2], idxs[0],
                 offs[5]),  # 7: read miss & write back  # noqa
            test(b_bits, tags[0], idxs[1], offs[0],
                 rand_mask(x_len)),  # 8: write miss  # noqa
            test(b_bits, tags[0], idxs[1], offs[0]),  # 9: read hit
            test(b_bits, tags[0], idxs[1], offs[1]),  # 10: read hit
            test(b_bits, tags[1], idxs[1], offs[2],
                 rand_mask(x_len)),  # 11: write miss & write back  # noqa
            test(b_bits, tags[1], idxs[1], offs[3]),  # 12: read hit
            test(b_bits, tags[2], idxs[1], offs[4]),  # 13: read write back
            test(b_bits, tags[2], idxs[1], offs[5])  # 14: read hit
        ]

        class TestState(m.Enum):
            INIT = 0
            START = 1
            WAIT = 2
            DONE = 3

        state = m.Register(init=TestState.INIT)()
        timeout = m.Register(m.UInt[32])()
        init_m = len(init_addr) - 1
        init_counter = mantle.CounterModM(init_m,
                                          init_m.bit_length(),
                                          has_ce=True)
        init_counter.CE @= m.enable(state.O == TestState.INIT)

        test_m = len(test_vec) - 1
        test_counter = mantle.CounterModM(test_m,
                                          test_m.bit_length(),
                                          has_ce=True)
        test_counter.CE @= m.enable(state.O == TestState.DONE)
        curr_vec = m.mux(test_vec, test_counter.O)
        mask = (curr_vec >> (b_len + s_len + t_len + b_bits))[:x_len // 8]
        data = (curr_vec >> (b_len + s_len + t_len))[:b_bits]
        tag = (curr_vec >> (b_len + s_len))[:t_len]
        idx = (curr_vec >> b_len)[:s_len]
        off = curr_vec[:b_len]

        dut.cpu.req.data.addr @= m.concat(off, idx, tag)
        # TODO: Is truncating this fine?
        req_data = data[:x_len]
        dut.cpu.req.data.data @= req_data
        dut.cpu.req.data.mask @= mask
        dut.cpu.req.valid @= state.O == TestState.WAIT
        dut.cpu.abort @= 0
        gold_req.data @= dut.cpu.req.data.value()
        gold_req.valid @= state.O == TestState.START
        gold_resp.ready @= state.O == TestState.DONE

        mem_waddr1 @= m.mux(init_addr, init_counter.O)[:20]
        mem_wdata1 @= m.mux(init_data, init_counter.O)

        check_resp_data = m.Bit()
        if TRACE:
            m.display("[%0t]: [init] mem[%x] <= %x", m.time(),
                      mem_waddr1, mem_wdata1)\
                .when(m.posedge(io.CLK))\
                .if_(state.O == TestState.INIT)

        @m.inline_combinational()
        def state_fsm():
            timeout.I @= timeout.O
            mem_wen1 @= m.bit(False)
            check_resp_data @= m.bit(False)
            state.I @= state.O
            if state.O == TestState.INIT:
                mem_wen1 @= m.bit(True)
                if init_counter.COUT:
                    state.I @= TestState.START
            elif state.O == TestState.START:
                if gold_req.ready:
                    timeout.I @= m.bits(0, 32)
                    state.I @= TestState.WAIT
            elif state.O == TestState.WAIT:
                timeout.I @= timeout.O + 1
                if dut.cpu.resp.valid & gold_resp.valid:
                    if ~mask.reduce_or():
                        check_resp_data @= m.bit(True)
                    state.I @= TestState.DONE
            elif state.O == TestState.DONE:
                state.I @= TestState.START

        f.assert_immediate((state.O != TestState.WAIT) | (timeout.O < 100))
        f.assert_immediate(
            ~check_resp_data | (dut.cpu.resp.data.data == gold_resp.data.data),
            failure_msg=("dut.cpu.resp.data.data => %x != %x",
                         dut.cpu.resp.data.data, gold_resp.data.data))
        # m.display("mem_state=%x", mem_state.O).when(m.posedge(io.CLK))
        # m.display("test_state=%x", state.O).when(m.posedge(io.CLK))
        # m.display("dut req valid = %x",
        #           dut.cpu.req.valid).when(m.posedge(io.CLK))
        # m.display("gold req valid = %x, ready = %x", gold_req.valid,
        #           gold_req.ready).when(m.posedge(io.CLK))
        # m.display("[%0t]: dut resp data = %x, gold resp data = %x", m.time(),
        #           dut.cpu.resp.data.data, gold_resp.data.data)\
        #     .when(m.posedge(io.CLK))
        io.done @= test_counter.COUT
Ejemplo n.º 14
0
import magma as m
import mantle
from loam.boards.icestick import IceStick
from uart import UART

icestick = IceStick()
icestick.Clock.on()
for i in range(3):
    icestick.J3[i].output().on()

main = icestick.main()

# baud for uart output
clock = mantle.CounterModM(103, 8)
baud = clock.COUT

bit_counter = mantle.Counter(5, has_ce=True)
m.wire(baud, bit_counter.CE)

load = mantle.Decode(0, 5)(bit_counter.O)

valid_counter = mantle.CounterModM(1000, 10, has_ce=True)
m.wire(load & baud, valid_counter.CE)

m.wire(baud, main.J3[0])
m.wire(load, main.J3[1])
m.wire(valid_counter.COUT, main.J3[2]) 

Ejemplo n.º 15
0
    def definition(io):
        load = io.LOAD
        baud = rising(io.SCK) | falling(io.SCK)

        valid_counter = mantle.CounterModM(buf_size, 12, has_ce=True)
        m.wire(load & baud, valid_counter.CE)

        valid_list = [wi * (b - 1) + i * a - 1 for i in range(1, wo + 1)]  # len = 32

        valid = m.GND

        for i in valid_list:
            valid = valid | mantle.Decode(i, 12)(valid_counter.O)

        # register on input
        st_in = mantle.Register(width, has_ce=True)
        st_in(io.DATA)
        m.wire(load, st_in.CE)

        # --------------------------DOWNSCALING----------------------------- #
        # downscale the image from 352x288 to 32x32
        Downscale = m.DeclareCircuit(
                    'Downscale',
                    "I_0_0", m.In(m.Array(1, m.Array(1, m.Array(width, m.Bit)))),
                    "WE", m.In(m.Bit), "CLK", m.In(m.Clock),
                    "O", m.Out(m.Array(width, m.Bit)), "V", m.Out(m.Bit))

        dscale = Downscale()

        m.wire(st_in.O, dscale.I_0_0[0][0])
        m.wire(1, dscale.WE)
        m.wire(load, dscale.CLK)

        add16 = mantle.Add(width)  # needed for Add16 definition

        # --------------------------FILL IMG RAM--------------------------- #
        # each valid output of dscale represents an entry of 32x32 binary image
        # accumulate each group of 32 entries into a 32-bit value representing a row
        col = mantle.CounterModM(32, 6, has_ce=True) 
        col_ce = rising(valid) 
        m.wire(col_ce, col.CE)

        # shift each bit in one at a time until we get an entire row
        px_bit = mantle.ULE(16)(dscale.O, m.uint(THRESH, 16)) & valid
        row_reg = mantle.SIPO(32, has_ce=True)
        row_reg(px_bit)
        m.wire(col_ce, row_reg.CE)

        # reverse the row bits since the image is flipped
        row = reverse(row_reg.O)

        rowaddr = mantle.Counter(6, has_ce=True)

        img_full = mantle.SRFF(has_ce=True)
        img_full(mantle.EQ(6)(rowaddr.O, m.bits(32, 6)), 0)
        m.wire(falling(col.COUT), img_full.CE)
        row_ce = rising(col.COUT) & ~img_full.O
        m.wire(row_ce, rowaddr.CE)

        waddr = rowaddr.O[:5]

        rdy = col.COUT & ~img_full.O
        pulse_count = mantle.Counter(2, has_ce=True)
        we = mantle.UGE(2)(pulse_count.O, m.uint(1, 2))
        pulse_count(CE=(we|rdy))

        # ---------------------------UART OUTPUT----------------------------- #

        row_load = row_ce
        row_baud = mantle.FF()(baud)
        uart_row = UART(32)
        uart_row(CLK=io.CLK, BAUD=row_baud, DATA=row, LOAD=row_load)

        uart_addr = UART(5)
        uart_addr(CLK=io.CLK, BAUD=row_baud, DATA=waddr, LOAD=row_load)

        m.wire(waddr, io.WADDR)
        m.wire(img_full, io.DONE) #img_full
        m.wire(uart_row, io.UART) #uart_st
        m.wire(row, io.O)
        m.wire(we, io.VALID)

        m.wire(valid, io.T0)
        m.wire(uart_addr, io.T1)
Ejemplo n.º 16
0
    def definition(io):
        load = io.LOAD
        baud = rising(io.SCK) | falling(io.SCK)

        valid_counter = mantle.CounterModM(buf_size, 13, has_ce=True)
        m.wire(load & baud, valid_counter.CE)

        valid_list = [wi * (b - 1) + i * a - 1 for i in range(1, wo + 1)]

        valid = m.GND

        for i in valid_list:
            valid = valid | mantle.Decode(i, 13)(valid_counter.O)

        # register on input
        st_in = mantle.Register(16, has_ce=True)
        st_in(io.DATA)
        m.wire(load, st_in.CE)

        # --------------------------DOWNSCALING----------------------------- #
        # downscale the image from 320x240 to 16x16
        Downscale = m.DeclareCircuit(
                    'Downscale',
                    "I_0_0", m.In(m.Array(1, m.Array(1, m.Array(16, m.Bit)))),
                    "WE", m.In(m.Bit), "CLK", m.In(m.Clock),
                    "O", m.Out(m.Array(16, m.Bit)), "V", m.Out(m.Bit))

        dscale = Downscale()

        m.wire(st_in.O, dscale.I_0_0[0][0])
        m.wire(1, dscale.WE)
        m.wire(load, dscale.CLK)

        add16 = mantle.Add(16)  # needed for Add16 definition

        # --------------------------FILL IMG RAM--------------------------- #
        # each valid output of dscale represents a pixel in 16x16 binary image
        # accumulate each group of 16 pixels into a 16-bit value representing
        # a row in the image
        col = mantle.CounterModM(16, 5, has_ce=True)
        col_ce = rising(valid)
        m.wire(col_ce, col.CE)

        # shift each bit in one at a time until we get an entire row
        px_bit = mantle.ULE(16)(dscale.O, m.uint(THRESH, 16)) & valid
        row_reg = mantle.SIPO(16, has_ce=True)
        row_reg(px_bit)
        m.wire(col_ce, row_reg.CE)

        # reverse the row bits since the image is flipped
        row = reverse(row_reg.O)

        rowaddr = mantle.Counter(5, has_ce=True)

        img_full = mantle.SRFF(has_ce=True)
        img_full(mantle.EQ(5)(rowaddr.O, m.bits(16, 5)), 0)
        m.wire(falling(col.COUT), img_full.CE)
        row_ce = rising(col.COUT) & ~img_full.O
        m.wire(row_ce, rowaddr.CE)

        waddr = rowaddr.O[:4]

        # we_counter = mantle.CounterModM(16, 5, has_ce=True)
        # m.wire(rising(valid), we_counter.CE)

        rdy = col.COUT & ~img_full.O
        pulse_count = mantle.Counter(5, has_ce=True)
        we = mantle.UGE(5)(pulse_count.O, m.uint(1, 5))
        pulse_count(CE=(we | rdy))

        # ---------------------------UART OUTPUT----------------------------- #

        row_load = row_ce
        row_baud = mantle.FF()(baud)
        uart_row = UART(16)
        uart_row(CLK=io.CLK, BAUD=row_baud, DATA=row, LOAD=row_load)

        uart_addr = UART(4)
        uart_addr(CLK=io.CLK, BAUD=row_baud, DATA=waddr, LOAD=row_load)

        # split 16-bit row data into 8-bit packets so it can be parsed
        low_byte = row & LOW_MASK
        high_byte = row & HIGH_MASK
        uart_counter = mantle.CounterModM(8, 4, has_ce=True)
        m.wire(rising(valid), uart_counter.CE)

        m.wire(waddr, io.WADDR)
        m.wire(img_full, io.DONE)
        m.wire(uart_row, io.UART)
        m.wire(row, io.O)
        m.wire(we, io.VALID)
Ejemplo n.º 17
0
a = 2
b = 2

width = 16
TIN = m.Array(width, m.BitIn)
TOUT = m.Array(width, m.Out(m.Bit))

icestick = IceStick()
icestick.Clock.on()
for i in range(3):
    icestick.J3[i].output().on()

main = icestick.main()

# baud for uart output
clock = mantle.CounterModM(103, 8)
baud = clock.COUT

bit_counter = mantle.Counter(5, has_ce=True)
m.wire(baud, bit_counter.CE)

load = mantle.Decode(0, 5)(bit_counter.O)

# # "test" data
# init = [m.uint(i, 16) for i in range(16)]
# printf = mantle.Counter(4, has_ce=True)
# rom = ROM16(4, init, printf.O)
# m.wire(load & baud, printf.CE)

#---------------------------STENCILING-----------------------------#
Ejemplo n.º 18
0
    def __init__(self, x_len, n_ways: int, n_sets: int, b_bytes: int):
        b_bits = b_bytes << 3
        b_len = m.bitutils.clog2(b_bytes)
        s_len = m.bitutils.clog2(n_sets)
        t_len = x_len - (s_len + b_len)
        n_words = b_bits // x_len
        w_bytes = x_len // 8
        byte_offset_bits = m.bitutils.clog2(w_bytes)
        nasti_params = NastiParameters(data_bits=64,
                                       addr_bits=x_len,
                                       id_bits=5)
        data_beats = b_bits // nasti_params.x_data_bits

        class MetaData(m.Product):
            tag = m.UInt[t_len]

        self.io = m.IO(**make_cache_ports(x_len, nasti_params))
        self.io += m.ClockIO()

        class State(m.Enum):
            IDLE = 0
            READ_CACHE = 1
            WRITE_CACHE = 2
            WRITE_BACK = 3
            WRITE_ACK = 4
            REFILL_READY = 5
            REFILL = 6

        state = m.Register(init=State.IDLE)()

        # memory
        v = m.Register(m.UInt[n_sets], has_enable=True)()
        d = m.Register(m.UInt[n_sets], has_enable=True)()
        meta_mem = m.Memory(n_sets,
                            MetaData,
                            read_latency=1,
                            has_read_enable=True)()
        data_mem = [
            ArrayMaskMem(n_sets,
                         w_bytes,
                         m.UInt[8],
                         read_latency=1,
                         has_read_enable=True)() for _ in range(n_words)
        ]

        addr_reg = m.Register(type(self.io.cpu.req.data.addr).undirected_t,
                              has_enable=True)()
        cpu_data = m.Register(type(self.io.cpu.req.data.data).undirected_t,
                              has_enable=True)()
        cpu_mask = m.Register(type(self.io.cpu.req.data.mask).undirected_t,
                              has_enable=True)()

        self.io.nasti.r.ready @= state.O == State.REFILL
        # Counters
        assert data_beats > 0
        if data_beats > 1:
            read_counter = mantle.CounterModM(data_beats,
                                              max(data_beats.bit_length(), 1),
                                              has_ce=True)
            read_counter.CE @= m.enable(self.io.nasti.r.fired())
            read_count, read_wrap_out = read_counter.O, read_counter.COUT

            write_counter = mantle.CounterModM(data_beats,
                                               max(data_beats.bit_length(), 1),
                                               has_ce=True)
            write_count, write_wrap_out = write_counter.O, write_counter.COUT
        else:
            read_count, read_wrap_out = 0, 1
            write_count, write_wrap_out = 0, 1

        refill_buf = m.Register(m.Array[data_beats,
                                        m.UInt[nasti_params.x_data_bits]],
                                has_enable=True)()
        if data_beats == 1:
            refill_buf.I[0] @= self.io.nasti.r.data.data
        else:
            refill_buf.I @= m.set_index(refill_buf.O,
                                        self.io.nasti.r.data.data,
                                        read_count[:-1])
        refill_buf.CE @= m.enable(self.io.nasti.r.fired())

        is_idle = state.O == State.IDLE
        is_read = state.O == State.READ_CACHE
        is_write = state.O == State.WRITE_CACHE
        is_alloc = (state.O == State.REFILL) & read_wrap_out
        # m.display("[%0t]: is_alloc = %x", m.time(), is_alloc)\
        #     .when(m.posedge(self.io.CLK))
        is_alloc_reg = m.Register(m.Bit)()(is_alloc)

        hit = m.Bit(name="hit")
        wen = is_write & (hit | is_alloc_reg) & ~self.io.cpu.abort | is_alloc
        # m.display("[%0t]: wen = %x", m.time(), wen)\
        #     .when(m.posedge(self.io.CLK))
        ren = m.enable(~wen & (is_idle | is_read) & self.io.cpu.req.valid)
        ren_reg = m.enable(m.Register(m.Bit)()(ren))

        addr = self.io.cpu.req.data.addr
        idx = addr[b_len:s_len + b_len]
        tag_reg = addr_reg.O[s_len + b_len:x_len]
        idx_reg = addr_reg.O[b_len:s_len + b_len]
        off_reg = addr_reg.O[byte_offset_bits:b_len]

        rmeta = meta_mem.read(idx, ren)
        rdata = m.concat(*(mem.read(idx, ren) for mem in data_mem))
        rdata_buf = m.Register(type(rdata), has_enable=True)()(rdata,
                                                               CE=ren_reg)

        read = m.mux([
            m.as_bits(m.mux([rdata_buf, rdata], ren_reg)),
            m.as_bits(refill_buf.O)
        ], is_alloc_reg)
        # m.display("is_alloc_reg=%x", is_alloc_reg)\
        #     .when(m.posedge(self.io.CLK))

        hit @= v.O[idx_reg] & (rmeta.tag == tag_reg)

        # read mux
        self.io.cpu.resp.data.data @= m.array(
            [read[i * x_len:(i + 1) * x_len] for i in range(n_words)])[off_reg]
        self.io.cpu.resp.valid @= (is_idle | (is_read & hit) |
                                   (is_alloc_reg & ~cpu_mask.O.reduce_or()))
        m.display("resp.valid=%x", self.io.cpu.resp.valid.value())\
            .when(m.posedge(self.io.CLK))
        m.display("[%0t]: valid = %x", m.time(),
                  self.io.cpu.resp.valid.value())\
            .when(m.posedge(self.io.CLK))
        m.display("[%0t]: is_idle = %x, is_read = %x, hit = %x, is_alloc_reg = "
                  "%x, ~cpu_mask.O.reduce_or() = %x", m.time(), is_idle,
                  is_read, hit, is_alloc_reg, ~cpu_mask.O.reduce_or())\
            .when(m.posedge(self.io.CLK))
        m.display("[%0t]: refill_buf.O=%x, %x", m.time(), *refill_buf.O)\
            .when(m.posedge(self.io.CLK))\
            .if_(self.io.cpu.resp.valid.value() & is_alloc_reg)
        m.display("[%0t]: read=%x", m.time(), read)\
            .when(m.posedge(self.io.CLK))\
            .if_(self.io.cpu.resp.valid.value() & is_alloc_reg)

        addr_reg.I @= addr
        addr_reg.CE @= m.enable(self.io.cpu.resp.valid.value())

        cpu_data.I @= self.io.cpu.req.data.data
        cpu_data.CE @= m.enable(self.io.cpu.resp.valid.value())

        cpu_mask.I @= self.io.cpu.req.data.mask
        cpu_mask.CE @= m.enable(self.io.cpu.resp.valid.value())

        wmeta = MetaData(name="wmeta")
        wmeta.tag @= tag_reg

        offset_mask = (m.zext_to(cpu_mask.O, w_bytes * 8) << m.concat(
            m.bits(0, byte_offset_bits), off_reg))
        wmask = m.mux([m.SInt[w_bytes * 8](-1),
                       m.sint(offset_mask)], ~is_alloc)

        if len(refill_buf.O) == 1:
            wdata_alloc = self.io.nasti.r.data.data
        else:
            wdata_alloc = m.concat(
                # TODO: not sure why they use `init.reverse`
                # https://github.com/ucb-bar/riscv-mini/blob/release/src/main/scala/Cache.scala#L116
                m.concat(*refill_buf.O[:-1]),
                self.io.nasti.r.data.data)
        wdata = m.mux([wdata_alloc,
                       m.as_bits(m.repeat(cpu_data.O, n_words))], ~is_alloc)

        v.I @= m.set_index(v.O, m.bit(True), idx_reg)
        v.CE @= m.enable(wen)
        d.I @= m.set_index(d.O, ~is_alloc, idx_reg)
        d.CE @= m.enable(wen)
        # m.display("[%0t]: refill_buf.O = %x", m.time(),
        #           m.concat(*refill_buf.O)).when(m.posedge(self.io.CLK)).if_(wen)
        # m.display("[%0t]: nasti.r.data.data = %x", m.time(),
        #           self.io.nasti.r.data.data).when(m.posedge(self.io.CLK)).if_(wen)

        meta_mem.write(wmeta, idx_reg, m.enable(wen & is_alloc))
        for i, mem in enumerate(data_mem):
            data = [
                wdata[i * x_len + j * 8:i * x_len + (j + 1) * 8]
                for j in range(w_bytes)
            ]
            mem.write(m.array(data), idx_reg,
                      wmask[i * w_bytes:(i + 1) * w_bytes], m.enable(wen))
            # m.display("[%0t]: wdata = %x, %x, %x, %x", m.time(),
            #           *mem.WDATA.value()).when(m.posedge(self.io.CLK)).if_(wen)
            # m.display("[%0t]: wmask = %x, %x, %x, %x", m.time(),
            #           *mem.WMASK.value()).when(m.posedge(self.io.CLK)).if_(wen)

        tag_and_idx = m.zext_to(m.concat(idx_reg, tag_reg),
                                nasti_params.x_addr_bits)
        self.io.nasti.ar.data @= NastiReadAddressChannel(
            nasti_params, 0, tag_and_idx << m.Bits[len(tag_and_idx)](b_len),
            m.bitutils.clog2(nasti_params.x_data_bits // 8), data_beats - 1)

        rmeta_and_idx = m.zext_to(m.concat(idx_reg, rmeta.tag),
                                  nasti_params.x_addr_bits)
        self.io.nasti.aw.data @= NastiWriteAddressChannel(
            nasti_params, 0,
            rmeta_and_idx << m.Bits[len(rmeta_and_idx)](b_len),
            m.bitutils.clog2(nasti_params.x_data_bits // 8), data_beats - 1)

        self.io.nasti.w.data @= NastiWriteDataChannel(
            nasti_params,
            m.array([
                read[i * nasti_params.x_data_bits:(i + 1) *
                     nasti_params.x_data_bits] for i in range(data_beats)
            ])[write_count[:-1]], None, write_wrap_out)

        is_dirty = v.O[idx_reg] & d.O[idx_reg]

        # TODO: Have to use temporary so we can invoke `fired()`
        aw_valid = m.Bit(name="aw_valid")
        self.io.nasti.aw.valid @= aw_valid

        ar_valid = m.Bit(name="ar_valid")
        self.io.nasti.ar.valid @= ar_valid

        b_ready = m.Bit(name="b_ready")
        self.io.nasti.b.ready @= b_ready

        @m.inline_combinational()
        def logic():
            state.I @= state.O
            aw_valid @= False
            ar_valid @= False
            self.io.nasti.w.valid @= False
            b_ready @= False
            if state.O == State.IDLE:
                if self.io.cpu.req.valid:
                    if self.io.cpu.req.data.mask.reduce_or():
                        state.I @= State.WRITE_CACHE
                    else:
                        state.I @= State.READ_CACHE
            elif state.O == State.READ_CACHE:
                if hit:
                    if self.io.cpu.req.valid:
                        if self.io.cpu.req.data.mask.reduce_or():
                            state.I @= State.WRITE_CACHE
                        else:
                            state.I @= State.READ_CACHE
                    else:
                        state.I @= State.IDLE
                else:
                    aw_valid @= is_dirty
                    ar_valid @= ~is_dirty
                    if self.io.nasti.aw.fired():
                        state.I @= State.WRITE_BACK
                    elif self.io.nasti.ar.fired():
                        state.I @= State.REFILL
            elif state.O == State.WRITE_CACHE:
                if hit | is_alloc_reg | self.io.cpu.abort:
                    state.I @= State.IDLE
                else:
                    aw_valid @= is_dirty
                    ar_valid @= ~is_dirty
                    if self.io.nasti.aw.fired():
                        state.I @= State.WRITE_BACK
                    elif self.io.nasti.ar.fired():
                        state.I @= State.REFILL
            elif state.O == State.WRITE_BACK:
                self.io.nasti.w.valid @= True
                if write_wrap_out:
                    state.I @= State.WRITE_ACK
            elif state.O == State.WRITE_ACK:
                b_ready @= True
                if self.io.nasti.b.fired():
                    state.I @= State.REFILL_READY
            elif state.O == State.REFILL_READY:
                ar_valid @= True
                if self.io.nasti.ar.fired():
                    state.I @= State.REFILL
            elif state.O == State.REFILL:
                if read_wrap_out:
                    if cpu_mask.O.reduce_or():
                        state.I @= State.WRITE_CACHE
                    else:
                        state.I @= State.IDLE

        if data_beats > 1:
            # TODO: Have to do this at the end since the inline comb logic
            # wires up nasti.w
            write_counter.CE @= m.enable(self.io.nasti.w.fired())