Пример #1
0
    class Main(m.Circuit):
        io = m.IO(clocks=ClocksT, count=m.Out(m.UInt[3]))
        count = m.Register(m.UInt[3])()
        count.CLK @= io.clocks.clk0
        io.count @= count(count.O + 1)

        tff = m.Register(m.Bit, has_enable=True)()
        tff.CLK @= io.clocks.clk0
        tff.CE @= m.enable(count.O == 3)
        io.clocks.clk1 @= m.clock(tff(tff.O ^ 1))
Пример #2
0
class _Accum(m.Circuit):
    T = m.UInt[16]
    io = m.IO(I=m.In(T), O=m.Out(T)) + m.ClockIO()
    reg = m.Register(T)()
    accum = reg.O + io.I
    reg.I @= accum
    io.O @= accum
Пример #3
0
 class Foo(m.Circuit):
     io = m.IO(a=m.In(m.Bits[8]),
               b=m.In(m.Bits[8]),
               c=m.In(m.Bits[8]),
               x=m.Out(m.Bits[8]),
               y=m.Out(m.Bits[8]))
     io += m.ClockIO(has_resetn=True)
     x = [m.bits(0, 8), m.bits(0, 8), m.bits(1, 8), m.bits(0, 8)]
     if should_pass:
         y = [m.bits(0, 8), m.bits(1, 8), m.bits(2, 8), m.bits(3, 8)]
     else:
         y = [m.bits(1, 8), m.bits(1, 8), m.bits(1, 8), m.bits(1, 8)]
     count = m.Register(m.Bits[2])()
     count.I @= count.O + 1
     io.x @= m.mux(x, count.O)
     io.y @= m.mux(y, count.O)
     m.display("io.x=%x, io.y=%x", io.x, io.y).when(m.posedge(io.CLK))
     if use_sva:
         f.assert_(f.sva(f.not_(f.onehot(io.a)), "&&", io.b.reduce_or(),
                         "&&", io.x[0].value(), "|=>",
                         io.y.value() != f.past(io.y.value(), 2)),
                   name="name_A",
                   on=f.posedge(io.CLK),
                   disable_iff=f.not_(io.RESETN))
     else:
         f.assert_(
             # Note parens matter!
             (f.not_(f.onehot(io.a)) & io.b.reduce_or() & io.x[0].value())
             | f.implies | f.delay[1] | (io.y != f.past(io.y.value(), 2)),
             name="name_A",
             on=f.posedge(io.CLK),
             disable_iff=f.not_(io.RESETN))
Пример #4
0
    def __init__(self, n: int, T: m.Kind = m.Bit, init=None,
                 has_enable: bool = False,
                 reset_type: Optional[m.AbstractReset] = None):
        if init is None:
            init = [T(*_zero_init_args(T)) for _ in range(n)]
        self.name = f"SIPO{n}"
        self.io = m.IO(
            I=m.In(T),
            O=m.Out(m.Array[n, T] if T is not m.Bit else m.Bits[n])
        )

        # TODO: Add magma helper func for this
        has_async_reset = reset_type == m.AsyncReset
        has_async_resetn = reset_type == m.AsyncResetN
        has_reset = reset_type == m.Reset
        has_resetn = reset_type == m.ResetN
        self.io += m.ClockIO(has_enable=has_enable,
                             has_async_reset=has_async_reset,
                             has_async_resetn=has_async_resetn,
                             has_reset=has_reset, has_resetn=has_resetn)

        regs = (m.Register(T, init=init[i], has_enable=has_enable,
                           reset_type=reset_type)() for i in range(n))
        # TODO: Default clock wiring logic raises warning inside scan
        self.io.O @= m.scan(regs, scanargs={"I": "O"})(self.io.I)
Пример #5
0
    def __init__(self, T, entries, pipe=False, flow=False):
        assert entries >= 0
        self.io = m.IO(
            # Flipped since enq/deq is from perspective of the client
            enq=m.DeqIO[T],
            deq=m.EnqIO[T],
            count=m.Out(m.UInt[m.bitutils.clog2(entries + 1)])) + m.ClockIO()

        ram = m.Memory(entries, T)()
        enq_ptr = mantle.CounterModM(entries,
                                     entries.bit_length(),
                                     has_ce=True,
                                     cout=False)
        deq_ptr = mantle.CounterModM(entries,
                                     entries.bit_length(),
                                     has_ce=True,
                                     cout=False)
        maybe_full = m.Register(init=False, has_enable=True)()

        ptr_match = enq_ptr.O == deq_ptr.O
        empty = ptr_match & ~maybe_full.O
        full = ptr_match & maybe_full.O

        self.io.deq.valid @= ~empty
        self.io.enq.ready @= ~full

        do_enq = self.io.enq.fired()
        do_deq = self.io.deq.fired()

        ram.write(self.io.enq.data, enq_ptr.O[:-1], m.enable(do_enq))

        enq_ptr.CE @= m.enable(do_enq)
        deq_ptr.CE @= m.enable(do_deq)

        maybe_full.I @= m.enable(do_enq)
        maybe_full.CE @= m.enable(do_enq != do_deq)
        self.io.deq.data @= ram[deq_ptr.O[:-1]]

        if flow:
            raise NotImplementedError()
        if pipe:
            raise NotImplementedError()

        def ispow2(n):
            return (n & (n - 1) == 0) and n != 0

        count_len = len(self.io.count)
        if ispow2(entries):
            self.io.count @= m.mux([m.bits(0, count_len), entries],
                                   maybe_full.O & ptr_match)
        else:
            ptr_diff = enq_ptr.O - deq_ptr.O
            self.io.count @= m.mux([
                m.mux([m.bits(0, count_len), entries], maybe_full.O),
                m.mux([ptr_diff, entries + ptr_diff], deq_ptr.O > enq_ptr.O)
            ], ptr_match)
Пример #6
0
 class Main(m.Circuit):
     io = m.IO(I=m.In(m.Bits[8]), O=m.Out(m.Bits[8])) + m.ClockIO()
     io.O @= m.Register(T=m.Bits[8])()(io.I)
     if sva:
         f.assert_(f.sva(io.I, "|-> ##1",
                         io.O.value() == 0),
                   on=f.posedge(io.CLK))
     else:
         f.assert_(io.I | f.implies | f.delay[1] | (io.O.value() == 0),
                   on=f.posedge(io.CLK))
Пример #7
0
class ALUTile(m.Circuit):
    io = m.IO(a=m.In(m.UInt[16]),
              b=m.In(m.UInt[16]),
              config_data=m.In(m.UInt[2]),
              config_en=m.In(m.Enable),
              c=m.Out(m.UInt[16])) + m.ClockIO()
    config_reg = m.Register(m.Bits[2], has_enable=True)()
    config_reg.CE @= io.config_en
    config_reg.I @= io.config_data
    alu = ALUCore()
    io.c @= alu(io.a, io.b, config_reg.O)
Пример #8
0
    def __init__(self, T, entries, with_bug=False):
        assert entries >= 0
        self.io = m.IO(
            # Flipped since enq/deq is from perspective of the client
            enq=m.DeqIO[T],
            deq=m.EnqIO[T]) + m.ClockIO()

        ram = m.Memory(entries, T)()
        enq_ptr = mantle.CounterModM(entries,
                                     entries.bit_length(),
                                     has_ce=True,
                                     cout=False)
        deq_ptr = mantle.CounterModM(entries,
                                     entries.bit_length(),
                                     has_ce=True,
                                     cout=False)
        maybe_full = m.Register(init=False, has_enable=True)()

        ptr_match = enq_ptr.O == deq_ptr.O
        empty = ptr_match & ~maybe_full.O
        if with_bug:
            # never full
            full = False
        else:
            full = ptr_match & maybe_full.O

        self.io.deq.valid @= ~empty
        self.io.enq.ready @= ~full

        do_enq = self.io.enq.fired()
        do_deq = self.io.deq.fired()

        ram.write(self.io.enq.data, enq_ptr.O[:-1], m.enable(do_enq))

        enq_ptr.CE @= m.enable(do_enq)
        deq_ptr.CE @= m.enable(do_deq)

        maybe_full.I @= m.enable(do_enq)
        maybe_full.CE @= m.enable(do_enq != do_deq)
        self.io.deq.data @= ram[deq_ptr.O[:-1]]

        def ispow2(n):
            return (n & (n - 1) == 0) and n != 0
Пример #9
0
    class _ConfigRegister(magma.Circuit):
        name = f"ConfigRegister_{width}_{addr_width}_{data_width}_{addr}"
        ports = {
            "clk": magma.In(magma.Clock),
            "reset": magma.In(magma.AsyncReset),
            "O": magma.Out(T),
            "config_addr": magma.In(magma.Bits[addr_width]),
            "config_data": magma.In(magma.Bits[data_width]),
        }
        if use_config_en:
            ports["config_en"] = magma.In(magma.Bit)
        io = magma.IO(**ports)

        reg = magma.Register(magma.Bits[width],
                             has_enable=True,
                             reset_type=magma.AsyncReset)()
        magma.wire(io.clk, reg.CLK)
        ce = (io.config_addr == magma.bits(addr, addr_width))
        magma.wire(io.reset, reg.ASYNCRESET)
        if use_config_en:
            ce = ce & io.config_en
        magma.wire(io.config_data[0:width], reg.I)
        magma.wire(magma.enable(ce), reg.CE)
        magma.wire(reg.O, io.O)
Пример #10
0
    class DUT(m.Circuit):
        io = m.IO(done=m.Out(m.Bit)) + m.ClockIO()
        x_len = 32
        n_sets = 256
        b_bytes = 4 * (x_len >> 3)
        b_len = m.bitutils.clog2(b_bytes)
        s_len = m.bitutils.clog2(n_sets)
        t_len = x_len - (s_len + b_len)
        nasti_params = NastiParameters(data_bits=64,
                                       addr_bits=x_len,
                                       id_bits=5)

        dut = Cache(x_len, 1, n_sets, b_bytes)()
        dut_mem = make_NastiIO(nasti_params).undirected_t(name="dut_mem")
        dut_mem.ar @= make_Queue(dut.nasti.ar, 32)
        dut_mem.aw @= make_Queue(dut.nasti.aw, 32)
        dut_mem.w @= make_Queue(dut.nasti.w, 32)
        dut.nasti.b @= make_Queue(dut_mem.b, 32)
        dut.nasti.r @= make_Queue(dut_mem.r, 32)

        gold = GoldCache(x_len, 1, n_sets, b_bytes)()
        gold_req = type(gold.req).undirected_t(name="gold_req")
        gold_resp = type(gold.resp).undirected_t(name="gold_resp")
        gold_mem = make_NastiIO(nasti_params).undirected_t(name="gold_mem")
        gold.req @= make_Queue(gold_req, 32)
        gold_resp @= make_Queue(gold.resp, 32)
        gold_mem.ar @= make_Queue(gold.nasti.ar, 32)
        gold_mem.aw @= make_Queue(gold.nasti.aw, 32)
        gold_mem.w @= make_Queue(gold.nasti.w, 32)
        gold.nasti.b @= make_Queue(gold_mem.b, 32)
        gold.nasti.r @= make_Queue(gold_mem.r, 32)

        size = m.bitutils.clog2(nasti_params.x_data_bits // 8)
        b_bits = b_bytes << 3
        data_beats = b_bits // nasti_params.x_data_bits

        mem = m.Memory(1 << 20, m.UInt[nasti_params.x_data_bits])()

        class MemState(m.Enum):
            IDLE = 0
            WRITE = 1
            WRITE_ACK = 2
            READ = 3

        mem_state = m.Register(init=MemState.IDLE)()

        write_counter = mantle.CounterModM(data_beats,
                                           data_beats.bit_length(),
                                           has_ce=True)
        write_counter.CE @= m.enable((mem_state.O == MemState.WRITE)
                                     & dut_mem.w.valid & gold_mem.w.valid)
        read_counter = mantle.CounterModM(data_beats,
                                          data_beats.bit_length(),
                                          has_ce=True)
        read_counter.CE @= m.enable((mem_state.O == MemState.READ)
                                    & dut_mem.r.ready & gold_mem.r.ready)

        dut_mem.b.valid @= mem_state.O == MemState.WRITE_ACK
        dut_mem.b.data @= NastiWriteResponseChannel(nasti_params, 0)
        dut_mem.r.valid @= mem_state.O == MemState.READ
        dut_mem.r.data @= NastiReadDataChannel(
            nasti_params, 0,
            mem.read(
                ((gold_mem.ar.data.addr) +
                 m.zext_to(read_counter.O, nasti_params.x_addr_bits))[:20]),
            read_counter.COUT)
        gold_mem.ar.ready @= dut_mem.ar.ready
        gold_mem.aw.ready @= dut_mem.aw.ready
        gold_mem.w.ready @= dut_mem.w.ready
        gold_mem.b.valid @= dut_mem.b.valid
        gold_mem.b.data @= dut_mem.b.data
        gold_mem.r.valid @= dut_mem.r.valid
        gold_mem.r.data @= dut_mem.r.data

        mem_wen0 = m.Bit(name="mem_wen0")
        mem_wdata0 = m.UInt[nasti_params.x_data_bits](name="mem_wdata0")
        mem_wen1 = m.Bit(name="mem_wen1")
        mem_wdata1 = m.UInt[nasti_params.x_data_bits](name="mem_wdata1")
        mem_waddr1 = m.UInt[20](name="mem_waddr1")
        mem.write(
            m.mux([dut_mem.w.data.data, mem_wdata1], mem_wen1),
            m.mux([((dut_mem.aw.data.addr) +
                    m.zext_to(write_counter.O, nasti_params.x_addr_bits))[:20],
                   mem_waddr1], mem_wen1), m.enable(mem_wen0 | mem_wen1))
        # m.display("mem_wen0 = %x, mem_wen1 = %x", mem_wen0,
        #           mem_wen1).when(m.posedge(io.CLK))
        # m.display("dut_mem.w.valid = %x",
        #           dut_mem.w.valid).when(m.posedge(io.CLK))
        # m.display("gold_mem.w.valid = %x",
        #           gold_mem.w.valid).when(m.posedge(io.CLK))

        f.assert_immediate(
            (mem_state.O != MemState.IDLE)
            | ~(gold_mem.aw.valid & dut_mem.aw.valid) |
            (dut_mem.aw.data.addr == gold_mem.aw.data.addr),
            failure_msg=(
                "[dut_mem.aw.data.addr] %x != [gold_mem.aw.data.addr] %x",
                dut_mem.aw.data.addr, gold_mem.aw.data.addr))

        f.assert_immediate(
            (mem_state.O != MemState.IDLE)
            | ~(gold_mem.aw.valid & dut_mem.aw.valid)
            | ~(gold_mem.ar.valid & dut_mem.ar.valid) |
            (dut_mem.ar.data.addr == gold_mem.ar.data.addr),
            failure_msg=(
                "[dut_mem.ar.data.addr] %x != [gold_mem.ar.data.addr] %x",
                dut_mem.ar.data.addr, gold_mem.ar.data.addr))

        f.assert_immediate(
            (mem_state.O != MemState.WRITE)
            | ~(gold_mem.w.valid & dut_mem.w.valid) |
            (dut_mem.w.data.data == gold_mem.w.data.data),
            failure_msg=(
                "[dut_mem.w.data.data] %x != [gold_mem.w.data.data] %x",
                dut_mem.w.data.data, gold_mem.w.data.data))

        @m.inline_combinational()
        def mem_fsm():
            dut_mem.w.ready @= False
            dut_mem.aw.ready @= False
            dut_mem.ar.ready @= False

            mem_wen0 @= False

            mem_state.I @= mem_state.O

            if mem_state.O == MemState.IDLE:
                if gold_mem.aw.valid & dut_mem.aw.valid:
                    mem_state.I @= MemState.WRITE
                elif gold_mem.ar.valid & dut_mem.ar.valid:
                    mem_state.I @= MemState.READ
            elif mem_state.O == MemState.WRITE:
                if gold_mem.w.valid & dut_mem.w.valid:
                    mem_wen0 @= True
                    dut_mem.w.ready @= True
                if write_counter.COUT:
                    dut_mem.aw.ready @= True
                    mem_state.I @= MemState.WRITE_ACK
            elif mem_state.O == MemState.WRITE_ACK:
                if gold_mem.b.ready & dut_mem.b.ready:
                    mem_state.I @= MemState.IDLE
            elif mem_state.O == MemState.READ:
                if read_counter.COUT:
                    dut_mem.ar.ready @= True
                    mem_state.I @= MemState.IDLE

        if TRACE:
            m.display("[%0t]: [write] mem[%x] <= %x", m.time(),
                      mem.WADDR.value(), dut_mem.w.data.data).when(
                          m.posedge(io.CLK)).if_(mem_wen0)
            m.display("[%0t]: [read] mem[%x] => %x", m.time(),
                      mem.RADDR.value(),
                      dut_mem.r.data.data).when(m.posedge(
                          io.CLK)).if_((mem_state.O == MemState.READ)
                                       & dut_mem.r.ready & gold_mem.r.ready)

        def rand_data(nasti_params):
            rand_data = BitVector[nasti_params.x_data_bits](0)
            for i in range(nasti_params.x_data_bits // 8):
                rand_data |= BitVector[nasti_params.x_data_bits](
                    random.randint(0, 0xff) << (8 * i))
            return rand_data

        def rand_mask(x_len):
            return BitVector[x_len // 8](random.randint(
                1, (1 << (x_len // 8)) - 2))

        def make_test(rand_data, nasti_params, x_len):
            # Wrapper because function definition in side class namespace
            # doesn't inherit class variables
            def test(b_bits, tag, idx, off, mask=BitVector[x_len // 8](0)):
                test_data = rand_data(nasti_params)
                for i in range((b_bits // nasti_params.x_data_bits) - 1):
                    test_data = test_data.concat(rand_data(nasti_params))
                return m.uint(m.concat(off, idx, tag, test_data, mask))

            return test

        test = make_test(rand_data, nasti_params, x_len)

        tags = []
        for _ in range(3):
            tags.append(BitVector.random(t_len))
        idxs = []
        for _ in range(2):
            idxs.append(BitVector.random(s_len))
        offs = []
        for _ in range(6):
            offs.append(BitVector.random(b_len) & -4)

        init_addr = []
        init_data = []
        _iter = itertools.product(tags, idxs, range(0, data_beats))
        for tag, idx, off in _iter:
            init_addr.append(m.uint(m.concat(BitVector[b_len](off), idx, tag)))
            init_data.append(rand_data(nasti_params))

        test_vec = [
            test(b_bits, tags[0], idxs[0], offs[0]),  # 0: read miss
            test(b_bits, tags[0], idxs[0], offs[1]),  # 1: read hit
            test(b_bits, tags[1], idxs[0], offs[0]),  # 2: read miss
            test(b_bits, tags[1], idxs[0], offs[2]),  # 3: read hit
            test(b_bits, tags[1], idxs[0], offs[3]),  # 4: read hit
            test(b_bits, tags[1], idxs[0], offs[4],
                 rand_mask(x_len)),  # 5: write hit  # noqa
            test(b_bits, tags[1], idxs[0], offs[4]),  # 6: read hit
            test(b_bits, tags[2], idxs[0],
                 offs[5]),  # 7: read miss & write back  # noqa
            test(b_bits, tags[0], idxs[1], offs[0],
                 rand_mask(x_len)),  # 8: write miss  # noqa
            test(b_bits, tags[0], idxs[1], offs[0]),  # 9: read hit
            test(b_bits, tags[0], idxs[1], offs[1]),  # 10: read hit
            test(b_bits, tags[1], idxs[1], offs[2],
                 rand_mask(x_len)),  # 11: write miss & write back  # noqa
            test(b_bits, tags[1], idxs[1], offs[3]),  # 12: read hit
            test(b_bits, tags[2], idxs[1], offs[4]),  # 13: read write back
            test(b_bits, tags[2], idxs[1], offs[5])  # 14: read hit
        ]

        class TestState(m.Enum):
            INIT = 0
            START = 1
            WAIT = 2
            DONE = 3

        state = m.Register(init=TestState.INIT)()
        timeout = m.Register(m.UInt[32])()
        init_m = len(init_addr) - 1
        init_counter = mantle.CounterModM(init_m,
                                          init_m.bit_length(),
                                          has_ce=True)
        init_counter.CE @= m.enable(state.O == TestState.INIT)

        test_m = len(test_vec) - 1
        test_counter = mantle.CounterModM(test_m,
                                          test_m.bit_length(),
                                          has_ce=True)
        test_counter.CE @= m.enable(state.O == TestState.DONE)
        curr_vec = m.mux(test_vec, test_counter.O)
        mask = (curr_vec >> (b_len + s_len + t_len + b_bits))[:x_len // 8]
        data = (curr_vec >> (b_len + s_len + t_len))[:b_bits]
        tag = (curr_vec >> (b_len + s_len))[:t_len]
        idx = (curr_vec >> b_len)[:s_len]
        off = curr_vec[:b_len]

        dut.cpu.req.data.addr @= m.concat(off, idx, tag)
        # TODO: Is truncating this fine?
        req_data = data[:x_len]
        dut.cpu.req.data.data @= req_data
        dut.cpu.req.data.mask @= mask
        dut.cpu.req.valid @= state.O == TestState.WAIT
        dut.cpu.abort @= 0
        gold_req.data @= dut.cpu.req.data.value()
        gold_req.valid @= state.O == TestState.START
        gold_resp.ready @= state.O == TestState.DONE

        mem_waddr1 @= m.mux(init_addr, init_counter.O)[:20]
        mem_wdata1 @= m.mux(init_data, init_counter.O)

        check_resp_data = m.Bit()
        if TRACE:
            m.display("[%0t]: [init] mem[%x] <= %x", m.time(),
                      mem_waddr1, mem_wdata1)\
                .when(m.posedge(io.CLK))\
                .if_(state.O == TestState.INIT)

        @m.inline_combinational()
        def state_fsm():
            timeout.I @= timeout.O
            mem_wen1 @= m.bit(False)
            check_resp_data @= m.bit(False)
            state.I @= state.O
            if state.O == TestState.INIT:
                mem_wen1 @= m.bit(True)
                if init_counter.COUT:
                    state.I @= TestState.START
            elif state.O == TestState.START:
                if gold_req.ready:
                    timeout.I @= m.bits(0, 32)
                    state.I @= TestState.WAIT
            elif state.O == TestState.WAIT:
                timeout.I @= timeout.O + 1
                if dut.cpu.resp.valid & gold_resp.valid:
                    if ~mask.reduce_or():
                        check_resp_data @= m.bit(True)
                    state.I @= TestState.DONE
            elif state.O == TestState.DONE:
                state.I @= TestState.START

        f.assert_immediate((state.O != TestState.WAIT) | (timeout.O < 100))
        f.assert_immediate(
            ~check_resp_data | (dut.cpu.resp.data.data == gold_resp.data.data),
            failure_msg=("dut.cpu.resp.data.data => %x != %x",
                         dut.cpu.resp.data.data, gold_resp.data.data))
        # m.display("mem_state=%x", mem_state.O).when(m.posedge(io.CLK))
        # m.display("test_state=%x", state.O).when(m.posedge(io.CLK))
        # m.display("dut req valid = %x",
        #           dut.cpu.req.valid).when(m.posedge(io.CLK))
        # m.display("gold req valid = %x, ready = %x", gold_req.valid,
        #           gold_req.ready).when(m.posedge(io.CLK))
        # m.display("[%0t]: dut resp data = %x, gold resp data = %x", m.time(),
        #           dut.cpu.resp.data.data, gold_resp.data.data)\
        #     .when(m.posedge(io.CLK))
        io.done @= test_counter.COUT
Пример #11
0
    def __init__(self, x_len):
        nasti_params = NastiParameters(data_bits=64,
                                       addr_bits=x_len,
                                       id_bits=5)
        self.io = m.IO(
            icache=m.Flip(make_NastiIO(nasti_params)),
            dcache=m.Flip(make_NastiIO(nasti_params)),
            nasti=make_NastiIO(nasti_params)) + m.ClockIO(has_reset=True)

        class State(m.Enum):
            IDLE = 0
            ICACHE_READ = 1
            DCACHE_READ = 2
            DCACHE_WRITE = 3
            DCACHE_ACK = 4

        state = m.Register(init=State.IDLE)()

        # write address
        self.io.nasti.aw.data @= self.io.dcache.aw.data
        self.io.nasti.aw.valid @= (self.io.dcache.aw.valid &
                                   (state.O == State.IDLE))
        self.io.dcache.aw.ready @= (self.io.nasti.aw.ready &
                                    (state.O == State.IDLE))
        self.io.icache.aw.ready @= 0

        # write data
        self.io.nasti.w.data @= self.io.dcache.w.data
        self.io.nasti.w.valid @= (self.io.dcache.w.valid &
                                  (state.O == State.DCACHE_WRITE))
        self.io.dcache.w.ready @= (self.io.nasti.w.ready &
                                   (state.O == State.DCACHE_WRITE))
        self.io.icache.w.ready @= 0

        # write ack
        self.io.dcache.b.data @= self.io.nasti.b.data
        self.io.dcache.b.valid @= (self.io.nasti.b.valid &
                                   (state.O == State.DCACHE_ACK))
        self.io.nasti.b.ready @= (self.io.dcache.b.ready &
                                  (state.O == State.DCACHE_ACK))
        self.io.icache.b.valid @= 0
        self.io.icache.b.data.resp @= 0
        self.io.icache.b.data.id @= 0
        self.io.icache.b.data.user @= 0

        # read address
        self.io.nasti.ar.data @= NastiReadAddressChannel(
            nasti_params,
            m.mux([self.io.icache.ar.data.id, self.io.dcache.ar.data.id],
                  self.io.dcache.ar.valid),
            m.mux([self.io.icache.ar.data.addr, self.io.dcache.ar.data.addr],
                  self.io.dcache.ar.valid),
            m.mux([self.io.icache.ar.data.size, self.io.dcache.ar.data.size],
                  self.io.dcache.ar.valid),
            m.mux(
                [self.io.icache.ar.data.length, self.io.dcache.ar.data.length],
                self.io.dcache.ar.valid),
        )
        self.io.nasti.ar.valid @= (
            (self.io.icache.ar.valid | self.io.dcache.ar.valid)
            & ~self.io.nasti.aw.valid.value() & (state.O == State.IDLE))
        self.io.dcache.ar.ready @= (self.io.nasti.ar.ready
                                    & ~self.io.nasti.aw.valid.value() &
                                    (state.O == State.IDLE))
        self.io.icache.ar.ready @= (self.io.dcache.ar.ready.value()
                                    & ~self.io.dcache.ar.valid)

        # read data
        self.io.icache.r.data @= self.io.nasti.r.data
        self.io.dcache.r.data @= self.io.nasti.r.data
        self.io.icache.r.valid @= (self.io.nasti.r.valid &
                                   (state.O == State.ICACHE_READ))
        self.io.dcache.r.valid @= (self.io.nasti.r.valid &
                                   (state.O == State.DCACHE_READ))
        self.io.nasti.r.ready @= ((self.io.icache.r.ready &
                                   (state.O == State.ICACHE_READ)) |
                                  (self.io.dcache.r.ready &
                                   (state.O == State.DCACHE_READ)))

        @m.inline_combinational()
        def logic():
            state.I @= state.O
            if state.O == State.IDLE:
                if (self.io.dcache.aw.valid & self.io.dcache.aw.ready.value()):
                    state.I @= State.DCACHE_WRITE
                elif (self.io.dcache.ar.valid
                      & self.io.dcache.ar.ready.value()):
                    state.I @= State.DCACHE_READ
                elif (self.io.icache.ar.valid
                      & self.io.icache.ar.ready.value()):
                    state.I @= State.ICACHE_READ
            elif state.O == State.ICACHE_READ:
                if self.io.nasti.r.fired() & self.io.nasti.r.data.last:
                    state.I @= State.IDLE
            elif state.O == State.DCACHE_READ:
                if self.io.nasti.r.fired() & self.io.nasti.r.data.last:
                    state.I @= State.IDLE
            elif state.O == State.DCACHE_WRITE:
                if (self.io.dcache.w.valid & self.io.dcache.w.ready.value()
                        & self.io.dcache.w.data.last):
                    state.I @= State.DCACHE_ACK
            elif state.O == State.DCACHE_ACK:
                if self.io.nasti.b.fired():
                    state.I @= State.IDLE
Пример #12
0
    class CSR_DUT(m.Circuit):
        io = m.IO(done=m.Out(m.Bit),
                  check=m.Out(m.Bit),
                  rdata=m.Out(m.UInt[x_len]),
                  expected_rdata=m.Out(m.UInt[x_len]),
                  epc=m.Out(m.UInt[x_len]),
                  expected_epc=m.Out(m.UInt[x_len]),
                  evec=m.Out(m.UInt[x_len]),
                  expected_evec=m.Out(m.UInt[x_len]),
                  expt=m.Out(m.Bit),
                  expected_expt=m.Out(m.Bit))
        io += m.ClockIO(has_reset=True)

        regs = {}
        for reg in CSR.regs:
            if reg == CSR.mcpuid:
                init = (1 << (ord('I') - ord('A')) | 1 <<
                        (ord('U') - ord('A')))
            elif reg == CSR.mstatus:
                init = (CSR.PRV_M.ext(30) << 4) | (CSR.PRV_M.ext(30) << 1)
            elif reg == CSR.mtvec:
                init = Const.PC_EVEC
            else:
                init = 0
            regs[reg] = m.Register(init=BV[32](init), reset_type=m.Reset)()

        csr = CSRGen(x_len)()
        ctrl = Control.Control(x_len)()

        counter = CounterModM(n, n.bit_length())
        inst = m.mux(insts, counter.O)
        ctrl.inst @= inst
        csr.inst @= inst
        csr_cmd = ctrl.csr_cmd
        csr.cmd @= csr_cmd
        csr.illegal @= ctrl.illegal
        csr.st_type @= ctrl.st_type
        csr.ld_type @= ctrl.ld_type
        csr.pc_check @= ctrl.pc_sel == Control.PC_ALU
        csr.pc @= m.mux(pc, counter.O)
        csr.addr @= m.mux(addr, counter.O)
        csr.I @= m.mux(data, counter.O)
        csr.stall @= False
        csr.host.fromhost.valid @= False
        csr.host.fromhost.data @= 0

        # values known statically
        _csr_addr = [csr(inst) for inst in insts]
        _rs1_addr = [rs1(inst) for inst in insts]
        _csr_ro = [((((x >> 11) & 0x1) > 0x0) & (((x >> 10) & 0x1) > 0x0)) |
                   (x == CSR.mtvec) | (x == CSR.mtdeleg) for x in _csr_addr]
        _csr_valid = [x in CSR.regs for x in _csr_addr]
        # should be <= prv in runtime
        _prv_level = [(x >> 8) & 0x3 for x in _csr_addr]
        # should consider prv in runtime
        _is_ecall = [((x & 0x1) == 0x0) & (((x >> 8) & 0x1) == 0x0)
                     for x in _csr_addr]
        _is_ebreak = [((x & 0x1) > 0x0) & (((x >> 8) & 0x1) == 0x0)
                      for x in _csr_addr]
        _is_eret = [((x & 0x1) == 0x0) & (((x >> 8) & 0x1) > 0x0)
                    for x in _csr_addr]
        # should consider pc_check in runtime
        _iaddr_invalid = [((x >> 1) & 0x1) > 0 for x in addr]
        # should consider ld_type & sd_type
        _waddr_invalid = [(((x >> 1) & 0x1) > 0) | ((x & 0x1) > 0)
                          for x in addr]
        _haddr_invalid = [(x & 0x1) > 0 for x in addr]

        # values known at runtime
        csr_addr = m.mux(_csr_addr, counter.O)
        rs1_addr = m.mux(_rs1_addr, counter.O)
        csr_ro = m.mux(_csr_ro, counter.O)
        csr_valid = m.mux(_csr_valid, counter.O)

        wen = (csr_cmd == CSR.W) | (csr_cmd[1] & (rs1_addr != 0))
        prv1 = (regs[CSR.mstatus].O >> 4) & 0x3
        ie1 = (regs[CSR.mstatus].O >> 3) & 0x1
        prv = (regs[CSR.mstatus].O >> 1) & 0x3
        ie = regs[CSR.mstatus].O & 0x1
        prv_inst = csr_cmd == CSR.P
        prv_valid = (m.uint(m.zext_to(m.mux(_prv_level, counter.O), 32)) <=
                     m.uint(prv))
        iaddr_invalid = m.mux(_iaddr_invalid, counter.O) & csr.pc_check.value()
        laddr_invalid = (m.mux(_haddr_invalid, counter.O) &
                         ((ctrl.ld_type == Control.LD_LH) |
                          (ctrl.ld_type == Control.LD_LHU))
                         | m.mux(_waddr_invalid, counter.O) &
                         (ctrl.ld_type == Control.LD_LW))
        saddr_invalid = (m.mux(_haddr_invalid, counter.O) &
                         (ctrl.st_type == Control.ST_SH)
                         | m.mux(_waddr_invalid, counter.O) &
                         (ctrl.st_type == Control.ST_SW))
        is_ecall = prv_inst & m.mux(_is_ecall, counter.O)
        is_ebreak = prv_inst & m.mux(_is_ebreak, counter.O)
        is_eret = prv_inst & m.mux(_is_eret, counter.O)
        exception = (ctrl.illegal | iaddr_invalid | laddr_invalid
                     | saddr_invalid | (((csr_cmd & 0x3) > 0) &
                                        (~csr_valid | ~prv_valid)) |
                     (csr_ro & wen) | (prv_inst & ~prv_valid) | is_ecall
                     | is_ebreak)
        instret = (inst != nop) & (~exception | is_ecall | is_ebreak)

        rdata = m.dict_lookup({key: value.O
                               for key, value in regs.items()}, csr_addr)
        wdata = m.dict_lookup(
            {
                CSR.W: csr.I.value(),
                CSR.S: (csr.I.value() | rdata),
                CSR.C: (~csr.I.value() & rdata)
            }, csr_cmd)

        # compute state
        regs[CSR.time].I @= regs[CSR.time].O + 1
        regs[CSR.timew].I @= regs[CSR.timew].O + 1
        regs[CSR.mtime].I @= regs[CSR.mtime].O + 1
        regs[CSR.cycle].I @= regs[CSR.cycle].O + 1
        regs[CSR.cyclew].I @= regs[CSR.cyclew].O + 1

        time_max = regs[CSR.time].O.reduce_and()
        # TODO: mtime has same default value as this case (from chisel code)
        # https://github.com/ucb-bar/riscv-mini/blob/release/src/test/scala/CSRTests.scala#L140
        # mtime_reg = regs[CSR.mtime]
        # mtime_reg.I @= m.mux([mtime_reg.O, mtime_reg.O + 1], time_max)

        incr_when(regs[CSR.timeh], time_max)
        incr_when(regs[CSR.timehw], time_max)

        cycle_max = regs[CSR.cycle].O.reduce_and()

        incr_when(regs[CSR.cycleh], cycle_max)
        incr_when(regs[CSR.cyclehw], cycle_max)

        incr_when(regs[CSR.instret], instret)
        incr_when(regs[CSR.instretw], instret)

        instret_max = regs[CSR.instret].O.reduce_and()
        incr_when(regs[CSR.instreth], instret & instret_max)
        incr_when(regs[CSR.instrethw], instret & instret_max)

        cond = ~exception & ~is_eret & wen
        # Assuming these are mutually exclusive, so we don't need chained
        # elsewhen
        update_when(regs[CSR.mstatus], m.zext_to(wdata[0:6], 32),
                    cond & (csr_addr == CSR.mstatus))
        update_when(regs[CSR.mip],
                    (m.bits(wdata[7], 32) << 7) | (m.bits(wdata[3], 32) << 3),
                    cond & (csr_addr == CSR.mip))
        update_when(regs[CSR.mie],
                    (m.bits(wdata[7], 32) << 7) | (m.bits(wdata[3], 32) << 3),
                    cond & (csr_addr == CSR.mie))
        update_when(regs[CSR.mepc], (wdata >> 2) << 2,
                    cond & (csr_addr == CSR.mepc))
        update_when(regs[CSR.mcause], wdata & (1 << 31 | 0xf),
                    cond & (csr_addr == CSR.mcause))
        update_when(regs[CSR.time], wdata,
                    cond & ((csr_addr == CSR.timew) | (csr_addr == CSR.mtime)))
        update_when(regs[CSR.timew], wdata,
                    cond & ((csr_addr == CSR.timew) | (csr_addr == CSR.mtime)))
        update_when(regs[CSR.mtime], wdata,
                    cond & ((csr_addr == CSR.timew) | (csr_addr == CSR.mtime)))
        update_when(
            regs[CSR.timeh], wdata,
            cond & ((csr_addr == CSR.timehw) | (csr_addr == CSR.mtimeh)))
        update_when(
            regs[CSR.timehw], wdata,
            cond & ((csr_addr == CSR.timehw) | (csr_addr == CSR.mtimeh)))
        update_when(
            regs[CSR.mtimeh], wdata,
            cond & ((csr_addr == CSR.timehw) | (csr_addr == CSR.mtimeh)))
        update_when(regs[CSR.cycle], wdata, cond & (csr_addr == CSR.cyclew))
        update_when(regs[CSR.cyclew], wdata, cond & (csr_addr == CSR.cyclew))
        update_when(regs[CSR.cycleh], wdata, cond & (csr_addr == CSR.cyclehw))
        update_when(regs[CSR.cyclehw], wdata, cond & (csr_addr == CSR.cyclehw))
        update_when(regs[CSR.instret], wdata,
                    cond & (csr_addr == CSR.instretw))
        update_when(regs[CSR.instretw], wdata,
                    cond & (csr_addr == CSR.instretw))
        update_when(regs[CSR.instreth], wdata,
                    cond & (csr_addr == CSR.instrethw))
        update_when(regs[CSR.instrethw], wdata,
                    cond & (csr_addr == CSR.instrethw))
        update_when(regs[CSR.mtimecmp], wdata,
                    cond & (csr_addr == CSR.mtimecmp))
        update_when(regs[CSR.mscratch], wdata,
                    cond & (csr_addr == CSR.mscratch))
        update_when(regs[CSR.mbadaddr], wdata,
                    cond & (csr_addr == CSR.mbadaddr))
        update_when(regs[CSR.mtohost], wdata, cond & (csr_addr == CSR.mtohost))
        update_when(regs[CSR.mfromhost], wdata,
                    cond & (csr_addr == CSR.mfromhost))

        # eret
        update_when(regs[CSR.mstatus],
                    (CSR.PRV_U.zext(30) << 4) | (1 << 3) | (prv1 << 1) | ie1,
                    ~exception & is_eret)

        # TODO: exception logic comes after since it has priority
        Cause = make_Cause(x_len)
        mcause = m.mux([
            m.mux([
                m.mux([
                    m.mux([
                        m.mux([Cause.IllegalInst, Cause.Breakpoint],
                              is_ebreak),
                        Cause.Ecall + prv,
                    ], is_ecall),
                    Cause.StoreAddrMisaligned,
                ], saddr_invalid),
                Cause.LoadAddrMisaligned,
            ], laddr_invalid),
            Cause.InstAddrMisaligned,
        ], iaddr_invalid)
        update_when(regs[CSR.mcause], mcause, exception)

        update_when(regs[CSR.mepc], (csr.pc.value() >> 2) << 2, exception)
        update_when(regs[CSR.mstatus],
                    (prv << 4) | (ie << 3) | (CSR.PRV_M.zext(30) << 1),
                    exception)
        update_when(
            regs[CSR.mbadaddr], csr.addr.value(),
            exception & (iaddr_invalid | laddr_invalid | saddr_invalid))

        epc = regs[CSR.mepc].O
        evec = regs[CSR.mtvec].O + (prv << 6)

        m.display("*** Counter: %d ***", counter.O)
        m.display("[in] inst: 0x%x, pc: 0x%x, addr: 0x%x, in: 0x%x", csr.inst,
                  csr.pc, csr.addr, csr.I)

        m.display(
            "     cmd: 0x%x, st_type: 0x%x, ld_type: 0x%x, illegal: %d, "
            "pc_check: %d", csr.cmd, csr.st_type, csr.ld_type, csr.illegal,
            csr.pc_check)

        m.display("[state] csr addr: %x", csr_addr)

        for reg_addr, reg in regs.items():
            m.display(f" {hex(int(reg_addr))} -> 0x%x", reg.O)

        m.display(
            "[out] read: 0x%x =? 0x%x, epc: 0x%x =? 0x%x, evec: 0x%x ?= "
            "0x%x, expt: %d ?= %d", csr.O, rdata, csr.epc, epc, csr.evec, evec,
            csr.expt, exception)
        io.check @= counter.O.reduce_or()

        io.rdata @= csr.O
        io.expected_rdata @= rdata

        io.epc @= csr.epc
        io.expected_epc @= epc

        io.evec @= csr.evec
        io.expected_evec @= evec

        io.expt @= csr.expt
        io.expected_expt @= exception

        # io.failed @= counter.O.reduce_or() & (
        #     (csr.O != rdata) |
        #     (csr.epc != epc) |
        #     (csr.evec != evec) |
        #     (csr.expt != exception)
        # )
        io.done @= counter.COUT
        for key, reg in regs.items():
            if not reg.I.driven():
                reg.I @= reg.O
Пример #13
0
def reg_init(t, init): return m.Register(t, init = init, reset_type = m.Reset)()

def sl(b, s): return m.uint(0, s).concat(b)
Пример #14
0
    def __init__(self, x_len):
        Cause = make_Cause(x_len)

        self.io = io = m.IO(
            stall=m.In(m.Bit),
            cmd=m.In(m.UInt[3]),
            I=m.In(m.UInt[x_len]),
            O=m.Out(m.UInt[x_len]),
            # Excpetion
            pc=m.In(m.UInt[x_len]),
            addr=m.In(m.UInt[x_len]),
            inst=m.In(m.UInt[x_len]),
            illegal=m.In(m.Bit),
            st_type=m.In(m.UInt[2]),
            ld_type=m.In(m.UInt[3]),
            pc_check=m.In(m.Bit),
            expt=m.Out(m.Bit),
            evec=m.Out(m.UInt[x_len]),
            epc=m.Out(
                m.UInt[x_len])) + HostIO(x_len) + m.ClockIO(has_reset=True)

        csr_addr = io.inst[20:32]
        rs1_addr = io.inst[15:20]

        # user counters
        time = m.Register(m.UInt[x_len], reset_type=m.Reset)()
        timeh = m.Register(m.UInt[x_len], reset_type=m.Reset)()
        cycle = m.Register(m.UInt[x_len], reset_type=m.Reset)()
        cycleh = m.Register(m.UInt[x_len], reset_type=m.Reset)()
        instret = m.Register(m.UInt[x_len], reset_type=m.Reset)()
        instreth = m.Register(m.UInt[x_len], reset_type=m.Reset)()

        mcpuid = m.concat(
            BV[26](
                1 << (ord('I') - ord('A')) |  # Base ISA
                1 << (ord('U') - ord('A'))),  # User Mode
            BV[x_len - 28](0),
            BV[2](0),  # RV32I
        )
        mimpid = BV[x_len](0)
        mhartid = BV[x_len](0)

        # interrupt enable stack
        PRV = m.Register(m.UInt[len(CSR.PRV_M)],
                         init=CSR.PRV_M,
                         reset_type=m.Reset)()
        PRV1 = m.Register(m.UInt[len(CSR.PRV_M)],
                          init=CSR.PRV_M,
                          reset_type=m.Reset)()
        PRV2 = BV[2](0)
        PRV3 = BV[2](0)
        IE = m.Register(m.Bit, init=False, reset_type=m.Reset)()
        IE1 = m.Register(m.Bit, init=False, reset_type=m.Reset)()
        IE2 = False
        IE3 = False

        # virtualization management field
        VM = BV[5](0)

        # memory privilege
        MPRV = False

        # Extension context status
        XS = BV[2](0)
        FS = BV[2](0)
        SD = BV[1](0)
        mstatus = m.concat(IE.O, PRV.O, IE1.O, PRV1.O, IE2, PRV2, IE3, PRV3,
                           FS, XS, MPRV, VM, BV[x_len - 23](0), SD)
        mtvec = BV[x_len](Const.PC_EVEC)
        mtdeleg = BV[x_len](0)

        # interrupt registers
        MTIP = m.Register(m.Bit, init=False, reset_type=m.Reset)()
        HTIP = False
        STIP = False
        MTIE = m.Register(m.Bit, init=False, reset_type=m.Reset)()
        HTIE = False
        STIE = False
        MSIP = m.Register(m.Bit, init=False, reset_type=m.Reset)()
        HSIP = False
        SSIP = False
        MSIE = m.Register(m.Bit, init=False, reset_type=m.Reset)()
        HSIE = False
        SSIE = False

        mip = m.concat(Bit(False), SSIP, HSIP, MSIP.O, Bit(False), STIP, HTIP,
                       MTIP.O, BV[x_len - 8](0))
        mie = m.concat(Bit(False), SSIE, HSIE, MSIE.O, Bit(False), STIE, HTIE,
                       MTIE.O, BV[x_len - 8](0))

        mtimecmp = m.Register(m.UInt[x_len], reset_type=m.Reset)()
        mscratch = m.Register(m.UInt[x_len], reset_type=m.Reset)()

        mepc = m.Register(m.UInt[x_len], reset_type=m.Reset)()
        mcause = m.Register(m.UInt[x_len], reset_type=m.Reset)()
        mbadaddr = m.Register(m.UInt[x_len], reset_type=m.Reset)()

        mtohost = m.Register(m.UInt[x_len], reset_type=m.Reset)()
        mfromhost = m.Register(m.UInt[x_len], reset_type=m.Reset)()

        io.host.tohost @= mtohost.O
        csr_file = {
            CSR.cycle: cycle.O,
            CSR.time: time.O,
            CSR.instret: instret.O,
            CSR.cycleh: cycleh.O,
            CSR.timeh: timeh.O,
            CSR.instreth: instreth.O,
            CSR.cyclew: cycle.O,
            CSR.timew: time.O,
            CSR.instretw: instret.O,
            CSR.cyclehw: cycleh.O,
            CSR.timehw: timeh.O,
            CSR.instrethw: instreth.O,
            CSR.mcpuid: mcpuid,
            CSR.mimpid: mimpid,
            CSR.mhartid: mhartid,
            CSR.mtvec: mtvec,
            CSR.mtdeleg: mtdeleg,
            CSR.mie: mie,
            CSR.mtimecmp: mtimecmp.O,
            CSR.mtime: time.O,
            CSR.mtimeh: timeh.O,
            CSR.mscratch: mscratch.O,
            CSR.mepc: mepc.O,
            CSR.mcause: mcause.O,
            CSR.mbadaddr: mbadaddr.O,
            CSR.mip: mip,
            CSR.mtohost: mtohost.O,
            CSR.mfromhost: mfromhost.O,
            CSR.mstatus: mstatus,
        }
        out = m.dict_lookup(csr_file, csr_addr)
        io.O @= out

        priv_valid = csr_addr[8:10] <= PRV.O
        priv_inst = io.cmd == CSR.P
        is_E_call = priv_inst & ~csr_addr[0] & ~csr_addr[8]
        is_E_break = priv_inst & csr_addr[0] & ~csr_addr[8]
        is_E_ret = priv_inst & ~csr_addr[0] & csr_addr[8]
        csr_valid = m.reduce(operator.or_,
                             m.bits([csr_addr == key for key in csr_file]))
        csr_RO = (csr_addr[10:12].reduce_and() | (csr_addr == CSR.mtvec) |
                  (csr_addr == CSR.mtdeleg))
        wen = (io.cmd == CSR.W) | io.cmd[1] & rs1_addr.reduce_or()
        wdata = m.dict_lookup(
            {
                CSR.W: io.I,
                CSR.S: out | io.I,
                CSR.C: out & ~io.I
            }, io.cmd)

        iaddr_invalid = io.pc_check & io.addr[1]

        laddr_invalid = m.dict_lookup(
            {
                Control.LD_LW: io.addr[0:2].reduce_or(),
                Control.LD_LH: io.addr[0],
                Control.LD_LHU: io.addr[0]
            }, io.ld_type)

        saddr_invalid = m.dict_lookup(
            {
                Control.ST_SW: io.addr[0:2].reduce_or(),
                Control.ST_SH: io.addr[0]
            }, io.st_type)

        expt = (io.illegal | iaddr_invalid | laddr_invalid | saddr_invalid
                | io.cmd[0:2].reduce_or() & (~csr_valid | ~priv_valid)
                | wen & csr_RO | (priv_inst & ~priv_valid) | is_E_call
                | is_E_break)
        io.expt @= expt

        io.evec @= mtvec + (m.zext_to(PRV.O, x_len) << 6)
        io.epc @= mepc.O

        @m.inline_combinational()
        def logic():
            # Counters
            time.I @= time.O + 1
            timeh.I @= timeh.O
            if time.O.reduce_and():
                timeh.I @= timeh.O + 1

            cycle.I @= cycle.O + 1
            cycleh.I @= cycleh.O
            if cycle.O.reduce_and():
                cycleh.I @= cycleh.O + 1
            instret.I @= instret.O
            is_inst_ret = ((io.inst != Instructions.NOP) &
                           (~expt | is_E_call | is_E_break) & ~io.stall)
            if is_inst_ret:
                instret.I @= instret.O + 1
            instreth.I @= instreth.O
            if is_inst_ret & instret.O.reduce_and():
                instreth.I @= instreth.O + 1

            mbadaddr.I @= mbadaddr.O
            mepc.I @= mepc.O
            mcause.I @= mcause.O
            PRV.I @= PRV.O
            IE.I @= IE.O
            IE1.I @= IE1.O
            PRV1.I @= PRV1.O
            MTIP.I @= MTIP.O
            MSIP.I @= MSIP.O
            MTIE.I @= MTIE.O
            MSIE.I @= MSIE.O
            mtimecmp.I @= mtimecmp.O
            mscratch.I @= mscratch.O
            mtohost.I @= mtohost.O
            mfromhost.I @= mfromhost.O
            if io.host.fromhost.valid:
                mfromhost.I @= io.host.fromhost.data

            if ~io.stall:
                if expt:
                    mepc.I @= io.pc >> 2 << 2
                    if iaddr_invalid:
                        mcause.I @= Cause.InstAddrMisaligned
                    elif laddr_invalid:
                        mcause.I @= Cause.LoadAddrMisaligned
                    elif saddr_invalid:
                        mcause.I @= Cause.StoreAddrMisaligned
                    elif is_E_call:
                        mcause.I @= Cause.Ecall + m.zext_to(PRV.O, x_len)
                    elif is_E_break:
                        mcause.I @= Cause.Breakpoint
                    else:
                        mcause.I @= Cause.IllegalInst
                    PRV.I @= CSR.PRV_M
                    IE.I @= False
                    PRV1.I @= PRV.O
                    IE1.I @= IE.O
                    if iaddr_invalid | laddr_invalid | saddr_invalid:
                        mbadaddr.I @= io.addr
                elif is_E_ret:
                    PRV.I @= PRV1.O
                    IE.I @= IE1.O
                    PRV1.I @= CSR.PRV_U
                    IE1.I @= True
                elif wen:
                    if csr_addr == CSR.mstatus:
                        PRV1.I @= wdata[4:6]
                        IE1.I @= wdata[3]
                        PRV.I @= wdata[1:3]
                        IE.I @= wdata[0]
                    elif csr_addr == CSR.mip:
                        MTIP.I @= wdata[7]
                        MSIP.I @= wdata[3]
                    elif csr_addr == CSR.mie:
                        MTIE.I @= wdata[7]
                        MSIE.I @= wdata[3]
                    elif csr_addr == CSR.mtime:
                        time.I @= wdata
                    elif csr_addr == CSR.mtimeh:
                        timeh.I @= wdata
                    elif csr_addr == CSR.mtimecmp:
                        mtimecmp.I @= wdata
                    elif csr_addr == CSR.mscratch:
                        mscratch.I @= wdata
                    elif csr_addr == CSR.mepc:
                        mepc.I @= wdata >> 2 << 2
                    elif csr_addr == CSR.mcause:
                        mcause.I @= wdata & (1 << (x_len - 1) | 0xf)
                    elif csr_addr == CSR.mbadaddr:
                        mbadaddr.I @= wdata
                    elif csr_addr == CSR.mtohost:
                        mtohost.I @= wdata
                    elif csr_addr == CSR.mfromhost:
                        mfromhost.I @= wdata
                    elif csr_addr == CSR.cyclew:
                        cycle.I @= wdata
                    elif csr_addr == CSR.timew:
                        time.I @= wdata
                    elif csr_addr == CSR.instretw:
                        instret.I @= wdata
                    elif csr_addr == CSR.cyclehw:
                        cycleh.I @= wdata
                    elif csr_addr == CSR.timehw:
                        timeh.I @= wdata
                    elif csr_addr == CSR.instrethw:
                        instreth.I @= wdata
Пример #15
0
def reg(t): return m.Register(t)()

def reg_init(t, init): return m.Register(t, init = init, reset_type = m.Reset)()
Пример #16
0
 class Main(m.Circuit):
     io = m.IO(I=m.In(m.Bit), O=m.Out(m.Bit)) + m.ClockIO()
     io.O @= m.Register(T=m.Bit)()(io.I)
     f.assume(io.I | f.delay[1] | ~io.I, on=f.posedge(io.CLK))
Пример #17
0
    def __init__(self,
                 x_len,
                 ALU=ALUArea,
                 ImmGen=ImmGenWire,
                 BrCond=BrCondArea):
        self.io = make_DatapathIO(x_len) + m.ClockIO(has_reset=True)
        csr = CSRGen(x_len)()
        reg_file = RegFile(x_len)()
        alu = ALU(x_len)()
        imm_gen = ImmGen(x_len)()
        br_cond = BrCondArea(x_len)()

        # Fetch / Execute Registers
        fe_inst = m.Register(init=Instructions.NOP, has_enable=True)()
        fe_pc = m.Register(m.UInt[x_len], has_enable=True)()

        # Execute / Write Back Registers
        ew_inst = m.Register(init=Instructions.NOP)()
        ew_pc = m.Register(m.UInt[x_len])()
        ew_alu = m.Register(m.UInt[x_len])()
        csr_in = m.Register(m.UInt[x_len])()

        # Control signals
        st_type = m.Register(type(self.io.ctrl.st_type).undirected_t)()
        ld_type = m.Register(type(self.io.ctrl.ld_type).undirected_t)()
        wb_sel = m.Register(type(self.io.ctrl.wb_sel).undirected_t)()
        wb_en = m.Register(m.Bit)()
        csr_cmd = m.Register(type(self.io.ctrl.csr_cmd).undirected_t)()
        illegal = m.Register(m.Bit)()
        pc_check = m.Register(m.Bit)()

        # Fetch
        started = m.Register(m.Bit)()(m.bit(self.io.RESET))
        stall = ~self.io.icache.resp.valid | ~self.io.dcache.resp.valid
        pc = m.Register(init=UIntVector[x_len](Const.PC_START) -
                        UIntVector[x_len](4))()
        npc = m.mux([
            m.mux([
                m.mux([
                    m.mux([
                        m.mux([pc.O + m.uint(4, x_len), pc.O],
                              self.io.ctrl.pc_sel == PC_0), alu.sum_ >> 1 << 1
                    ], (self.io.ctrl.pc_sel == PC_ALU) | br_cond.taken),
                    csr.epc
                ], self.io.ctrl.pc_sel == PC_EPC), csr.evec
            ], csr.expt), pc.O
        ], stall)

        inst = m.mux([self.io.icache.resp.data.data, Instructions.NOP], started
                     | self.io.ctrl.inst_kill | br_cond.taken | csr.expt)

        pc.I @= npc
        self.io.icache.req.data.addr @= npc
        self.io.icache.req.data.data @= 0
        self.io.icache.req.data.mask @= 0
        self.io.icache.req.valid @= ~stall
        self.io.icache.abort @= False

        fe_pc.I @= pc.O
        fe_pc.CE @= m.enable(~stall)
        fe_inst.I @= inst
        fe_inst.CE @= m.enable(~stall)

        # Execute
        # Decode
        self.io.ctrl.inst @= fe_inst.O

        # reg_file read
        rs1_addr = fe_inst.O[15:20]
        rs2_addr = fe_inst.O[20:25]
        reg_file.raddr1 @= rs1_addr
        reg_file.raddr2 @= rs2_addr

        # gen immediates
        imm_gen.inst @= fe_inst.O
        imm_gen.sel @= self.io.ctrl.imm_sel

        # bypass
        wb_rd_addr = ew_inst.O[7:12]
        rs1_hazard = wb_en.O & rs1_addr.reduce_or() & (rs1_addr == wb_rd_addr)
        rs2_hazard = wb_en.O & rs2_addr.reduce_or() & (rs2_addr == wb_rd_addr)
        rs1 = m.mux([reg_file.rdata1, ew_alu.O],
                    (wb_sel.O == WB_ALU) & rs1_hazard)
        rs2 = m.mux([reg_file.rdata2, ew_alu.O],
                    (wb_sel.O == WB_ALU) & rs2_hazard)

        # ALU operations
        alu.A @= m.mux([fe_pc.O, rs1], self.io.ctrl.A_sel == A_RS1)
        alu.B @= m.mux([imm_gen.O, rs2], self.io.ctrl.B_sel == B_RS2)
        alu.op @= self.io.ctrl.alu_op

        # Branch condition calc
        br_cond.rs1 @= rs1
        br_cond.rs2 @= rs2
        br_cond.br_type @= self.io.ctrl.br_type

        # D$ access
        daddr = m.mux([alu.sum_, ew_alu.O], stall) >> 2 << 2
        w_offset = ((m.bits(alu.sum_[1], x_len) << 4) |
                    (m.bits(alu.sum_[0], x_len) << 3))
        self.io.dcache.req.valid @= ~stall & (self.io.ctrl.st_type.reduce_or()
                                              |
                                              self.io.ctrl.ld_type.reduce_or())
        self.io.dcache.req.data.addr @= daddr
        self.io.dcache.req.data.data @= rs2 << w_offset
        self.io.dcache.req.data.mask @= m.dict_lookup(
            {
                ST_SW: m.bits(0b1111, 4),
                ST_SH: m.bits(0b11, 4) << m.zext(alu.sum_[0:2], 2),
                ST_SB: m.bits(0b1, 4) << m.zext(alu.sum_[0:2], 2),
            }, m.mux([self.io.ctrl.st_type, st_type.O], stall), m.bits(0, 4))

        # Pipelining
        @m.inline_combinational()
        def pipeline_logic():
            ew_pc.I @= ew_pc.O
            ew_inst.I @= ew_inst.O
            ew_alu.I @= ew_alu.O
            csr_in.I @= csr_in.O
            st_type.I @= st_type.O
            ld_type.I @= ld_type.O
            wb_sel.I @= wb_sel.O
            wb_en.I @= wb_en.O
            csr_cmd.I @= csr_cmd.O
            illegal.I @= illegal.O
            pc_check.I @= pc_check.O
            if m.bit(self.io.RESET) | ~stall & csr.expt:
                st_type.I @= 0
                ld_type.I @= 0
                wb_en.I @= 0
                csr_cmd.I @= 0
                illegal.I @= False
                pc_check.I @= False
            elif ~stall & ~csr.expt:
                ew_pc.I @= fe_pc.O
                ew_inst.I @= fe_inst.O
                ew_alu.I @= alu.O
                csr_in.I @= m.mux([rs1, imm_gen.O],
                                  self.io.ctrl.imm_sel == IMM_Z)
                st_type.I @= self.io.ctrl.st_type
                ld_type.I @= self.io.ctrl.ld_type
                wb_sel.I @= self.io.ctrl.wb_sel
                wb_en.I @= self.io.ctrl.wb_en
                csr_cmd.I @= self.io.ctrl.csr_cmd
                illegal.I @= self.io.ctrl.illegal
                pc_check.I @= self.io.ctrl.pc_sel == PC_ALU

        # Load
        l_offset = ((m.uint(ew_alu.O[1], x_len) << 4) |
                    (m.uint(ew_alu.O[0], x_len) << 3))
        l_shift = self.io.dcache.resp.data.data >> l_offset
        load = m.dict_lookup(
            {
                LD_LH: m.sext_to(m.sint(l_shift[0:16]), x_len),
                LD_LHU: m.sint(m.zext_to(l_shift[0:16], x_len)),
                LD_LB: m.sext_to(m.sint(l_shift[0:8]), x_len),
                LD_LBU: m.sint(m.zext_to(l_shift[0:8], x_len))
            }, ld_type.O, m.sint(self.io.dcache.resp.data.data))

        # CSR access
        csr.stall @= stall
        csr.I @= csr_in.O
        csr.cmd @= csr_cmd.O
        csr.inst @= ew_inst.O
        csr.pc @= ew_pc.O
        csr.addr @= ew_alu.O
        csr.illegal @= illegal.O
        csr.pc_check @= pc_check.O
        csr.ld_type @= ld_type.O
        csr.st_type @= st_type.O
        self.io.host @= csr.host

        # Regfile write
        reg_write = m.dict_lookup(
            {
                WB_MEM: m.uint(load),
                WB_PC4: (ew_pc.O + 4),
                WB_CSR: csr.O
            }, wb_sel.O, ew_alu.O)

        reg_file.wen @= m.enable(wb_en.O & ~stall & ~csr.expt)
        reg_file.waddr @= wb_rd_addr
        reg_file.wdata @= reg_write

        # Abort store when there's an exception
        self.io.dcache.abort @= csr.expt
Пример #18
0
 class Main(m.Circuit):
     io = m.IO(I=m.In(m.Bits[8]), O=m.Out(m.Bits[8]))
     io += m.IO(clocks=ClockIntf)
     io.O @= m.Register(T=m.Bits[8], reset_type=m.AsyncResetN)()(io.I)
     my_assert(io.I | f.implies | f.delay[1] | io.O)
Пример #19
0
    def __init__(self, x_len, n_ways: int, n_sets: int, b_bytes: int):
        nasti_params = NastiParameters(data_bits=64,
                                       addr_bits=x_len,
                                       id_bits=5)

        self.io = m.IO(req=m.Consumer(m.Decoupled[make_CacheReq(x_len)]),
                       resp=m.Producer(m.Decoupled[make_CacheResp(x_len)]),
                       nasti=make_NastiIO(nasti_params)) + m.ClockIO()
        size = m.bitutils.clog2(nasti_params.x_data_bits)
        b_bits = b_bytes << 3
        b_len = m.bitutils.clog2(b_bytes)
        s_len = m.bitutils.clog2(n_sets)
        t_len = x_len - (s_len + b_len)
        nasti_params = NastiParameters(data_bits=64,
                                       addr_bits=x_len,
                                       id_bits=5)
        data_beats = b_bits // nasti_params.x_data_bits
        length = data_beats - 1

        data = m.Memory(n_sets, m.UInt[b_bits])()
        tags = m.Memory(n_sets, m.UInt[t_len])()
        v = m.Memory(n_sets, m.Bit)()
        d = m.Memory(n_sets, m.Bit)()

        req = self.io.req.data
        tag = (req.addr >> (b_len + s_len))[:t_len]
        idx = req.addr[b_len:b_len + s_len]
        off = req.addr[:b_len]
        read = data.read(idx)
        write = m.bits(0, b_bits)
        for i in range(b_bytes):
            write |= m.mux([(read & (0xff << (8 * i))),
                            ((m.zext_to(req.data, b_bits) >>
                              ((8 * (i & 0x3)))) & 0xff) << (8 * i)],
                           ((off // 4) == (i // 4)) & (req.mask >>
                                                       (i & 0x3))[0])[:b_bits]

        class State(m.Enum):
            IDLE = 0
            WRITE = 1
            WRITE_ACK = 2
            READ = 3

        state = m.Register(init=State.IDLE)()

        write_counter = mantle.CounterModM(data_beats,
                                           max(data_beats.bit_length(), 1),
                                           has_ce=True)
        write_counter.CE @= m.enable(state.O == State.WRITE)
        w_cnt, w_done = write_counter.O, write_counter.COUT

        read_counter = mantle.CounterModM(data_beats,
                                          max(data_beats.bit_length(), 1),
                                          has_ce=True)
        read_counter.CE @= m.enable((state.O == State.READ)
                                    & self.io.nasti.r.valid)
        r_cnt, r_done = read_counter.O, read_counter.COUT

        self.io.resp.data.data @= (read >> (m.zext_to(
            (off // 4), b_bits) * x_len))[:x_len]
        self.io.nasti.ar.data @= NastiReadAddressChannel(
            nasti_params, 0, (req.addr >> b_len) << b_len, size, length)
        tags_rdata = tags.read(idx)
        self.io.nasti.aw.data @= NastiWriteAddressChannel(
            nasti_params, 0,
            m.bits(m.concat(idx, tags_rdata), nasti_params.x_addr_bits) <<
            b_len, size, length)
        self.io.nasti.w.data @= NastiWriteDataChannel(
            nasti_params,
            (read >> (m.zext_to(w_cnt, b_bits) *
                      nasti_params.x_data_bits))[:nasti_params.x_data_bits],
            None, w_done)
        self.io.nasti.w.valid @= state.O == State.WRITE
        self.io.nasti.b.ready @= state.O == State.WRITE_ACK
        self.io.nasti.r.ready @= state.O == State.READ

        d_wen = m.Bit(name="d_wen")
        d.write(True, idx, m.enable(d_wen))

        data_wen = m.Bit(name="data_wen")
        data_wdata = m.UInt[b_bits](name="data_wdata")
        data.write(data_wdata, idx, m.enable(data_wen))
        # m.display("data_wdata=%x", data_wdata).when(m.posedge(self.io.CLK))

        v_wen = m.Bit(name="v_wen")
        v.write(True, idx, m.enable(v_wen))
        v_rdata = v.read(idx)

        tags_wen = m.Bit(name="tags_wen")
        tags.write(tag, idx, m.enable(tags_wen))

        d_rdata = d.read(idx)

        # m.display("gold_state=%x", state.O).when(m.posedge(self.io.CLK))
        # m.display("gold_w_done=%x", w_done).when(m.posedge(self.io.CLK))
        # m.display("gold_b_valid=%x",
        #           self.io.nasti.b.valid).when(m.posedge(self.io.CLK))

        if TRACE:
            m.display(
                "[%0t] [cache] data[%x] <= %x, off: %x, req: %x, mask: %b",
                m.time(), idx, write, off, self.io.req.data.data,
                self.io.req.data.mask)\
                .when(m.posedge(self.io.CLK))\
                .if_((state.O == State.IDLE) &
                     (self.io.req.valid & self.io.resp.ready) &
                     (v_rdata & (tags_rdata == tag)) & req.mask.reduce_or())

            m.display(
                "[%0t] [cache] data[%x] => %x, off: %x, resp: %x", m.time(),
                idx, write, off, self.io.resp.data.data.value())\
                .when(m.posedge(self.io.CLK))\
                .if_((state.O == State.IDLE) &
                     (self.io.req.valid & self.io.resp.ready) &
                     (v_rdata & (tags_rdata == tag)) & ~req.mask.reduce_or())

        @m.inline_combinational()
        def logic():
            self.io.resp.valid @= False
            self.io.req.ready @= False
            self.io.nasti.ar.valid @= False
            self.io.nasti.aw.valid @= False

            d_wen @= False

            data_wen @= False
            data_wdata @= m.UInt[b_bits](0)
            state.I @= state.O

            tags_wen @= False
            v_wen @= False

            if state.O == State.IDLE:
                if self.io.req.valid & self.io.resp.ready:
                    if v_rdata & (tags_rdata == tag):
                        if req.mask.reduce_or():
                            d_wen @= True
                            data_wdata @= write
                            data_wen @= True
                        self.io.req.ready @= True
                        self.io.resp.valid @= True
                    else:
                        if d_rdata:
                            self.io.nasti.aw.valid @= True
                            state.I @= State.WRITE
                        else:
                            data_wdata @= 0
                            data_wen @= True
                            self.io.nasti.ar.valid @= True
                            state.I @= State.READ
            elif state.O == State.WRITE:
                if w_done:
                    state.I @= State.WRITE_ACK
            elif state.O == State.WRITE_ACK:
                if self.io.nasti.b.valid:
                    data_wdata @= 0
                    data_wen @= True
                    self.io.nasti.ar.valid @= True
                    state.I @= State.READ
            elif state.O == State.READ:
                if self.io.nasti.r.valid:
                    data_wdata @= read | (
                        m.zext_to(self.io.nasti.r.data.data, b_bits) <<
                        (m.zext_to(r_cnt, b_bits) * nasti_params.x_data_bits))
                    data_wen @= True
                if r_done:
                    tags_wen @= True
                    v_wen @= True
                    state.I @= State.IDLE
Пример #20
0
        class DUT(m.Circuit):
            io = m.IO(done=m.Out(m.Bit)) + m.ClockIO(has_reset=True)
            core = Core(
                x_len, data_path_kwargs=m.generator.ParamDict(ImmGen=ImmGen))()
            core.host.fromhost.data.undriven()
            core.host.fromhost.valid @= False

            # reverse concat because we're using utils with chisel ordering
            _hex = [concat(*reversed(x)) for x in loadmem]
            imem = RegFileBuilder("imem",
                                  1 << 20,
                                  x_len,
                                  write_forward=False,
                                  reset_type=m.Reset,
                                  backend="verilog")
            dmem = RegFileBuilder("dmem",
                                  1 << 20,
                                  x_len,
                                  write_forward=False,
                                  reset_type=m.Reset,
                                  backend="verilog")

            INIT, RUN = False, True

            state = m.Register(init=INIT)()
            cycle = m.Register(m.UInt[32])()

            n = len(_hex)
            counter = CounterModM(n, n.bit_length(), has_ce=True)
            counter.CE @= m.enable(state.O == INIT)
            cntr, done = counter.O, counter.COUT

            iaddr = (core.icache.req.data.addr // (x_len // 8))[:20]
            daddr = (core.dcache.req.data.addr // (x_len // 8))[:20]

            dmem_data = dmem[daddr]
            imem_data = imem[iaddr]
            write = 0
            for i in range(x_len // 8):
                write |= m.zext_to(
                    m.mux([dmem_data, core.dcache.req.data.data],
                          core.dcache.req.valid
                          & core.dcache.req.data.mask[i])[8 * i:8 * (i + 1)],
                    32) << (8 * i)

            core.RESET @= m.reset(state.O == INIT)

            core.icache.resp.valid @= state.O == RUN
            core.dcache.resp.valid @= state.O == RUN

            core.icache.resp.data.data @= m.Register(
                m.UInt[x_len])()(imem_data)
            core.dcache.resp.data.data @= m.Register(
                m.UInt[x_len])()(dmem_data)

            chunk = m.mux(_hex, cntr)

            imem.write(m.zext_to(cntr, 20), chunk, m.enable(state.O == INIT))

            dmem.write(
                m.mux([m.zext_to(cntr, 20), daddr], state.O == INIT),
                m.mux([chunk, write], state.O == INIT),
                m.enable((state.O == INIT)
                         | (core.dcache.req.valid
                            & core.dcache.req.data.mask.reduce_or())))

            @m.inline_combinational()
            def logic():
                state.I @= state.O
                cycle.I @= cycle.O
                if state.O == INIT:
                    if done:
                        state.I @= RUN
                if state.O == RUN:
                    cycle.I @= cycle.O + 1

            debug = False
            if debug:
                m.display("LOADMEM[%x] <= %x", cntr * (x_len // 8),
                          chunk).when(m.posedge(io.CLK)).if_(state.O == INIT)

                m.display("INST[%x] => %x",
                          iaddr * (x_len // 8), dmem_data).when(
                              m.posedge(io.CLK)).if_((state.O == RUN)
                                                     & core.icache.req.valid)

                m.display("MEM[%x] <= %x", daddr * (x_len // 8), write).when(
                    m.posedge(
                        io.CLK)).if_((state.O == RUN) & core.dcache.req.valid
                                     & core.dcache.req.data.mask.reduce_or())

                m.display(
                    "MEM[%x] => %x", daddr * (x_len // 8),
                    dmem_data).when(m.posedge(
                        io.CLK)).if_((state.O == RUN) & core.dcache.req.valid
                                     & ~core.dcache.req.data.mask.reduce_or())

                m.display("cycles: %d", cycle.O).when(m.posedge(
                    io.CLK)).if_(io.done.value() == 1)
            f.assert_immediate(cycle.O < test.maxcycles)
            io.done @= core.host.tohost != 0
            f.assert_immediate(
                (core.host.tohost >> 1) == 0,
                failure_msg=("* tohost: %d *", core.host.tohost))
Пример #21
0
    def __init__(self, x_len, n_ways: int, n_sets: int, b_bytes: int):
        b_bits = b_bytes << 3
        b_len = m.bitutils.clog2(b_bytes)
        s_len = m.bitutils.clog2(n_sets)
        t_len = x_len - (s_len + b_len)
        n_words = b_bits // x_len
        w_bytes = x_len // 8
        byte_offset_bits = m.bitutils.clog2(w_bytes)
        nasti_params = NastiParameters(data_bits=64,
                                       addr_bits=x_len,
                                       id_bits=5)
        data_beats = b_bits // nasti_params.x_data_bits

        class MetaData(m.Product):
            tag = m.UInt[t_len]

        self.io = m.IO(**make_cache_ports(x_len, nasti_params))
        self.io += m.ClockIO()

        class State(m.Enum):
            IDLE = 0
            READ_CACHE = 1
            WRITE_CACHE = 2
            WRITE_BACK = 3
            WRITE_ACK = 4
            REFILL_READY = 5
            REFILL = 6

        state = m.Register(init=State.IDLE)()

        # memory
        v = m.Register(m.UInt[n_sets], has_enable=True)()
        d = m.Register(m.UInt[n_sets], has_enable=True)()
        meta_mem = m.Memory(n_sets,
                            MetaData,
                            read_latency=1,
                            has_read_enable=True)()
        data_mem = [
            ArrayMaskMem(n_sets,
                         w_bytes,
                         m.UInt[8],
                         read_latency=1,
                         has_read_enable=True)() for _ in range(n_words)
        ]

        addr_reg = m.Register(type(self.io.cpu.req.data.addr).undirected_t,
                              has_enable=True)()
        cpu_data = m.Register(type(self.io.cpu.req.data.data).undirected_t,
                              has_enable=True)()
        cpu_mask = m.Register(type(self.io.cpu.req.data.mask).undirected_t,
                              has_enable=True)()

        self.io.nasti.r.ready @= state.O == State.REFILL
        # Counters
        assert data_beats > 0
        if data_beats > 1:
            read_counter = mantle.CounterModM(data_beats,
                                              max(data_beats.bit_length(), 1),
                                              has_ce=True)
            read_counter.CE @= m.enable(self.io.nasti.r.fired())
            read_count, read_wrap_out = read_counter.O, read_counter.COUT

            write_counter = mantle.CounterModM(data_beats,
                                               max(data_beats.bit_length(), 1),
                                               has_ce=True)
            write_count, write_wrap_out = write_counter.O, write_counter.COUT
        else:
            read_count, read_wrap_out = 0, 1
            write_count, write_wrap_out = 0, 1

        refill_buf = m.Register(m.Array[data_beats,
                                        m.UInt[nasti_params.x_data_bits]],
                                has_enable=True)()
        if data_beats == 1:
            refill_buf.I[0] @= self.io.nasti.r.data.data
        else:
            refill_buf.I @= m.set_index(refill_buf.O,
                                        self.io.nasti.r.data.data,
                                        read_count[:-1])
        refill_buf.CE @= m.enable(self.io.nasti.r.fired())

        is_idle = state.O == State.IDLE
        is_read = state.O == State.READ_CACHE
        is_write = state.O == State.WRITE_CACHE
        is_alloc = (state.O == State.REFILL) & read_wrap_out
        # m.display("[%0t]: is_alloc = %x", m.time(), is_alloc)\
        #     .when(m.posedge(self.io.CLK))
        is_alloc_reg = m.Register(m.Bit)()(is_alloc)

        hit = m.Bit(name="hit")
        wen = is_write & (hit | is_alloc_reg) & ~self.io.cpu.abort | is_alloc
        # m.display("[%0t]: wen = %x", m.time(), wen)\
        #     .when(m.posedge(self.io.CLK))
        ren = m.enable(~wen & (is_idle | is_read) & self.io.cpu.req.valid)
        ren_reg = m.enable(m.Register(m.Bit)()(ren))

        addr = self.io.cpu.req.data.addr
        idx = addr[b_len:s_len + b_len]
        tag_reg = addr_reg.O[s_len + b_len:x_len]
        idx_reg = addr_reg.O[b_len:s_len + b_len]
        off_reg = addr_reg.O[byte_offset_bits:b_len]

        rmeta = meta_mem.read(idx, ren)
        rdata = m.concat(*(mem.read(idx, ren) for mem in data_mem))
        rdata_buf = m.Register(type(rdata), has_enable=True)()(rdata,
                                                               CE=ren_reg)

        read = m.mux([
            m.as_bits(m.mux([rdata_buf, rdata], ren_reg)),
            m.as_bits(refill_buf.O)
        ], is_alloc_reg)
        # m.display("is_alloc_reg=%x", is_alloc_reg)\
        #     .when(m.posedge(self.io.CLK))

        hit @= v.O[idx_reg] & (rmeta.tag == tag_reg)

        # read mux
        self.io.cpu.resp.data.data @= m.array(
            [read[i * x_len:(i + 1) * x_len] for i in range(n_words)])[off_reg]
        self.io.cpu.resp.valid @= (is_idle | (is_read & hit) |
                                   (is_alloc_reg & ~cpu_mask.O.reduce_or()))
        m.display("resp.valid=%x", self.io.cpu.resp.valid.value())\
            .when(m.posedge(self.io.CLK))
        m.display("[%0t]: valid = %x", m.time(),
                  self.io.cpu.resp.valid.value())\
            .when(m.posedge(self.io.CLK))
        m.display("[%0t]: is_idle = %x, is_read = %x, hit = %x, is_alloc_reg = "
                  "%x, ~cpu_mask.O.reduce_or() = %x", m.time(), is_idle,
                  is_read, hit, is_alloc_reg, ~cpu_mask.O.reduce_or())\
            .when(m.posedge(self.io.CLK))
        m.display("[%0t]: refill_buf.O=%x, %x", m.time(), *refill_buf.O)\
            .when(m.posedge(self.io.CLK))\
            .if_(self.io.cpu.resp.valid.value() & is_alloc_reg)
        m.display("[%0t]: read=%x", m.time(), read)\
            .when(m.posedge(self.io.CLK))\
            .if_(self.io.cpu.resp.valid.value() & is_alloc_reg)

        addr_reg.I @= addr
        addr_reg.CE @= m.enable(self.io.cpu.resp.valid.value())

        cpu_data.I @= self.io.cpu.req.data.data
        cpu_data.CE @= m.enable(self.io.cpu.resp.valid.value())

        cpu_mask.I @= self.io.cpu.req.data.mask
        cpu_mask.CE @= m.enable(self.io.cpu.resp.valid.value())

        wmeta = MetaData(name="wmeta")
        wmeta.tag @= tag_reg

        offset_mask = (m.zext_to(cpu_mask.O, w_bytes * 8) << m.concat(
            m.bits(0, byte_offset_bits), off_reg))
        wmask = m.mux([m.SInt[w_bytes * 8](-1),
                       m.sint(offset_mask)], ~is_alloc)

        if len(refill_buf.O) == 1:
            wdata_alloc = self.io.nasti.r.data.data
        else:
            wdata_alloc = m.concat(
                # TODO: not sure why they use `init.reverse`
                # https://github.com/ucb-bar/riscv-mini/blob/release/src/main/scala/Cache.scala#L116
                m.concat(*refill_buf.O[:-1]),
                self.io.nasti.r.data.data)
        wdata = m.mux([wdata_alloc,
                       m.as_bits(m.repeat(cpu_data.O, n_words))], ~is_alloc)

        v.I @= m.set_index(v.O, m.bit(True), idx_reg)
        v.CE @= m.enable(wen)
        d.I @= m.set_index(d.O, ~is_alloc, idx_reg)
        d.CE @= m.enable(wen)
        # m.display("[%0t]: refill_buf.O = %x", m.time(),
        #           m.concat(*refill_buf.O)).when(m.posedge(self.io.CLK)).if_(wen)
        # m.display("[%0t]: nasti.r.data.data = %x", m.time(),
        #           self.io.nasti.r.data.data).when(m.posedge(self.io.CLK)).if_(wen)

        meta_mem.write(wmeta, idx_reg, m.enable(wen & is_alloc))
        for i, mem in enumerate(data_mem):
            data = [
                wdata[i * x_len + j * 8:i * x_len + (j + 1) * 8]
                for j in range(w_bytes)
            ]
            mem.write(m.array(data), idx_reg,
                      wmask[i * w_bytes:(i + 1) * w_bytes], m.enable(wen))
            # m.display("[%0t]: wdata = %x, %x, %x, %x", m.time(),
            #           *mem.WDATA.value()).when(m.posedge(self.io.CLK)).if_(wen)
            # m.display("[%0t]: wmask = %x, %x, %x, %x", m.time(),
            #           *mem.WMASK.value()).when(m.posedge(self.io.CLK)).if_(wen)

        tag_and_idx = m.zext_to(m.concat(idx_reg, tag_reg),
                                nasti_params.x_addr_bits)
        self.io.nasti.ar.data @= NastiReadAddressChannel(
            nasti_params, 0, tag_and_idx << m.Bits[len(tag_and_idx)](b_len),
            m.bitutils.clog2(nasti_params.x_data_bits // 8), data_beats - 1)

        rmeta_and_idx = m.zext_to(m.concat(idx_reg, rmeta.tag),
                                  nasti_params.x_addr_bits)
        self.io.nasti.aw.data @= NastiWriteAddressChannel(
            nasti_params, 0,
            rmeta_and_idx << m.Bits[len(rmeta_and_idx)](b_len),
            m.bitutils.clog2(nasti_params.x_data_bits // 8), data_beats - 1)

        self.io.nasti.w.data @= NastiWriteDataChannel(
            nasti_params,
            m.array([
                read[i * nasti_params.x_data_bits:(i + 1) *
                     nasti_params.x_data_bits] for i in range(data_beats)
            ])[write_count[:-1]], None, write_wrap_out)

        is_dirty = v.O[idx_reg] & d.O[idx_reg]

        # TODO: Have to use temporary so we can invoke `fired()`
        aw_valid = m.Bit(name="aw_valid")
        self.io.nasti.aw.valid @= aw_valid

        ar_valid = m.Bit(name="ar_valid")
        self.io.nasti.ar.valid @= ar_valid

        b_ready = m.Bit(name="b_ready")
        self.io.nasti.b.ready @= b_ready

        @m.inline_combinational()
        def logic():
            state.I @= state.O
            aw_valid @= False
            ar_valid @= False
            self.io.nasti.w.valid @= False
            b_ready @= False
            if state.O == State.IDLE:
                if self.io.cpu.req.valid:
                    if self.io.cpu.req.data.mask.reduce_or():
                        state.I @= State.WRITE_CACHE
                    else:
                        state.I @= State.READ_CACHE
            elif state.O == State.READ_CACHE:
                if hit:
                    if self.io.cpu.req.valid:
                        if self.io.cpu.req.data.mask.reduce_or():
                            state.I @= State.WRITE_CACHE
                        else:
                            state.I @= State.READ_CACHE
                    else:
                        state.I @= State.IDLE
                else:
                    aw_valid @= is_dirty
                    ar_valid @= ~is_dirty
                    if self.io.nasti.aw.fired():
                        state.I @= State.WRITE_BACK
                    elif self.io.nasti.ar.fired():
                        state.I @= State.REFILL
            elif state.O == State.WRITE_CACHE:
                if hit | is_alloc_reg | self.io.cpu.abort:
                    state.I @= State.IDLE
                else:
                    aw_valid @= is_dirty
                    ar_valid @= ~is_dirty
                    if self.io.nasti.aw.fired():
                        state.I @= State.WRITE_BACK
                    elif self.io.nasti.ar.fired():
                        state.I @= State.REFILL
            elif state.O == State.WRITE_BACK:
                self.io.nasti.w.valid @= True
                if write_wrap_out:
                    state.I @= State.WRITE_ACK
            elif state.O == State.WRITE_ACK:
                b_ready @= True
                if self.io.nasti.b.fired():
                    state.I @= State.REFILL_READY
            elif state.O == State.REFILL_READY:
                ar_valid @= True
                if self.io.nasti.ar.fired():
                    state.I @= State.REFILL
            elif state.O == State.REFILL:
                if read_wrap_out:
                    if cpu_mask.O.reduce_or():
                        state.I @= State.WRITE_CACHE
                    else:
                        state.I @= State.IDLE

        if data_beats > 1:
            # TODO: Have to do this at the end since the inline comb logic
            # wires up nasti.w
            write_counter.CE @= m.enable(self.io.nasti.w.fired())