Exemplo n.º 1
0
def materialize_graph(
        defn: m.DefineCircuitKind) -> (Iterable[NodeType], Iterable[EdgeType]):
    seen = set()
    node_indices = {}  # in order to get consistently ordered vertices
    edges = []
    queue = []
    for port in defn.interface.ports.values():
        for bit in m.as_bits(port):
            node = BitPortNode(ScopedBit(bit, Scope(defn)))
            queue.append(node)
    graph = SimpleDirectedGraphViewBase(defn)
    while queue:
        node = queue.pop(0)
        if node in seen:
            continue
        seen.add(node)
        node_indices[node] = len(node_indices)
        for i in graph.incoming(node):
            edges.append((i, node))
            if i not in seen:
                queue.append(i)
        for o in graph.outgoing(node):
            if o not in seen:
                queue.append(o)
    assert len(node_indices) == len(seen)
    nodes = sorted(list(seen), key=lambda n: node_indices[n])
    return nodes, edges
Exemplo n.º 2
0
 def _get_primitive_drivees(
     self, primitive: m.DefineCircuitKind,
     inst_bit: m.In(m.Bit)) -> Iterable[m.Bit]:
     assert inst_bit.is_input()
     for port in primitive.interface.ports.values():
         port_as_bits = m.as_bits(port)
         for other_bit in port_as_bits:
             if other_bit.is_input():
                 yield other_bit
Exemplo n.º 3
0
	def __init__(self, ARES_design_type, depth):
		#这个self.name 和self.io必须得有
		#self.name = "ARES_FIFO_DESIGN"
		self.io = io = m.IO(ARES_design = ARES_design_type)
		#io += m.ClockIO()
		
		addr_width = m.bitutils.clog2(depth)
		print ("ARES addr width : " + str(addr_width) )
		buffer = mantle.RAM(2**addr_width, io.ARES_design.WData.flat_length())
		
		buffer.WDATA @= m.as_bits(io.ARES_design.WData)
		io.ARES_design.RData @= buffer.RDATA
		
		read_pointer = mantle.Register(addr_width + 1)
		write_pointer = mantle.Register(addr_width + 1)
		buffer.RADDR @= read_pointer.O[:addr_width]
		buffer.WADDR @= write_pointer.O[:addr_width]
		
		reset = io.ARES_design.RESET

		full = \
			(read_pointer.O[:addr_width] == write_pointer.O[:addr_width]) \
			& \
			(read_pointer.O[addr_width] != write_pointer.O[addr_width])
		
		empty = read_pointer.O == write_pointer.O
		write_valid = io.ARES_design.Write & ~full
		read_valid = io.ARES_design.Read & ~empty
	
		io.ARES_design.Full @= full
		
		buffer.WE @= write_valid

		write_p = mantle.mux([
			write_pointer.O, m.uint(write_pointer.O) + 1
		], write_valid)
		
		write_pointer.I @= mantle.mux([
			write_p, 0
		], reset)
	
		io.ARES_design.Empty @= empty
		
		read_p = mantle.mux([
			read_pointer.O, m.uint(read_pointer.O) + 1
		], read_valid)

		read_pointer.I @= mantle.mux([
			read_p, 0 
		], reset)
Exemplo n.º 4
0
	class FIFO(m.Circuit):
		io = m.IO(ARES_design = ARES_design_type)
		#io += m.ClockIO()
		
		addr_width = m.bitutils.clog2(depth)
		buffer = mantle.RAM(2**addr_width, io.ARES_design.WData.flat_length())
		
		buffer.WDATA @= m.as_bits(io.ARES_design.WData)
		io.ARES_design.RData @= buffer.RDATA
		
		read_pointer = mantle.Register(addr_width + 1)
		write_pointer = mantle.Register(addr_width + 1)
		buffer.RADDR @= read_pointer.O[:addr_width]
		buffer.WADDR @= write_pointer.O[:addr_width]
		
		reset = io.ARES_design.RESET

		full = \
			(read_pointer.O[:addr_width] == write_pointer.O[:addr_width]) \
			& \
			(read_pointer.O[addr_width] != write_pointer.O[addr_width])
		
		empty = read_pointer.O == write_pointer.O
		write_valid = io.ARES_design.Write & ~full
		read_valid = io.ARES_design.Read & ~empty
	
		io.ARES_design.Full @= full
		
		buffer.WE @= write_valid

		write_p = mantle.mux([
			write_pointer.O, m.uint(write_pointer.O) + 1
		], write_valid)
		
		write_pointer.I @= mantle.mux([
			write_p, 0
		], reset)
	
		io.ARES_design.Empty @= empty
		
		read_p = mantle.mux([
			read_pointer.O, m.uint(read_pointer.O) + 1
		], read_valid)

		read_pointer.I @= mantle.mux([
			read_p, 0 
		], reset)
Exemplo n.º 5
0
def _lift_instance_inputs(
    reconstructor: _CircuitReconstructor
) -> (Tuple[List[BitPortNode], List[BitPortNode]]):
    pi = []
    po = []

    def _process_value(value, scoped_inst, inst):
        node = BitPortNode(ScopedBit(value, scoped_inst.scope))
        sel = InstSelector(m.value_utils.make_selector(node.bit.value),
                           node.bit.ref.name)
        if value.is_input():
            if isinstance(value, m.ClockTypes):
                return
            if node in reconstructor.node_to_bit:
                assert reconstructor.node_to_bit[node].driven()
                return
            assert node not in reconstructor.node_to_bit
            new_value = reconstructor.add_or_get_bit(node)
            inst_value = sel.select(inst)
            assert not inst_value.driven()
            inst_value @= new_value
            pi.append(node)
        elif value.is_output():
            if node in reconstructor.node_to_bit:
                return
            new_value = reconstructor.add_or_get_bit(node)
            inst_value = sel.select(inst)
            assert not inst_value.driving()
            new_value @= inst_value
            po.append(node)
        else:
            raise NotImplementedError(value, type(value))

    for scoped_inst, inst in reconstructor.instance_map.items():
        for port in scoped_inst.inst.interface.ports.values():
            for bit in m.as_bits(port):
                _process_value(bit, scoped_inst, inst)

    return pi, po
Exemplo n.º 6
0
  def __init__(self, inputStartAddr: int, outputStartAddr: int, busWidth: int,
                     wordWidth: int, numWordsPerGroup: int, metricWidth: int):
    self.inputStartAddr = inputStartAddr
    self.outputStartAddr = outputStartAddr
    self.busWidth = busWidth
    self.io = io = m.IO(
      inputMemAddr = m.Out(m.UInt[64]),
      inputMemAddrValid = m.Out(m.Bit),
      inputMemAddrLen = m.Out(m.UInt[8]),
      inputMemAddrReady = m.In(m.Bit),
      inputMemBlock = m.In(m.UInt[busWidth]),
      inputMemBlockValid = m.In(m.Bit),
      inputMemBlockReady = m.Out(m.Bit),
      outputMemAddr = m.Out(m.UInt[64]),
      outputMemAddrValid = m.Out(m.Bit),
      outputMemAddrLen = m.Out(m.UInt[8]),
      outputMemAddrId = m.Out(m.UInt[16]),
      outputMemAddrReady = m.In(m.Bit),
      outputMemBlock = m.Out(m.UInt[busWidth]),
      outputMemBlockValid = m.Out(m.Bit),
      outputMemBlockLast = m.Out(m.Bit),
      outputMemBlockReady = m.In(m.Bit),
      finished = m.Out(m.Bit)
    ) + m.ClockIO(has_reset = True)

    assert(busWidth >= 64)
    numFeaturePairs = numWordsPerGroup * numWordsPerGroup
    outputWordsInLine = busWidth // 64
    numOutputWords = numFeaturePairs * (1 << (2 * wordWidth))
    # round up to nearest full line
    numOutputWords = (numOutputWords + outputWordsInLine - 1) // outputWordsInLine * \
      outputWordsInLine
    bytesInLine = busWidth // 8

    class TopState(m.Enum):
      inputLengthAddr = 0
      loadInputLength = 1
      mainLoop = 2
      pause = 3
      writeOutput = 4
      finished = 5

    class OutputState(m.Enum):
      sendingAddr = 0
      fillingLine = 1
      sendingLine = 2

    state = reg_init(TopState, TopState.inputLengthAddr)
    inputLength = reg(m.UInt[32])
    inputAddrLineCount = reg_init(m.UInt[32], 0)
    inputDataLineCount = reg_init(m.UInt[32], 0)
    outputState = reg_init(OutputState, OutputState.sendingAddr)
    outputWordCounter = reg_init(m.UInt[m.bitutils.clog2(numOutputWords + 1)], 0)
    outputLine = reg(m.Array[outputWordsInLine, m.UInt[64]])

    featurePairs = []
    for i in range(numWordsPerGroup):
      for j in range(numWordsPerGroup):
        idx = i * numWordsPerGroup + j
        featurePair = FeaturePair(wordWidth, metricWidth, idx)()
        featurePairs.append(featurePair)
        featurePair.inputMetric @= io.inputMemBlock[2 * numWordsPerGroup * wordWidth:
          2 * numWordsPerGroup * wordWidth + metricWidth]
        featurePair.inputFeatureOne @= io.inputMemBlock[i * wordWidth:(i + 1) * wordWidth]
        featurePair.inputFeatureTwo @= io.inputMemBlock[(j + numWordsPerGroup) * wordWidth:
          (j + 1 + numWordsPerGroup) * wordWidth]
        featurePair.inputValid @= io.inputMemBlockValid & (state.O == TopState.mainLoop)
        featurePair.shiftMode @= state.O == TopState.writeOutput
        featurePair.doShift @= (state.O == TopState.writeOutput) & (outputState.O == OutputState.fillingLine)
    io.inputMemBlock[max(32, 2 * numWordsPerGroup * wordWidth + metricWidth):].unused()
    for i in range(numFeaturePairs):
      if i == numFeaturePairs - 1:
        featurePairs[i].neighborOutputIn @= 0
      else:
        featurePairs[i].neighborOutputIn @= featurePairs[i + 1].out

    io.inputMemAddrValid @= (state.O == TopState.inputLengthAddr) | \
      ((state.O == TopState.mainLoop) & (inputAddrLineCount.O != inputLength.O))
    io.inputMemBlockReady @= (state.O == TopState.loadInputLength) | (state.O == TopState.mainLoop)
    io.outputMemAddr @= m.zext_to(sl(outputWordCounter.O, 3), 64) + outputStartAddr
    io.outputMemAddrValid @= (state.O == TopState.writeOutput) & (outputState.O == OutputState.sendingAddr)
    io.outputMemAddrLen @= 0
    io.outputMemAddrId @= 0
    io.outputMemBlock @= m.as_bits(outputLine.O)
    io.outputMemBlockValid @= (state.O == TopState.writeOutput) & (outputState.O == OutputState.sendingLine)
    io.outputMemBlockLast @= True
    io.finished @= state.O == TopState.finished

    # hard to put this inside the inline comb
    cond = (state.O == TopState.writeOutput) & (outputState.O == OutputState.fillingLine)
    outputLine.I[outputWordsInLine - 1] @= \
      m.mux([outputLine.O[outputWordsInLine - 1], featurePairs[0].out], cond)
    for i in range(len(outputLine.I) - 1):
      outputLine.I[i] @= m.mux([outputLine.O[i], outputLine.O[i + 1]], cond)

    @m.inline_combinational()
    def logic():
      io.inputMemAddr @= inputStartAddr
      io.inputMemAddrLen @= 0
      # default values required
      state.I @= state.O
      inputAddrLineCount.I @= inputAddrLineCount.O
      inputDataLineCount.I @= inputDataLineCount.O
      outputState.I @= outputState.O
      outputWordCounter.I @= outputWordCounter.O
      if state.O == TopState.inputLengthAddr:
        if io.inputMemAddrReady:
          state.I @= TopState.loadInputLength
      elif state.O == TopState.loadInputLength:
        if io.inputMemBlockValid:
          inputLength.I @= io.inputMemBlock[:32]
          state.I @= TopState.mainLoop
      elif state.O == TopState.mainLoop:
        io.inputMemAddr @= m.zext_to(sl(inputAddrLineCount.O, m.bitutils.clog2(bytesInLine)), 64) + \
          (inputStartAddr + bytesInLine) # final term is start offset of main data stream
        remainingAddrLen = inputLength.O - inputAddrLineCount.O - 1
        io.inputMemAddrLen @= 63 if remainingAddrLen > 63 else remainingAddrLen[:8]
        if io.inputMemAddrReady:
          inputAddrLineCount.I @= inputAddrLineCount.O + 64 if remainingAddrLen > 63 else inputLength.O
        if io.inputMemBlockValid:
          inputDataLineCount.I @= inputDataLineCount.O + 1
          if inputDataLineCount.O == inputLength.O - 1:
            state.I @= TopState.pause
      elif state.O == TopState.pause:
        # required to flush FeaturePair pipeline before shiftMode is set
        state.I @= TopState.writeOutput
      elif state.O == TopState.writeOutput:
        if outputState.O == OutputState.sendingAddr:
          if io.outputMemAddrReady:
            outputState.I @= OutputState.fillingLine
        elif outputState.O == OutputState.fillingLine:
          wordInLine = 0 if m.bit(outputWordsInLine == 1) else \
            outputWordCounter[:max(1, m.bitutils.clog2(outputWordsInLine))]
          if m.bit(wordInLine == outputWordsInLine - 1): # TODO figure out why m.bit is needed here
            outputState.I @= OutputState.sendingLine
          outputWordCounter.I @= outputWordCounter.O + 1
        else: # outputState is sendingLine
          if io.outputMemBlockReady:
            if outputWordCounter.O == numOutputWords:
              state.I @= TopState.finished
            else:
              outputState.I @= OutputState.sendingAddr
Exemplo n.º 7
0
    def __init__(self, x_len, n_ways: int, n_sets: int, b_bytes: int):
        b_bits = b_bytes << 3
        b_len = m.bitutils.clog2(b_bytes)
        s_len = m.bitutils.clog2(n_sets)
        t_len = x_len - (s_len + b_len)
        n_words = b_bits // x_len
        w_bytes = x_len // 8
        byte_offset_bits = m.bitutils.clog2(w_bytes)
        nasti_params = NastiParameters(data_bits=64,
                                       addr_bits=x_len,
                                       id_bits=5)
        data_beats = b_bits // nasti_params.x_data_bits

        class MetaData(m.Product):
            tag = m.UInt[t_len]

        self.io = m.IO(**make_cache_ports(x_len, nasti_params))
        self.io += m.ClockIO()

        class State(m.Enum):
            IDLE = 0
            READ_CACHE = 1
            WRITE_CACHE = 2
            WRITE_BACK = 3
            WRITE_ACK = 4
            REFILL_READY = 5
            REFILL = 6

        state = m.Register(init=State.IDLE)()

        # memory
        v = m.Register(m.UInt[n_sets], has_enable=True)()
        d = m.Register(m.UInt[n_sets], has_enable=True)()
        meta_mem = m.Memory(n_sets,
                            MetaData,
                            read_latency=1,
                            has_read_enable=True)()
        data_mem = [
            ArrayMaskMem(n_sets,
                         w_bytes,
                         m.UInt[8],
                         read_latency=1,
                         has_read_enable=True)() for _ in range(n_words)
        ]

        addr_reg = m.Register(type(self.io.cpu.req.data.addr).undirected_t,
                              has_enable=True)()
        cpu_data = m.Register(type(self.io.cpu.req.data.data).undirected_t,
                              has_enable=True)()
        cpu_mask = m.Register(type(self.io.cpu.req.data.mask).undirected_t,
                              has_enable=True)()

        self.io.nasti.r.ready @= state.O == State.REFILL
        # Counters
        assert data_beats > 0
        if data_beats > 1:
            read_counter = mantle.CounterModM(data_beats,
                                              max(data_beats.bit_length(), 1),
                                              has_ce=True)
            read_counter.CE @= m.enable(self.io.nasti.r.fired())
            read_count, read_wrap_out = read_counter.O, read_counter.COUT

            write_counter = mantle.CounterModM(data_beats,
                                               max(data_beats.bit_length(), 1),
                                               has_ce=True)
            write_count, write_wrap_out = write_counter.O, write_counter.COUT
        else:
            read_count, read_wrap_out = 0, 1
            write_count, write_wrap_out = 0, 1

        refill_buf = m.Register(m.Array[data_beats,
                                        m.UInt[nasti_params.x_data_bits]],
                                has_enable=True)()
        if data_beats == 1:
            refill_buf.I[0] @= self.io.nasti.r.data.data
        else:
            refill_buf.I @= m.set_index(refill_buf.O,
                                        self.io.nasti.r.data.data,
                                        read_count[:-1])
        refill_buf.CE @= m.enable(self.io.nasti.r.fired())

        is_idle = state.O == State.IDLE
        is_read = state.O == State.READ_CACHE
        is_write = state.O == State.WRITE_CACHE
        is_alloc = (state.O == State.REFILL) & read_wrap_out
        # m.display("[%0t]: is_alloc = %x", m.time(), is_alloc)\
        #     .when(m.posedge(self.io.CLK))
        is_alloc_reg = m.Register(m.Bit)()(is_alloc)

        hit = m.Bit(name="hit")
        wen = is_write & (hit | is_alloc_reg) & ~self.io.cpu.abort | is_alloc
        # m.display("[%0t]: wen = %x", m.time(), wen)\
        #     .when(m.posedge(self.io.CLK))
        ren = m.enable(~wen & (is_idle | is_read) & self.io.cpu.req.valid)
        ren_reg = m.enable(m.Register(m.Bit)()(ren))

        addr = self.io.cpu.req.data.addr
        idx = addr[b_len:s_len + b_len]
        tag_reg = addr_reg.O[s_len + b_len:x_len]
        idx_reg = addr_reg.O[b_len:s_len + b_len]
        off_reg = addr_reg.O[byte_offset_bits:b_len]

        rmeta = meta_mem.read(idx, ren)
        rdata = m.concat(*(mem.read(idx, ren) for mem in data_mem))
        rdata_buf = m.Register(type(rdata), has_enable=True)()(rdata,
                                                               CE=ren_reg)

        read = m.mux([
            m.as_bits(m.mux([rdata_buf, rdata], ren_reg)),
            m.as_bits(refill_buf.O)
        ], is_alloc_reg)
        # m.display("is_alloc_reg=%x", is_alloc_reg)\
        #     .when(m.posedge(self.io.CLK))

        hit @= v.O[idx_reg] & (rmeta.tag == tag_reg)

        # read mux
        self.io.cpu.resp.data.data @= m.array(
            [read[i * x_len:(i + 1) * x_len] for i in range(n_words)])[off_reg]
        self.io.cpu.resp.valid @= (is_idle | (is_read & hit) |
                                   (is_alloc_reg & ~cpu_mask.O.reduce_or()))
        m.display("resp.valid=%x", self.io.cpu.resp.valid.value())\
            .when(m.posedge(self.io.CLK))
        m.display("[%0t]: valid = %x", m.time(),
                  self.io.cpu.resp.valid.value())\
            .when(m.posedge(self.io.CLK))
        m.display("[%0t]: is_idle = %x, is_read = %x, hit = %x, is_alloc_reg = "
                  "%x, ~cpu_mask.O.reduce_or() = %x", m.time(), is_idle,
                  is_read, hit, is_alloc_reg, ~cpu_mask.O.reduce_or())\
            .when(m.posedge(self.io.CLK))
        m.display("[%0t]: refill_buf.O=%x, %x", m.time(), *refill_buf.O)\
            .when(m.posedge(self.io.CLK))\
            .if_(self.io.cpu.resp.valid.value() & is_alloc_reg)
        m.display("[%0t]: read=%x", m.time(), read)\
            .when(m.posedge(self.io.CLK))\
            .if_(self.io.cpu.resp.valid.value() & is_alloc_reg)

        addr_reg.I @= addr
        addr_reg.CE @= m.enable(self.io.cpu.resp.valid.value())

        cpu_data.I @= self.io.cpu.req.data.data
        cpu_data.CE @= m.enable(self.io.cpu.resp.valid.value())

        cpu_mask.I @= self.io.cpu.req.data.mask
        cpu_mask.CE @= m.enable(self.io.cpu.resp.valid.value())

        wmeta = MetaData(name="wmeta")
        wmeta.tag @= tag_reg

        offset_mask = (m.zext_to(cpu_mask.O, w_bytes * 8) << m.concat(
            m.bits(0, byte_offset_bits), off_reg))
        wmask = m.mux([m.SInt[w_bytes * 8](-1),
                       m.sint(offset_mask)], ~is_alloc)

        if len(refill_buf.O) == 1:
            wdata_alloc = self.io.nasti.r.data.data
        else:
            wdata_alloc = m.concat(
                # TODO: not sure why they use `init.reverse`
                # https://github.com/ucb-bar/riscv-mini/blob/release/src/main/scala/Cache.scala#L116
                m.concat(*refill_buf.O[:-1]),
                self.io.nasti.r.data.data)
        wdata = m.mux([wdata_alloc,
                       m.as_bits(m.repeat(cpu_data.O, n_words))], ~is_alloc)

        v.I @= m.set_index(v.O, m.bit(True), idx_reg)
        v.CE @= m.enable(wen)
        d.I @= m.set_index(d.O, ~is_alloc, idx_reg)
        d.CE @= m.enable(wen)
        # m.display("[%0t]: refill_buf.O = %x", m.time(),
        #           m.concat(*refill_buf.O)).when(m.posedge(self.io.CLK)).if_(wen)
        # m.display("[%0t]: nasti.r.data.data = %x", m.time(),
        #           self.io.nasti.r.data.data).when(m.posedge(self.io.CLK)).if_(wen)

        meta_mem.write(wmeta, idx_reg, m.enable(wen & is_alloc))
        for i, mem in enumerate(data_mem):
            data = [
                wdata[i * x_len + j * 8:i * x_len + (j + 1) * 8]
                for j in range(w_bytes)
            ]
            mem.write(m.array(data), idx_reg,
                      wmask[i * w_bytes:(i + 1) * w_bytes], m.enable(wen))
            # m.display("[%0t]: wdata = %x, %x, %x, %x", m.time(),
            #           *mem.WDATA.value()).when(m.posedge(self.io.CLK)).if_(wen)
            # m.display("[%0t]: wmask = %x, %x, %x, %x", m.time(),
            #           *mem.WMASK.value()).when(m.posedge(self.io.CLK)).if_(wen)

        tag_and_idx = m.zext_to(m.concat(idx_reg, tag_reg),
                                nasti_params.x_addr_bits)
        self.io.nasti.ar.data @= NastiReadAddressChannel(
            nasti_params, 0, tag_and_idx << m.Bits[len(tag_and_idx)](b_len),
            m.bitutils.clog2(nasti_params.x_data_bits // 8), data_beats - 1)

        rmeta_and_idx = m.zext_to(m.concat(idx_reg, rmeta.tag),
                                  nasti_params.x_addr_bits)
        self.io.nasti.aw.data @= NastiWriteAddressChannel(
            nasti_params, 0,
            rmeta_and_idx << m.Bits[len(rmeta_and_idx)](b_len),
            m.bitutils.clog2(nasti_params.x_data_bits // 8), data_beats - 1)

        self.io.nasti.w.data @= NastiWriteDataChannel(
            nasti_params,
            m.array([
                read[i * nasti_params.x_data_bits:(i + 1) *
                     nasti_params.x_data_bits] for i in range(data_beats)
            ])[write_count[:-1]], None, write_wrap_out)

        is_dirty = v.O[idx_reg] & d.O[idx_reg]

        # TODO: Have to use temporary so we can invoke `fired()`
        aw_valid = m.Bit(name="aw_valid")
        self.io.nasti.aw.valid @= aw_valid

        ar_valid = m.Bit(name="ar_valid")
        self.io.nasti.ar.valid @= ar_valid

        b_ready = m.Bit(name="b_ready")
        self.io.nasti.b.ready @= b_ready

        @m.inline_combinational()
        def logic():
            state.I @= state.O
            aw_valid @= False
            ar_valid @= False
            self.io.nasti.w.valid @= False
            b_ready @= False
            if state.O == State.IDLE:
                if self.io.cpu.req.valid:
                    if self.io.cpu.req.data.mask.reduce_or():
                        state.I @= State.WRITE_CACHE
                    else:
                        state.I @= State.READ_CACHE
            elif state.O == State.READ_CACHE:
                if hit:
                    if self.io.cpu.req.valid:
                        if self.io.cpu.req.data.mask.reduce_or():
                            state.I @= State.WRITE_CACHE
                        else:
                            state.I @= State.READ_CACHE
                    else:
                        state.I @= State.IDLE
                else:
                    aw_valid @= is_dirty
                    ar_valid @= ~is_dirty
                    if self.io.nasti.aw.fired():
                        state.I @= State.WRITE_BACK
                    elif self.io.nasti.ar.fired():
                        state.I @= State.REFILL
            elif state.O == State.WRITE_CACHE:
                if hit | is_alloc_reg | self.io.cpu.abort:
                    state.I @= State.IDLE
                else:
                    aw_valid @= is_dirty
                    ar_valid @= ~is_dirty
                    if self.io.nasti.aw.fired():
                        state.I @= State.WRITE_BACK
                    elif self.io.nasti.ar.fired():
                        state.I @= State.REFILL
            elif state.O == State.WRITE_BACK:
                self.io.nasti.w.valid @= True
                if write_wrap_out:
                    state.I @= State.WRITE_ACK
            elif state.O == State.WRITE_ACK:
                b_ready @= True
                if self.io.nasti.b.fired():
                    state.I @= State.REFILL_READY
            elif state.O == State.REFILL_READY:
                ar_valid @= True
                if self.io.nasti.ar.fired():
                    state.I @= State.REFILL
            elif state.O == State.REFILL:
                if read_wrap_out:
                    if cpu_mask.O.reduce_or():
                        state.I @= State.WRITE_CACHE
                    else:
                        state.I @= State.IDLE

        if data_beats > 1:
            # TODO: Have to do this at the end since the inline comb logic
            # wires up nasti.w
            write_counter.CE @= m.enable(self.io.nasti.w.fired())