def __init__(self, T, entries, pipe=False, flow=False): assert entries >= 0 self.io = m.IO( # Flipped since enq/deq is from perspective of the client enq=m.DeqIO[T], deq=m.EnqIO[T], count=m.Out(m.UInt[m.bitutils.clog2(entries + 1)])) + m.ClockIO() ram = m.Memory(entries, T)() enq_ptr = mantle.CounterModM(entries, entries.bit_length(), has_ce=True, cout=False) deq_ptr = mantle.CounterModM(entries, entries.bit_length(), has_ce=True, cout=False) maybe_full = m.Register(init=False, has_enable=True)() ptr_match = enq_ptr.O == deq_ptr.O empty = ptr_match & ~maybe_full.O full = ptr_match & maybe_full.O self.io.deq.valid @= ~empty self.io.enq.ready @= ~full do_enq = self.io.enq.fired() do_deq = self.io.deq.fired() ram.write(self.io.enq.data, enq_ptr.O[:-1], m.enable(do_enq)) enq_ptr.CE @= m.enable(do_enq) deq_ptr.CE @= m.enable(do_deq) maybe_full.I @= m.enable(do_enq) maybe_full.CE @= m.enable(do_enq != do_deq) self.io.deq.data @= ram[deq_ptr.O[:-1]] if flow: raise NotImplementedError() if pipe: raise NotImplementedError() def ispow2(n): return (n & (n - 1) == 0) and n != 0 count_len = len(self.io.count) if ispow2(entries): self.io.count @= m.mux([m.bits(0, count_len), entries], maybe_full.O & ptr_match) else: ptr_diff = enq_ptr.O - deq_ptr.O self.io.count @= m.mux([ m.mux([m.bits(0, count_len), entries], maybe_full.O), m.mux([ptr_diff, entries + ptr_diff], deq_ptr.O > enq_ptr.O) ], ptr_match)
class DUT(m.Circuit): io = m.IO(done=m.Out(m.Bit)) + m.ClockIO() imm = ImmGen(32)() ctrl = Control(32)() counter = mantle.CounterModM(len(insts), len(insts).bit_length()) i = m.mux([iimm(i) for i in insts], counter.O) s = m.mux([simm(i) for i in insts], counter.O) b = m.mux([bimm(i) for i in insts], counter.O) u = m.mux([uimm(i) for i in insts], counter.O) j = m.mux([jimm(i) for i in insts], counter.O) z = m.mux([zimm(i) for i in insts], counter.O) x = m.mux([iimm(i) & -2 for i in insts], counter.O) O = m.mux([ m.mux([ m.mux([ m.mux([ m.mux([m.mux([x, z], ctrl.imm_sel == IMM_Z), j], ctrl.imm_sel == IMM_J), u ], ctrl.imm_sel == IMM_U), b ], ctrl.imm_sel == IMM_B), s ], ctrl.imm_sel == IMM_S), i ], ctrl.imm_sel == IMM_I) inst = m.mux(insts, counter.O) ctrl.inst @= inst imm.inst @= inst imm.sel @= ctrl.imm_sel io.done @= counter.COUT f.assert_immediate(imm.O == O, failure_msg=("Counter: %d, Type: 0x%x, O: %x ?= %x", counter.O, imm.sel, imm.O, O)) m.display("Counter: %d, Type: 0x%x, O: %x ?= %x", counter.O, imm.sel, imm.O, O)
def __init__(self, T, entries, with_bug=False): assert entries >= 0 self.io = m.IO( # Flipped since enq/deq is from perspective of the client enq=m.DeqIO[T], deq=m.EnqIO[T]) + m.ClockIO() ram = m.Memory(entries, T)() enq_ptr = mantle.CounterModM(entries, entries.bit_length(), has_ce=True, cout=False) deq_ptr = mantle.CounterModM(entries, entries.bit_length(), has_ce=True, cout=False) maybe_full = m.Register(init=False, has_enable=True)() ptr_match = enq_ptr.O == deq_ptr.O empty = ptr_match & ~maybe_full.O if with_bug: # never full full = False else: full = ptr_match & maybe_full.O self.io.deq.valid @= ~empty self.io.enq.ready @= ~full do_enq = self.io.enq.fired() do_deq = self.io.deq.fired() ram.write(self.io.enq.data, enq_ptr.O[:-1], m.enable(do_enq)) enq_ptr.CE @= m.enable(do_enq) deq_ptr.CE @= m.enable(do_deq) maybe_full.I @= m.enable(do_enq) maybe_full.CE @= m.enable(do_enq != do_deq) self.io.deq.data @= ram[deq_ptr.O[:-1]] def ispow2(n): return (n & (n - 1) == 0) and n != 0
def definition(io): load = io.LOAD baud = rising(io.SCK) | falling(io.SCK) valid_counter = mantle.CounterModM(buf_size, 12, has_ce=True) m.wire(load & baud, valid_counter.CE) valid_list = [wi * (b - 1) + i * a - 1 for i in range(1, wo + 1)] # len = 32 valid = m.GND for i in valid_list: valid = valid | mantle.Decode(i, 12)(valid_counter.O) # register on input st_in = mantle.Register(width, has_ce=True) st_in(io.DATA) m.wire(load, st_in.CE) # --------------------------DOWNSCALING----------------------------- # # downscale the image from 352x288 to 32x32 Downscale = m.DeclareCircuit( 'Downscale', "I_0_0", m.In(m.Array(1, m.Array(1, m.Array(width, m.Bit)))), "WE", m.In(m.Bit), "CLK", m.In(m.Clock), "O", m.Out(m.Array(width, m.Bit)), "V", m.Out(m.Bit)) dscale = Downscale() m.wire(st_in.O, dscale.I_0_0[0][0]) m.wire(1, dscale.WE) m.wire(load, dscale.CLK) add16 = mantle.Add(width) # needed for Add16 definition # threshold the downscale output px_bit = mantle.ULE(16)(dscale.O, m.uint(THRESH, 16)) & valid # ---------------------------UART OUTPUT----------------------------- # m.wire(px_bit, io.O) m.wire(valid, io.VALID)
hx8kboard.Clock.on() hx8kboard.J2[9].output().on() hx8kboard.J2[10].output().on() hx8kboard.J2[11].output().on() hx8kboard.J2[12].output().on() main = hx8kboard.main() # "test" data init = [m.uint(i, 16) for i in range(16)] printf = mantle.Counter(4, has_ce=True) rom = ROM16(4, init, printf.O) # baud for uart output clock = mantle.CounterModM(103, 8) baud = clock.COUT bit_counter = mantle.Counter(5, has_ce=True) m.wire(baud, bit_counter.CE) load = mantle.Decode(0, 5)(bit_counter.O) valid_counter = mantle.CounterModM(buf_size, 13, has_ce=True) m.wire(load & baud, valid_counter.CE) valid_list = [wi * (b - 1) + i * a - 1 for i in range(1, wo + 1)] # len = 16 valid = m.GND for i in valid_list:
def definition(cam): edge_f = falling(cam.SCK) edge_r = rising(cam.SCK) # ROM to store commands rom_index = mantle.Counter(4, has_ce=True) rom = ROM16(4, init, rom_index.O) # Message length is 16 bits, setup counter to generate done signal # after EOM done_counter = mantle.Counter(5, has_ce=True, has_reset=True) count = done_counter.O done = mantle.Decode(16, 5)(count) # State machine to generate run signal (enable) run = mantle.DFF(has_ce=True) run_n = mantle.LUT3([0, 0, 1, 0, 1, 0, 1, 0]) run_n(done, trigger, run) run(run_n) m.wire(edge_f, run.CE) # Reset the message length counter after done run_reset = mantle.LUT2(I0 | ~I1)(done, run) done_counter(CE=edge_r, RESET=run_reset) # State variables for high-level state machine ready = mantle.LUT2(~I0 & I1)(run, edge_f) start = mantle.ULE(4)(rom_index.O, m.uint(3, 4)) burst = mantle.UGE(4)(rom_index.O, m.uint(9, 4)) # Shift register to store 16-bit command|data to send mosi = mantle.PISO(16, has_ce=True) # SPI enable is negative of load-don't load and shift out data at the # same time enable = mantle.LUT3(I0 & ~I1 & ~I2)(trigger, run, burst) mosi(~burst, rom.O, enable) m.wire(edge_f, mosi.CE) # Shit register to read in 8-bit data miso = mantle.SIPO(8, has_ce=True) miso(cam.MISO) valid = mantle.LUT2(~I0 & I1)(enable, edge_r) m.wire(valid, miso.CE) # Capture done state variable cap_done = mantle.SRFF(has_ce=True) cap_done(mantle.EQ(8)(miso.O, m.bits(0x08, 8)), 0) m.wire(enable & edge_r, cap_done.CE) # Use state variables to determine what commands are sent (how) increment = mantle.LUT4(I0 & (I1 | I2) & ~I3)( ready, start, cap_done, burst) m.wire(increment, rom_index.CE) # wire outputs m.wire(enable, cam.EN) m.wire(mosi.O, cam.MOSI) m.wire(miso.O, cam.DATA) m.wire(burst, cam.VALID) # --------------------------UART OUTPUT---------------------------- # # run UART at 2x SPI rate to allow it to keep up baud = edge_r | edge_f # reset when SPI burst read (image transfer) begins ff = mantle.FF(has_ce=True) m.wire(edge_r, ff.CE) u_reset = mantle.LUT2(I0 & ~I1)(burst, ff(burst)) # UART data out every 8 bits u_counter = mantle.CounterModM(8, 3, has_ce=True, has_reset=True) u_counter(CE=edge_r, RESET=u_reset) load = burst & rising(u_counter.COUT) uart = UART(8) uart(CLK=cam.CLK, BAUD=baud, DATA=miso, LOAD=load) # wire output m.wire(uart, cam.UART) # generate signal for when transfer is done data_count = mantle.Counter(18, has_ce=True) tx_done = mantle.SRFF(has_ce=True) # transfer has size 153600 bytes, first 2 bytes are ignored tx_done(mantle.EQ(18)(data_count.O, m.bits(153602, 18)), 0) m.wire(load, tx_done.CE) m.wire(load, data_count.CE) # wire output m.wire(tx_done, cam.DONE)
icestick.Clock.on() icestick.TX.output().on() main = icestick.main() valid = 1 init = [m.array(int2seq(ord(c), 8)) for c in 'hello, world \r\n'] printf = mantle.Counter(4, has_ce=True) rom = ROM(4, init, printf.O) data = m.array([rom.O[7], rom.O[6], rom.O[5], rom.O[4], rom.O[3], rom.O[2], rom.O[1], rom.O[0], 0]) counter = mantle.CounterModM(103, 8) baud = counter.COUT count = mantle.Counter(4, has_ce=True, has_reset=True) decode = mantle.Decode(15, 4) done = decode(count.O) run = mantle.DFF(has_ce=True) run_n = mantle.LUT3([0,0,1,0, 1,0,1,0]) run_n(done, valid, run.O) run(run_n, ce=baud) reset = mantle.LUT2(mantle.I0&~mantle.I1)(done, run) count(CE=baud, RESET=reset) shift = mantle.PISO(9, has_ce=True)
from uart import UART icestick = IceStick() icestick.Clock.on() for i in range(8): icestick.J3[i].output().on() main = icestick.main() # "test" data init = [m.uint(n, 16) for n in range(16)] printf = mantle.Counter(4, has_ce=True) rom = ROM16(4, init, printf.O) # baud for uart output clock = mantle.CounterModM(103, 8) baud = clock.COUT bit_counter = mantle.Counter(5, has_ce=True) m.wire(baud, bit_counter.CE) load = mantle.Decode(0, 5)(bit_counter.O) valid_counter = mantle.CounterModM(8, 3, has_ce=True) m.wire(load & baud, valid_counter.CE) valid_list = [5, 7] valid = m.GND for i in valid_list:
hx8kboard.J2[5].output().on() hx8kboard.J2[8].output().on() hx8kboard.J2[9].output().on() hx8kboard.J2[10].output().on() hx8kboard.J2[11].output().on() hx8kboard.J2[12].output().on() main = hx8kboard.main() # "test" data init = [m.uint(i, 16) for i in range(16)] printf = mantle.Counter(4, has_ce=True) rom = ROM16(4, init, printf.O) # baud for uart output clock = mantle.CounterModM(16, 8) baud = clock.COUT bit_counter = mantle.Counter(5, has_ce=True) m.wire(baud, bit_counter.CE) load = mantle.Decode(0, 5)(bit_counter.O) m.wire(load & baud, printf.CE) rescale = Rescale() # inputs m.wire(main.CLKIN, rescale.CLK) m.wire(rom.O, rescale.DATA) m.wire(baud, rescale.SCK)
def definition(io): load = io.LOAD baud = io.BAUD valid_counter = mantle.CounterModM(buf_size, 13, has_ce=True) m.wire(load & baud, valid_counter.CE) valid_list = [wi * (b - 1) + i * a - 1 for i in range(1, wo + 1)] valid = m.GND for i in valid_list: valid = valid | mantle.Decode(i, 13)(valid_counter.O) # register on input st_in = mantle.Register(16, has_ce=True) st_in(io.DATA) m.wire(load, st_in.CE) # --------------------------DOWNSCALING----------------------------- # # downscale the image from 320x240 to 16x16 Downscale = m.DeclareCircuit( 'Downscale', "I_0_0", m.In(m.Array(1, m.Array(1, m.Array(16, m.Bit)))), "WE", m.In(m.Bit), "CLK", m.In(m.Clock), "O", m.Out(m.Array(16, m.Bit)), "V", m.Out(m.Bit)) dscale = Downscale() m.wire(st_in.O, dscale.I_0_0[0][0]) m.wire(1, dscale.WE) m.wire(load, dscale.CLK) add16 = mantle.Add(16) # needed for Add16 definition # --------------------------FILL IMG RAM--------------------------- # # each valid output of dscale represents an entry of 16x16 binary image # accumulate each group of 16 entries into a 16-bit value representing # a row of the image col = mantle.Counter(4, has_ce=True) row_full = mantle.SRFF(has_ce=True) row_full(mantle.EQ(4)(col.O, m.bits(15, 4)), 0) m.wire(falling(dscale.V), row_full.CE) col_ce = rising(dscale.V) & ~row_full.O m.wire(col_ce, col.CE) row = mantle.Counter(4, has_ce=True) img_full = mantle.SRFF(has_ce=True) img_full(mantle.EQ(4)(row.O, m.bits(15, 4)), 0) m.wire(falling(col.COUT), img_full.CE) row_ce = rising(col.COUT) & ~img_full.O m.wire(row_ce, row.CE) # ---------------------------UART OUTPUT----------------------------- # uart_st = UART(16) uart_st(CLK=io.CLK, BAUD=baud, DATA=dscale.O, LOAD=load) m.wire(row.O, io.ROW) m.wire(img_full.O, io.DONE) m.wire(uart_st.O, io.UART)
from loam.boards.hx8kboard import HX8KBoard from uart import UART hx8kboard = HX8KBoard() hx8kboard.Clock.on() hx8kboard.J2[3].output().on() hx8kboard.J2[4].output().on() hx8kboard.J2[5].output().on() hx8kboard.J2[8].output().on() hx8kboard.J2[9].output().on() main = hx8kboard.main() # baud for uart output clock = mantle.CounterModM(100, 8) baud = clock.COUT bit_counter = mantle.Counter(5, has_ce=True) m.wire(baud, bit_counter.CE) load = mantle.Decode(0, 5)(bit_counter.O) valid_counter = mantle.CounterModM(4800, 13, has_ce=True) m.wire(load & baud, valid_counter.CE) valid_list = [10 * i for i in range(16)] # len = 16 valid = m.GND for i in valid_list:
def __init__(self, x_len, n_ways: int, n_sets: int, b_bytes: int): nasti_params = NastiParameters(data_bits=64, addr_bits=x_len, id_bits=5) self.io = m.IO(req=m.Consumer(m.Decoupled[make_CacheReq(x_len)]), resp=m.Producer(m.Decoupled[make_CacheResp(x_len)]), nasti=make_NastiIO(nasti_params)) + m.ClockIO() size = m.bitutils.clog2(nasti_params.x_data_bits) b_bits = b_bytes << 3 b_len = m.bitutils.clog2(b_bytes) s_len = m.bitutils.clog2(n_sets) t_len = x_len - (s_len + b_len) nasti_params = NastiParameters(data_bits=64, addr_bits=x_len, id_bits=5) data_beats = b_bits // nasti_params.x_data_bits length = data_beats - 1 data = m.Memory(n_sets, m.UInt[b_bits])() tags = m.Memory(n_sets, m.UInt[t_len])() v = m.Memory(n_sets, m.Bit)() d = m.Memory(n_sets, m.Bit)() req = self.io.req.data tag = (req.addr >> (b_len + s_len))[:t_len] idx = req.addr[b_len:b_len + s_len] off = req.addr[:b_len] read = data.read(idx) write = m.bits(0, b_bits) for i in range(b_bytes): write |= m.mux([(read & (0xff << (8 * i))), ((m.zext_to(req.data, b_bits) >> ((8 * (i & 0x3)))) & 0xff) << (8 * i)], ((off // 4) == (i // 4)) & (req.mask >> (i & 0x3))[0])[:b_bits] class State(m.Enum): IDLE = 0 WRITE = 1 WRITE_ACK = 2 READ = 3 state = m.Register(init=State.IDLE)() write_counter = mantle.CounterModM(data_beats, max(data_beats.bit_length(), 1), has_ce=True) write_counter.CE @= m.enable(state.O == State.WRITE) w_cnt, w_done = write_counter.O, write_counter.COUT read_counter = mantle.CounterModM(data_beats, max(data_beats.bit_length(), 1), has_ce=True) read_counter.CE @= m.enable((state.O == State.READ) & self.io.nasti.r.valid) r_cnt, r_done = read_counter.O, read_counter.COUT self.io.resp.data.data @= (read >> (m.zext_to( (off // 4), b_bits) * x_len))[:x_len] self.io.nasti.ar.data @= NastiReadAddressChannel( nasti_params, 0, (req.addr >> b_len) << b_len, size, length) tags_rdata = tags.read(idx) self.io.nasti.aw.data @= NastiWriteAddressChannel( nasti_params, 0, m.bits(m.concat(idx, tags_rdata), nasti_params.x_addr_bits) << b_len, size, length) self.io.nasti.w.data @= NastiWriteDataChannel( nasti_params, (read >> (m.zext_to(w_cnt, b_bits) * nasti_params.x_data_bits))[:nasti_params.x_data_bits], None, w_done) self.io.nasti.w.valid @= state.O == State.WRITE self.io.nasti.b.ready @= state.O == State.WRITE_ACK self.io.nasti.r.ready @= state.O == State.READ d_wen = m.Bit(name="d_wen") d.write(True, idx, m.enable(d_wen)) data_wen = m.Bit(name="data_wen") data_wdata = m.UInt[b_bits](name="data_wdata") data.write(data_wdata, idx, m.enable(data_wen)) # m.display("data_wdata=%x", data_wdata).when(m.posedge(self.io.CLK)) v_wen = m.Bit(name="v_wen") v.write(True, idx, m.enable(v_wen)) v_rdata = v.read(idx) tags_wen = m.Bit(name="tags_wen") tags.write(tag, idx, m.enable(tags_wen)) d_rdata = d.read(idx) # m.display("gold_state=%x", state.O).when(m.posedge(self.io.CLK)) # m.display("gold_w_done=%x", w_done).when(m.posedge(self.io.CLK)) # m.display("gold_b_valid=%x", # self.io.nasti.b.valid).when(m.posedge(self.io.CLK)) if TRACE: m.display( "[%0t] [cache] data[%x] <= %x, off: %x, req: %x, mask: %b", m.time(), idx, write, off, self.io.req.data.data, self.io.req.data.mask)\ .when(m.posedge(self.io.CLK))\ .if_((state.O == State.IDLE) & (self.io.req.valid & self.io.resp.ready) & (v_rdata & (tags_rdata == tag)) & req.mask.reduce_or()) m.display( "[%0t] [cache] data[%x] => %x, off: %x, resp: %x", m.time(), idx, write, off, self.io.resp.data.data.value())\ .when(m.posedge(self.io.CLK))\ .if_((state.O == State.IDLE) & (self.io.req.valid & self.io.resp.ready) & (v_rdata & (tags_rdata == tag)) & ~req.mask.reduce_or()) @m.inline_combinational() def logic(): self.io.resp.valid @= False self.io.req.ready @= False self.io.nasti.ar.valid @= False self.io.nasti.aw.valid @= False d_wen @= False data_wen @= False data_wdata @= m.UInt[b_bits](0) state.I @= state.O tags_wen @= False v_wen @= False if state.O == State.IDLE: if self.io.req.valid & self.io.resp.ready: if v_rdata & (tags_rdata == tag): if req.mask.reduce_or(): d_wen @= True data_wdata @= write data_wen @= True self.io.req.ready @= True self.io.resp.valid @= True else: if d_rdata: self.io.nasti.aw.valid @= True state.I @= State.WRITE else: data_wdata @= 0 data_wen @= True self.io.nasti.ar.valid @= True state.I @= State.READ elif state.O == State.WRITE: if w_done: state.I @= State.WRITE_ACK elif state.O == State.WRITE_ACK: if self.io.nasti.b.valid: data_wdata @= 0 data_wen @= True self.io.nasti.ar.valid @= True state.I @= State.READ elif state.O == State.READ: if self.io.nasti.r.valid: data_wdata @= read | ( m.zext_to(self.io.nasti.r.data.data, b_bits) << (m.zext_to(r_cnt, b_bits) * nasti_params.x_data_bits)) data_wen @= True if r_done: tags_wen @= True v_wen @= True state.I @= State.IDLE
class DUT(m.Circuit): io = m.IO(done=m.Out(m.Bit)) + m.ClockIO() x_len = 32 n_sets = 256 b_bytes = 4 * (x_len >> 3) b_len = m.bitutils.clog2(b_bytes) s_len = m.bitutils.clog2(n_sets) t_len = x_len - (s_len + b_len) nasti_params = NastiParameters(data_bits=64, addr_bits=x_len, id_bits=5) dut = Cache(x_len, 1, n_sets, b_bytes)() dut_mem = make_NastiIO(nasti_params).undirected_t(name="dut_mem") dut_mem.ar @= make_Queue(dut.nasti.ar, 32) dut_mem.aw @= make_Queue(dut.nasti.aw, 32) dut_mem.w @= make_Queue(dut.nasti.w, 32) dut.nasti.b @= make_Queue(dut_mem.b, 32) dut.nasti.r @= make_Queue(dut_mem.r, 32) gold = GoldCache(x_len, 1, n_sets, b_bytes)() gold_req = type(gold.req).undirected_t(name="gold_req") gold_resp = type(gold.resp).undirected_t(name="gold_resp") gold_mem = make_NastiIO(nasti_params).undirected_t(name="gold_mem") gold.req @= make_Queue(gold_req, 32) gold_resp @= make_Queue(gold.resp, 32) gold_mem.ar @= make_Queue(gold.nasti.ar, 32) gold_mem.aw @= make_Queue(gold.nasti.aw, 32) gold_mem.w @= make_Queue(gold.nasti.w, 32) gold.nasti.b @= make_Queue(gold_mem.b, 32) gold.nasti.r @= make_Queue(gold_mem.r, 32) size = m.bitutils.clog2(nasti_params.x_data_bits // 8) b_bits = b_bytes << 3 data_beats = b_bits // nasti_params.x_data_bits mem = m.Memory(1 << 20, m.UInt[nasti_params.x_data_bits])() class MemState(m.Enum): IDLE = 0 WRITE = 1 WRITE_ACK = 2 READ = 3 mem_state = m.Register(init=MemState.IDLE)() write_counter = mantle.CounterModM(data_beats, data_beats.bit_length(), has_ce=True) write_counter.CE @= m.enable((mem_state.O == MemState.WRITE) & dut_mem.w.valid & gold_mem.w.valid) read_counter = mantle.CounterModM(data_beats, data_beats.bit_length(), has_ce=True) read_counter.CE @= m.enable((mem_state.O == MemState.READ) & dut_mem.r.ready & gold_mem.r.ready) dut_mem.b.valid @= mem_state.O == MemState.WRITE_ACK dut_mem.b.data @= NastiWriteResponseChannel(nasti_params, 0) dut_mem.r.valid @= mem_state.O == MemState.READ dut_mem.r.data @= NastiReadDataChannel( nasti_params, 0, mem.read( ((gold_mem.ar.data.addr) + m.zext_to(read_counter.O, nasti_params.x_addr_bits))[:20]), read_counter.COUT) gold_mem.ar.ready @= dut_mem.ar.ready gold_mem.aw.ready @= dut_mem.aw.ready gold_mem.w.ready @= dut_mem.w.ready gold_mem.b.valid @= dut_mem.b.valid gold_mem.b.data @= dut_mem.b.data gold_mem.r.valid @= dut_mem.r.valid gold_mem.r.data @= dut_mem.r.data mem_wen0 = m.Bit(name="mem_wen0") mem_wdata0 = m.UInt[nasti_params.x_data_bits](name="mem_wdata0") mem_wen1 = m.Bit(name="mem_wen1") mem_wdata1 = m.UInt[nasti_params.x_data_bits](name="mem_wdata1") mem_waddr1 = m.UInt[20](name="mem_waddr1") mem.write( m.mux([dut_mem.w.data.data, mem_wdata1], mem_wen1), m.mux([((dut_mem.aw.data.addr) + m.zext_to(write_counter.O, nasti_params.x_addr_bits))[:20], mem_waddr1], mem_wen1), m.enable(mem_wen0 | mem_wen1)) # m.display("mem_wen0 = %x, mem_wen1 = %x", mem_wen0, # mem_wen1).when(m.posedge(io.CLK)) # m.display("dut_mem.w.valid = %x", # dut_mem.w.valid).when(m.posedge(io.CLK)) # m.display("gold_mem.w.valid = %x", # gold_mem.w.valid).when(m.posedge(io.CLK)) f.assert_immediate( (mem_state.O != MemState.IDLE) | ~(gold_mem.aw.valid & dut_mem.aw.valid) | (dut_mem.aw.data.addr == gold_mem.aw.data.addr), failure_msg=( "[dut_mem.aw.data.addr] %x != [gold_mem.aw.data.addr] %x", dut_mem.aw.data.addr, gold_mem.aw.data.addr)) f.assert_immediate( (mem_state.O != MemState.IDLE) | ~(gold_mem.aw.valid & dut_mem.aw.valid) | ~(gold_mem.ar.valid & dut_mem.ar.valid) | (dut_mem.ar.data.addr == gold_mem.ar.data.addr), failure_msg=( "[dut_mem.ar.data.addr] %x != [gold_mem.ar.data.addr] %x", dut_mem.ar.data.addr, gold_mem.ar.data.addr)) f.assert_immediate( (mem_state.O != MemState.WRITE) | ~(gold_mem.w.valid & dut_mem.w.valid) | (dut_mem.w.data.data == gold_mem.w.data.data), failure_msg=( "[dut_mem.w.data.data] %x != [gold_mem.w.data.data] %x", dut_mem.w.data.data, gold_mem.w.data.data)) @m.inline_combinational() def mem_fsm(): dut_mem.w.ready @= False dut_mem.aw.ready @= False dut_mem.ar.ready @= False mem_wen0 @= False mem_state.I @= mem_state.O if mem_state.O == MemState.IDLE: if gold_mem.aw.valid & dut_mem.aw.valid: mem_state.I @= MemState.WRITE elif gold_mem.ar.valid & dut_mem.ar.valid: mem_state.I @= MemState.READ elif mem_state.O == MemState.WRITE: if gold_mem.w.valid & dut_mem.w.valid: mem_wen0 @= True dut_mem.w.ready @= True if write_counter.COUT: dut_mem.aw.ready @= True mem_state.I @= MemState.WRITE_ACK elif mem_state.O == MemState.WRITE_ACK: if gold_mem.b.ready & dut_mem.b.ready: mem_state.I @= MemState.IDLE elif mem_state.O == MemState.READ: if read_counter.COUT: dut_mem.ar.ready @= True mem_state.I @= MemState.IDLE if TRACE: m.display("[%0t]: [write] mem[%x] <= %x", m.time(), mem.WADDR.value(), dut_mem.w.data.data).when( m.posedge(io.CLK)).if_(mem_wen0) m.display("[%0t]: [read] mem[%x] => %x", m.time(), mem.RADDR.value(), dut_mem.r.data.data).when(m.posedge( io.CLK)).if_((mem_state.O == MemState.READ) & dut_mem.r.ready & gold_mem.r.ready) def rand_data(nasti_params): rand_data = BitVector[nasti_params.x_data_bits](0) for i in range(nasti_params.x_data_bits // 8): rand_data |= BitVector[nasti_params.x_data_bits]( random.randint(0, 0xff) << (8 * i)) return rand_data def rand_mask(x_len): return BitVector[x_len // 8](random.randint( 1, (1 << (x_len // 8)) - 2)) def make_test(rand_data, nasti_params, x_len): # Wrapper because function definition in side class namespace # doesn't inherit class variables def test(b_bits, tag, idx, off, mask=BitVector[x_len // 8](0)): test_data = rand_data(nasti_params) for i in range((b_bits // nasti_params.x_data_bits) - 1): test_data = test_data.concat(rand_data(nasti_params)) return m.uint(m.concat(off, idx, tag, test_data, mask)) return test test = make_test(rand_data, nasti_params, x_len) tags = [] for _ in range(3): tags.append(BitVector.random(t_len)) idxs = [] for _ in range(2): idxs.append(BitVector.random(s_len)) offs = [] for _ in range(6): offs.append(BitVector.random(b_len) & -4) init_addr = [] init_data = [] _iter = itertools.product(tags, idxs, range(0, data_beats)) for tag, idx, off in _iter: init_addr.append(m.uint(m.concat(BitVector[b_len](off), idx, tag))) init_data.append(rand_data(nasti_params)) test_vec = [ test(b_bits, tags[0], idxs[0], offs[0]), # 0: read miss test(b_bits, tags[0], idxs[0], offs[1]), # 1: read hit test(b_bits, tags[1], idxs[0], offs[0]), # 2: read miss test(b_bits, tags[1], idxs[0], offs[2]), # 3: read hit test(b_bits, tags[1], idxs[0], offs[3]), # 4: read hit test(b_bits, tags[1], idxs[0], offs[4], rand_mask(x_len)), # 5: write hit # noqa test(b_bits, tags[1], idxs[0], offs[4]), # 6: read hit test(b_bits, tags[2], idxs[0], offs[5]), # 7: read miss & write back # noqa test(b_bits, tags[0], idxs[1], offs[0], rand_mask(x_len)), # 8: write miss # noqa test(b_bits, tags[0], idxs[1], offs[0]), # 9: read hit test(b_bits, tags[0], idxs[1], offs[1]), # 10: read hit test(b_bits, tags[1], idxs[1], offs[2], rand_mask(x_len)), # 11: write miss & write back # noqa test(b_bits, tags[1], idxs[1], offs[3]), # 12: read hit test(b_bits, tags[2], idxs[1], offs[4]), # 13: read write back test(b_bits, tags[2], idxs[1], offs[5]) # 14: read hit ] class TestState(m.Enum): INIT = 0 START = 1 WAIT = 2 DONE = 3 state = m.Register(init=TestState.INIT)() timeout = m.Register(m.UInt[32])() init_m = len(init_addr) - 1 init_counter = mantle.CounterModM(init_m, init_m.bit_length(), has_ce=True) init_counter.CE @= m.enable(state.O == TestState.INIT) test_m = len(test_vec) - 1 test_counter = mantle.CounterModM(test_m, test_m.bit_length(), has_ce=True) test_counter.CE @= m.enable(state.O == TestState.DONE) curr_vec = m.mux(test_vec, test_counter.O) mask = (curr_vec >> (b_len + s_len + t_len + b_bits))[:x_len // 8] data = (curr_vec >> (b_len + s_len + t_len))[:b_bits] tag = (curr_vec >> (b_len + s_len))[:t_len] idx = (curr_vec >> b_len)[:s_len] off = curr_vec[:b_len] dut.cpu.req.data.addr @= m.concat(off, idx, tag) # TODO: Is truncating this fine? req_data = data[:x_len] dut.cpu.req.data.data @= req_data dut.cpu.req.data.mask @= mask dut.cpu.req.valid @= state.O == TestState.WAIT dut.cpu.abort @= 0 gold_req.data @= dut.cpu.req.data.value() gold_req.valid @= state.O == TestState.START gold_resp.ready @= state.O == TestState.DONE mem_waddr1 @= m.mux(init_addr, init_counter.O)[:20] mem_wdata1 @= m.mux(init_data, init_counter.O) check_resp_data = m.Bit() if TRACE: m.display("[%0t]: [init] mem[%x] <= %x", m.time(), mem_waddr1, mem_wdata1)\ .when(m.posedge(io.CLK))\ .if_(state.O == TestState.INIT) @m.inline_combinational() def state_fsm(): timeout.I @= timeout.O mem_wen1 @= m.bit(False) check_resp_data @= m.bit(False) state.I @= state.O if state.O == TestState.INIT: mem_wen1 @= m.bit(True) if init_counter.COUT: state.I @= TestState.START elif state.O == TestState.START: if gold_req.ready: timeout.I @= m.bits(0, 32) state.I @= TestState.WAIT elif state.O == TestState.WAIT: timeout.I @= timeout.O + 1 if dut.cpu.resp.valid & gold_resp.valid: if ~mask.reduce_or(): check_resp_data @= m.bit(True) state.I @= TestState.DONE elif state.O == TestState.DONE: state.I @= TestState.START f.assert_immediate((state.O != TestState.WAIT) | (timeout.O < 100)) f.assert_immediate( ~check_resp_data | (dut.cpu.resp.data.data == gold_resp.data.data), failure_msg=("dut.cpu.resp.data.data => %x != %x", dut.cpu.resp.data.data, gold_resp.data.data)) # m.display("mem_state=%x", mem_state.O).when(m.posedge(io.CLK)) # m.display("test_state=%x", state.O).when(m.posedge(io.CLK)) # m.display("dut req valid = %x", # dut.cpu.req.valid).when(m.posedge(io.CLK)) # m.display("gold req valid = %x, ready = %x", gold_req.valid, # gold_req.ready).when(m.posedge(io.CLK)) # m.display("[%0t]: dut resp data = %x, gold resp data = %x", m.time(), # dut.cpu.resp.data.data, gold_resp.data.data)\ # .when(m.posedge(io.CLK)) io.done @= test_counter.COUT
import magma as m import mantle from loam.boards.icestick import IceStick from uart import UART icestick = IceStick() icestick.Clock.on() for i in range(3): icestick.J3[i].output().on() main = icestick.main() # baud for uart output clock = mantle.CounterModM(103, 8) baud = clock.COUT bit_counter = mantle.Counter(5, has_ce=True) m.wire(baud, bit_counter.CE) load = mantle.Decode(0, 5)(bit_counter.O) valid_counter = mantle.CounterModM(1000, 10, has_ce=True) m.wire(load & baud, valid_counter.CE) m.wire(baud, main.J3[0]) m.wire(load, main.J3[1]) m.wire(valid_counter.COUT, main.J3[2])
def definition(io): load = io.LOAD baud = rising(io.SCK) | falling(io.SCK) valid_counter = mantle.CounterModM(buf_size, 12, has_ce=True) m.wire(load & baud, valid_counter.CE) valid_list = [wi * (b - 1) + i * a - 1 for i in range(1, wo + 1)] # len = 32 valid = m.GND for i in valid_list: valid = valid | mantle.Decode(i, 12)(valid_counter.O) # register on input st_in = mantle.Register(width, has_ce=True) st_in(io.DATA) m.wire(load, st_in.CE) # --------------------------DOWNSCALING----------------------------- # # downscale the image from 352x288 to 32x32 Downscale = m.DeclareCircuit( 'Downscale', "I_0_0", m.In(m.Array(1, m.Array(1, m.Array(width, m.Bit)))), "WE", m.In(m.Bit), "CLK", m.In(m.Clock), "O", m.Out(m.Array(width, m.Bit)), "V", m.Out(m.Bit)) dscale = Downscale() m.wire(st_in.O, dscale.I_0_0[0][0]) m.wire(1, dscale.WE) m.wire(load, dscale.CLK) add16 = mantle.Add(width) # needed for Add16 definition # --------------------------FILL IMG RAM--------------------------- # # each valid output of dscale represents an entry of 32x32 binary image # accumulate each group of 32 entries into a 32-bit value representing a row col = mantle.CounterModM(32, 6, has_ce=True) col_ce = rising(valid) m.wire(col_ce, col.CE) # shift each bit in one at a time until we get an entire row px_bit = mantle.ULE(16)(dscale.O, m.uint(THRESH, 16)) & valid row_reg = mantle.SIPO(32, has_ce=True) row_reg(px_bit) m.wire(col_ce, row_reg.CE) # reverse the row bits since the image is flipped row = reverse(row_reg.O) rowaddr = mantle.Counter(6, has_ce=True) img_full = mantle.SRFF(has_ce=True) img_full(mantle.EQ(6)(rowaddr.O, m.bits(32, 6)), 0) m.wire(falling(col.COUT), img_full.CE) row_ce = rising(col.COUT) & ~img_full.O m.wire(row_ce, rowaddr.CE) waddr = rowaddr.O[:5] rdy = col.COUT & ~img_full.O pulse_count = mantle.Counter(2, has_ce=True) we = mantle.UGE(2)(pulse_count.O, m.uint(1, 2)) pulse_count(CE=(we|rdy)) # ---------------------------UART OUTPUT----------------------------- # row_load = row_ce row_baud = mantle.FF()(baud) uart_row = UART(32) uart_row(CLK=io.CLK, BAUD=row_baud, DATA=row, LOAD=row_load) uart_addr = UART(5) uart_addr(CLK=io.CLK, BAUD=row_baud, DATA=waddr, LOAD=row_load) m.wire(waddr, io.WADDR) m.wire(img_full, io.DONE) #img_full m.wire(uart_row, io.UART) #uart_st m.wire(row, io.O) m.wire(we, io.VALID) m.wire(valid, io.T0) m.wire(uart_addr, io.T1)
def definition(io): load = io.LOAD baud = rising(io.SCK) | falling(io.SCK) valid_counter = mantle.CounterModM(buf_size, 13, has_ce=True) m.wire(load & baud, valid_counter.CE) valid_list = [wi * (b - 1) + i * a - 1 for i in range(1, wo + 1)] valid = m.GND for i in valid_list: valid = valid | mantle.Decode(i, 13)(valid_counter.O) # register on input st_in = mantle.Register(16, has_ce=True) st_in(io.DATA) m.wire(load, st_in.CE) # --------------------------DOWNSCALING----------------------------- # # downscale the image from 320x240 to 16x16 Downscale = m.DeclareCircuit( 'Downscale', "I_0_0", m.In(m.Array(1, m.Array(1, m.Array(16, m.Bit)))), "WE", m.In(m.Bit), "CLK", m.In(m.Clock), "O", m.Out(m.Array(16, m.Bit)), "V", m.Out(m.Bit)) dscale = Downscale() m.wire(st_in.O, dscale.I_0_0[0][0]) m.wire(1, dscale.WE) m.wire(load, dscale.CLK) add16 = mantle.Add(16) # needed for Add16 definition # --------------------------FILL IMG RAM--------------------------- # # each valid output of dscale represents a pixel in 16x16 binary image # accumulate each group of 16 pixels into a 16-bit value representing # a row in the image col = mantle.CounterModM(16, 5, has_ce=True) col_ce = rising(valid) m.wire(col_ce, col.CE) # shift each bit in one at a time until we get an entire row px_bit = mantle.ULE(16)(dscale.O, m.uint(THRESH, 16)) & valid row_reg = mantle.SIPO(16, has_ce=True) row_reg(px_bit) m.wire(col_ce, row_reg.CE) # reverse the row bits since the image is flipped row = reverse(row_reg.O) rowaddr = mantle.Counter(5, has_ce=True) img_full = mantle.SRFF(has_ce=True) img_full(mantle.EQ(5)(rowaddr.O, m.bits(16, 5)), 0) m.wire(falling(col.COUT), img_full.CE) row_ce = rising(col.COUT) & ~img_full.O m.wire(row_ce, rowaddr.CE) waddr = rowaddr.O[:4] # we_counter = mantle.CounterModM(16, 5, has_ce=True) # m.wire(rising(valid), we_counter.CE) rdy = col.COUT & ~img_full.O pulse_count = mantle.Counter(5, has_ce=True) we = mantle.UGE(5)(pulse_count.O, m.uint(1, 5)) pulse_count(CE=(we | rdy)) # ---------------------------UART OUTPUT----------------------------- # row_load = row_ce row_baud = mantle.FF()(baud) uart_row = UART(16) uart_row(CLK=io.CLK, BAUD=row_baud, DATA=row, LOAD=row_load) uart_addr = UART(4) uart_addr(CLK=io.CLK, BAUD=row_baud, DATA=waddr, LOAD=row_load) # split 16-bit row data into 8-bit packets so it can be parsed low_byte = row & LOW_MASK high_byte = row & HIGH_MASK uart_counter = mantle.CounterModM(8, 4, has_ce=True) m.wire(rising(valid), uart_counter.CE) m.wire(waddr, io.WADDR) m.wire(img_full, io.DONE) m.wire(uart_row, io.UART) m.wire(row, io.O) m.wire(we, io.VALID)
a = 2 b = 2 width = 16 TIN = m.Array(width, m.BitIn) TOUT = m.Array(width, m.Out(m.Bit)) icestick = IceStick() icestick.Clock.on() for i in range(3): icestick.J3[i].output().on() main = icestick.main() # baud for uart output clock = mantle.CounterModM(103, 8) baud = clock.COUT bit_counter = mantle.Counter(5, has_ce=True) m.wire(baud, bit_counter.CE) load = mantle.Decode(0, 5)(bit_counter.O) # # "test" data # init = [m.uint(i, 16) for i in range(16)] # printf = mantle.Counter(4, has_ce=True) # rom = ROM16(4, init, printf.O) # m.wire(load & baud, printf.CE) #---------------------------STENCILING-----------------------------#
def __init__(self, x_len, n_ways: int, n_sets: int, b_bytes: int): b_bits = b_bytes << 3 b_len = m.bitutils.clog2(b_bytes) s_len = m.bitutils.clog2(n_sets) t_len = x_len - (s_len + b_len) n_words = b_bits // x_len w_bytes = x_len // 8 byte_offset_bits = m.bitutils.clog2(w_bytes) nasti_params = NastiParameters(data_bits=64, addr_bits=x_len, id_bits=5) data_beats = b_bits // nasti_params.x_data_bits class MetaData(m.Product): tag = m.UInt[t_len] self.io = m.IO(**make_cache_ports(x_len, nasti_params)) self.io += m.ClockIO() class State(m.Enum): IDLE = 0 READ_CACHE = 1 WRITE_CACHE = 2 WRITE_BACK = 3 WRITE_ACK = 4 REFILL_READY = 5 REFILL = 6 state = m.Register(init=State.IDLE)() # memory v = m.Register(m.UInt[n_sets], has_enable=True)() d = m.Register(m.UInt[n_sets], has_enable=True)() meta_mem = m.Memory(n_sets, MetaData, read_latency=1, has_read_enable=True)() data_mem = [ ArrayMaskMem(n_sets, w_bytes, m.UInt[8], read_latency=1, has_read_enable=True)() for _ in range(n_words) ] addr_reg = m.Register(type(self.io.cpu.req.data.addr).undirected_t, has_enable=True)() cpu_data = m.Register(type(self.io.cpu.req.data.data).undirected_t, has_enable=True)() cpu_mask = m.Register(type(self.io.cpu.req.data.mask).undirected_t, has_enable=True)() self.io.nasti.r.ready @= state.O == State.REFILL # Counters assert data_beats > 0 if data_beats > 1: read_counter = mantle.CounterModM(data_beats, max(data_beats.bit_length(), 1), has_ce=True) read_counter.CE @= m.enable(self.io.nasti.r.fired()) read_count, read_wrap_out = read_counter.O, read_counter.COUT write_counter = mantle.CounterModM(data_beats, max(data_beats.bit_length(), 1), has_ce=True) write_count, write_wrap_out = write_counter.O, write_counter.COUT else: read_count, read_wrap_out = 0, 1 write_count, write_wrap_out = 0, 1 refill_buf = m.Register(m.Array[data_beats, m.UInt[nasti_params.x_data_bits]], has_enable=True)() if data_beats == 1: refill_buf.I[0] @= self.io.nasti.r.data.data else: refill_buf.I @= m.set_index(refill_buf.O, self.io.nasti.r.data.data, read_count[:-1]) refill_buf.CE @= m.enable(self.io.nasti.r.fired()) is_idle = state.O == State.IDLE is_read = state.O == State.READ_CACHE is_write = state.O == State.WRITE_CACHE is_alloc = (state.O == State.REFILL) & read_wrap_out # m.display("[%0t]: is_alloc = %x", m.time(), is_alloc)\ # .when(m.posedge(self.io.CLK)) is_alloc_reg = m.Register(m.Bit)()(is_alloc) hit = m.Bit(name="hit") wen = is_write & (hit | is_alloc_reg) & ~self.io.cpu.abort | is_alloc # m.display("[%0t]: wen = %x", m.time(), wen)\ # .when(m.posedge(self.io.CLK)) ren = m.enable(~wen & (is_idle | is_read) & self.io.cpu.req.valid) ren_reg = m.enable(m.Register(m.Bit)()(ren)) addr = self.io.cpu.req.data.addr idx = addr[b_len:s_len + b_len] tag_reg = addr_reg.O[s_len + b_len:x_len] idx_reg = addr_reg.O[b_len:s_len + b_len] off_reg = addr_reg.O[byte_offset_bits:b_len] rmeta = meta_mem.read(idx, ren) rdata = m.concat(*(mem.read(idx, ren) for mem in data_mem)) rdata_buf = m.Register(type(rdata), has_enable=True)()(rdata, CE=ren_reg) read = m.mux([ m.as_bits(m.mux([rdata_buf, rdata], ren_reg)), m.as_bits(refill_buf.O) ], is_alloc_reg) # m.display("is_alloc_reg=%x", is_alloc_reg)\ # .when(m.posedge(self.io.CLK)) hit @= v.O[idx_reg] & (rmeta.tag == tag_reg) # read mux self.io.cpu.resp.data.data @= m.array( [read[i * x_len:(i + 1) * x_len] for i in range(n_words)])[off_reg] self.io.cpu.resp.valid @= (is_idle | (is_read & hit) | (is_alloc_reg & ~cpu_mask.O.reduce_or())) m.display("resp.valid=%x", self.io.cpu.resp.valid.value())\ .when(m.posedge(self.io.CLK)) m.display("[%0t]: valid = %x", m.time(), self.io.cpu.resp.valid.value())\ .when(m.posedge(self.io.CLK)) m.display("[%0t]: is_idle = %x, is_read = %x, hit = %x, is_alloc_reg = " "%x, ~cpu_mask.O.reduce_or() = %x", m.time(), is_idle, is_read, hit, is_alloc_reg, ~cpu_mask.O.reduce_or())\ .when(m.posedge(self.io.CLK)) m.display("[%0t]: refill_buf.O=%x, %x", m.time(), *refill_buf.O)\ .when(m.posedge(self.io.CLK))\ .if_(self.io.cpu.resp.valid.value() & is_alloc_reg) m.display("[%0t]: read=%x", m.time(), read)\ .when(m.posedge(self.io.CLK))\ .if_(self.io.cpu.resp.valid.value() & is_alloc_reg) addr_reg.I @= addr addr_reg.CE @= m.enable(self.io.cpu.resp.valid.value()) cpu_data.I @= self.io.cpu.req.data.data cpu_data.CE @= m.enable(self.io.cpu.resp.valid.value()) cpu_mask.I @= self.io.cpu.req.data.mask cpu_mask.CE @= m.enable(self.io.cpu.resp.valid.value()) wmeta = MetaData(name="wmeta") wmeta.tag @= tag_reg offset_mask = (m.zext_to(cpu_mask.O, w_bytes * 8) << m.concat( m.bits(0, byte_offset_bits), off_reg)) wmask = m.mux([m.SInt[w_bytes * 8](-1), m.sint(offset_mask)], ~is_alloc) if len(refill_buf.O) == 1: wdata_alloc = self.io.nasti.r.data.data else: wdata_alloc = m.concat( # TODO: not sure why they use `init.reverse` # https://github.com/ucb-bar/riscv-mini/blob/release/src/main/scala/Cache.scala#L116 m.concat(*refill_buf.O[:-1]), self.io.nasti.r.data.data) wdata = m.mux([wdata_alloc, m.as_bits(m.repeat(cpu_data.O, n_words))], ~is_alloc) v.I @= m.set_index(v.O, m.bit(True), idx_reg) v.CE @= m.enable(wen) d.I @= m.set_index(d.O, ~is_alloc, idx_reg) d.CE @= m.enable(wen) # m.display("[%0t]: refill_buf.O = %x", m.time(), # m.concat(*refill_buf.O)).when(m.posedge(self.io.CLK)).if_(wen) # m.display("[%0t]: nasti.r.data.data = %x", m.time(), # self.io.nasti.r.data.data).when(m.posedge(self.io.CLK)).if_(wen) meta_mem.write(wmeta, idx_reg, m.enable(wen & is_alloc)) for i, mem in enumerate(data_mem): data = [ wdata[i * x_len + j * 8:i * x_len + (j + 1) * 8] for j in range(w_bytes) ] mem.write(m.array(data), idx_reg, wmask[i * w_bytes:(i + 1) * w_bytes], m.enable(wen)) # m.display("[%0t]: wdata = %x, %x, %x, %x", m.time(), # *mem.WDATA.value()).when(m.posedge(self.io.CLK)).if_(wen) # m.display("[%0t]: wmask = %x, %x, %x, %x", m.time(), # *mem.WMASK.value()).when(m.posedge(self.io.CLK)).if_(wen) tag_and_idx = m.zext_to(m.concat(idx_reg, tag_reg), nasti_params.x_addr_bits) self.io.nasti.ar.data @= NastiReadAddressChannel( nasti_params, 0, tag_and_idx << m.Bits[len(tag_and_idx)](b_len), m.bitutils.clog2(nasti_params.x_data_bits // 8), data_beats - 1) rmeta_and_idx = m.zext_to(m.concat(idx_reg, rmeta.tag), nasti_params.x_addr_bits) self.io.nasti.aw.data @= NastiWriteAddressChannel( nasti_params, 0, rmeta_and_idx << m.Bits[len(rmeta_and_idx)](b_len), m.bitutils.clog2(nasti_params.x_data_bits // 8), data_beats - 1) self.io.nasti.w.data @= NastiWriteDataChannel( nasti_params, m.array([ read[i * nasti_params.x_data_bits:(i + 1) * nasti_params.x_data_bits] for i in range(data_beats) ])[write_count[:-1]], None, write_wrap_out) is_dirty = v.O[idx_reg] & d.O[idx_reg] # TODO: Have to use temporary so we can invoke `fired()` aw_valid = m.Bit(name="aw_valid") self.io.nasti.aw.valid @= aw_valid ar_valid = m.Bit(name="ar_valid") self.io.nasti.ar.valid @= ar_valid b_ready = m.Bit(name="b_ready") self.io.nasti.b.ready @= b_ready @m.inline_combinational() def logic(): state.I @= state.O aw_valid @= False ar_valid @= False self.io.nasti.w.valid @= False b_ready @= False if state.O == State.IDLE: if self.io.cpu.req.valid: if self.io.cpu.req.data.mask.reduce_or(): state.I @= State.WRITE_CACHE else: state.I @= State.READ_CACHE elif state.O == State.READ_CACHE: if hit: if self.io.cpu.req.valid: if self.io.cpu.req.data.mask.reduce_or(): state.I @= State.WRITE_CACHE else: state.I @= State.READ_CACHE else: state.I @= State.IDLE else: aw_valid @= is_dirty ar_valid @= ~is_dirty if self.io.nasti.aw.fired(): state.I @= State.WRITE_BACK elif self.io.nasti.ar.fired(): state.I @= State.REFILL elif state.O == State.WRITE_CACHE: if hit | is_alloc_reg | self.io.cpu.abort: state.I @= State.IDLE else: aw_valid @= is_dirty ar_valid @= ~is_dirty if self.io.nasti.aw.fired(): state.I @= State.WRITE_BACK elif self.io.nasti.ar.fired(): state.I @= State.REFILL elif state.O == State.WRITE_BACK: self.io.nasti.w.valid @= True if write_wrap_out: state.I @= State.WRITE_ACK elif state.O == State.WRITE_ACK: b_ready @= True if self.io.nasti.b.fired(): state.I @= State.REFILL_READY elif state.O == State.REFILL_READY: ar_valid @= True if self.io.nasti.ar.fired(): state.I @= State.REFILL elif state.O == State.REFILL: if read_wrap_out: if cpu_mask.O.reduce_or(): state.I @= State.WRITE_CACHE else: state.I @= State.IDLE if data_beats > 1: # TODO: Have to do this at the end since the inline comb logic # wires up nasti.w write_counter.CE @= m.enable(self.io.nasti.w.fired())