def elab(self, m): # number of bytes received already byte_count = Signal(range(4)) # output channel for next byte received pixel_index = Signal(range(self._num_pixels)) # shift registers to buffer 3 incoming values shift_registers = Array( [Signal(24, name=f"sr_{i}") for i in range(self._num_pixels)]) last_pixel_index = Signal.like(pixel_index) m.d.sync += last_pixel_index.eq(Mux(self.half_mode, self._num_pixels // 2 - 1, self._num_pixels - 1)) next_pixel_index = Signal.like(pixel_index) m.d.comb += next_pixel_index.eq(Mux(pixel_index == last_pixel_index, 0, pixel_index + 1)) m.d.comb += self.input.ready.eq(1) with m.If(self.input.is_transferring()): sr = shift_registers[pixel_index] with m.If(byte_count == 3): # Output current value and shift register m.d.comb += self.output.valid.eq(1) payload = Cat(sr, self.input.payload) m.d.comb += self.output.payload.eq(payload) with m.Else(): # Save input to shift register m.d.sync += sr[-8:].eq(self.input.payload) m.d.sync += sr[:-8].eq(sr[8:]) # Increment pixel index m.d.sync += pixel_index.eq(next_pixel_index) with m.If(pixel_index == last_pixel_index): m.d.sync += pixel_index.eq(0) m.d.sync += byte_count.eq(byte_count + 1) # allow rollover
def build_param_store(self, m): # Create memory for post process params # Use SinglePortMemory here? param_mem = Memory(width=POST_PROCESS_PARAMS_WIDTH, depth=Constants.MAX_CHANNEL_DEPTH) m.submodules['param_rp'] = rp = param_mem.read_port(transparent=False) m.submodules['param_wp'] = wp = param_mem.write_port() # Configure param writer m.submodules['param_writer'] = pw = ParamWriter() m.d.comb += connect(self.post_process_params, pw.input_data) m.d.comb += [ pw.reset.eq(self.reset), wp.en.eq(pw.mem_we), wp.addr.eq(pw.mem_addr), wp.data.eq(pw.mem_data), ] # Configure param reader m.submodules['param_reader'] = reader = ReadingProducer() repeats = Mux(self.config.mode, Constants.SYS_ARRAY_HEIGHT, Constants.SYS_ARRAY_HEIGHT // 2) m.d.comb += [ # Reset reader whenever new parameters are written reader.reset.eq(pw.input_data.is_transferring()), reader.sizes.depth.eq(self.config.output_channel_depth), reader.sizes.repeats.eq(repeats), rp.addr.eq(reader.mem_addr), reader.mem_data.eq(rp.data), ] return reader.output_data
def elab(self, m: Module): buffering = Signal() # True if there is a value being buffered buffered_value = Signal.like(Value.cast(self.input.payload)) # Pipe valid and ready back and forth m.d.comb += [ self.input.ready.eq(~buffering | self.output.ready), self.output.valid.eq(buffering | self.input.valid), self.output.payload.eq( Mux(buffering, buffered_value, self.input.payload)) ] # Buffer when have incoming value but cannot output just now with m.If(~buffering & ~self.output.ready & self.input.valid): m.d.sync += buffering.eq(True) m.d.sync += buffered_value.eq(self.input.payload) # Handle cases when transfering out from buffer with m.If(buffering & self.output.ready): with m.If(self.input.valid): m.d.sync += buffered_value.eq(self.input.payload) with m.Else(): m.d.sync += buffering.eq(False) # Reset all state with m.If(self.reset): m.d.sync += buffering.eq(False) m.d.sync += buffered_value.eq(0)
def transform(self, m, in_value, out_value): # Cycle 0: register inputs dividend = Signal(signed(32)) shift = Signal(4) m.d.sync += dividend.eq(in_value.dividend) m.d.sync += shift.eq(in_value.shift) # Cycle 1: calculate result = Signal(signed(32)) remainder = Signal(signed(32)) # Our threshold looks like 010, 0100, 01000 etc for positive values and # 011, 0101, 01001 etc for negative values. threshold = Signal(signed(32)) quotient = Signal(signed(32)) negative = Signal() m.d.comb += negative.eq(dividend < 0) with m.Switch(shift): for n in range(2, 13): with m.Case(n): mask = (1 << n) - 1 m.d.comb += remainder.eq(dividend & mask) m.d.comb += threshold[1:].eq(1 << (n - 2)) m.d.comb += quotient.eq(dividend >> n) m.d.comb += threshold[0].eq(negative) m.d.sync += result.eq(quotient + Mux(remainder >= threshold, 1, 0)) # Cycle 2: send output m.d.sync += out_value.eq(result)
def elab(self, m): # TODO: reimplement as a function that returns an expression mask = (1 << self.exponent) - 1 remainder = self.x & mask threshold = (mask >> 1) + self.x[31] rounding = Mux(remainder > threshold, 1, 0) m.d.comb += self.result.eq((self.x >> self.exponent) + rounding)
def elab(self, m): # Track previous restart, next was_restart = Signal() m.d.sync += was_restart.eq(self.restart) was_next = Signal() m.d.sync += was_next.eq(self.next) # Decide address to be output (determines data available next cycle) last_mem_addr = Signal.like(self.mem_addr) m.d.sync += last_mem_addr.eq(self.mem_addr) incremented_addr = Signal.like(self.limit) m.d.comb += incremented_addr.eq( Mux(last_mem_addr == self.limit - 1, 0, last_mem_addr + 1)) with m.If(self.restart): m.d.comb += self.mem_addr.eq(0) with m.Elif(was_next | was_restart): m.d.comb += self.mem_addr.eq(incremented_addr) with m.Else(): m.d.comb += self.mem_addr.eq(last_mem_addr) # Decide data to be output last_data = Signal.like(self.data) m.d.sync += last_data.eq(self.data) with m.If(was_restart | was_next): m.d.comb += self.data.eq(self.mem_data) with m.Else(): m.d.comb += self.data.eq(last_data)
def handle_reading(self, m, memory, num_words, index, reset): """Handle stepping through memory on read.""" m.d.comb += memory.read_addr.eq(index[2:]) read_port_valid = Signal() read_started = Signal() with m.If(~self.data_output.valid | self.data_output.is_transferring()): m.d.sync += self.data_output.payload.eq(memory.read_data) m.d.sync += self.data_output.valid.eq(read_port_valid) m.d.sync += read_port_valid.eq(0) with m.If(~read_port_valid & ~read_started): m.d.sync += index.eq(Mux(index == num_words - 4, 0, index + 4)) m.d.sync += read_started.eq(1) with m.Else(): m.d.sync += read_started.eq(0) with m.If(read_started): # Read port is valid next cycle if we started a new read this cycle m.d.sync += read_port_valid.eq(1) with m.If(reset): m.d.sync += read_port_valid.eq(0) m.d.sync += read_started.eq(0)
def rounding_divide_by_pot(x, exponent): """Implements gemmlowp::RoundingDivideByPOT This divides by a power of two, rounding to the nearest whole number. """ mask = (1 << exponent) - 1 remainder = x & mask threshold = (mask >> 1) + x[31] rounding = Mux(remainder > threshold, 1, 0) return (x >> exponent) + rounding
def elab(self, m): count = Signal.like(self.max) last_count = self.max - 1 next_count = Mux(count == last_count, 0, count + 1) with m.If(self.en): m.d.sync += count.eq(next_count) m.d.comb += self.done.eq(count == last_count) with m.If(self.restart): m.d.sync += count.eq(0)
def max_(word0, word1): result = [Signal(8, name=f"result{i}") for i in range(4)] bytes0 = [word0[i:i + 8] for i in range(0, 32, 8)] bytes1 = [word1[i:i + 8] for i in range(0, 32, 8)] for r, b0, b1 in zip(result, bytes0, bytes1): sb0 = Signal(signed(8)) m.d.comb += sb0.eq(b0) sb1 = Signal(signed(8)) m.d.comb += sb1.eq(b1) m.d.comb += r.eq(Mux(sb1 > sb0, b1, b0)) return Cat(*result)
def transform(self, m, in_value, out_value): # Cycle 0: register inputs a = in_value.a reg_a = Signal(signed(32)) reg_b = Signal(signed(32)) m.d.sync += reg_a.eq(Mux(a >= 0, a, -a)) m.d.sync += reg_b.eq(in_value.b) # Cycle 1: multiply to register # both operands are positive, so result always positive reg_ab = Signal(signed(63)) m.d.sync += reg_ab.eq(reg_a * reg_b) # Cycle 2: nudge, take high bits and sign positive_2 = self.delay(m, 2, a >= 0) # Whether input positive nudged = reg_ab + Mux(positive_2, (1 << 30), (1 << 30) - 1) high_bits = Signal(signed(32)) m.d.comb += high_bits.eq(nudged[31:]) with_sign = Mux(positive_2, high_bits, -high_bits) m.d.sync += out_value.eq(with_sign)
def increment_to_limit(value, limit): """Builds a statement that performs circular increments. Parameters ---------- value: Signal(n) The value to be incremented limit: Signal(n) The maximum allowed value """ return value.eq(Mux(value == limit - 1, 0, value + 1))
def build_input_fetcher(self, m, stop): # Create fetchers f0 = self.create_fetcher(m, stop, 'f0', Mode0InputFetcher) f1 = self.create_fetcher(m, stop, 'f1', Mode1InputFetcher) # Additional config for fetcher1 repeats = (self.config.output_channel_depth // Const(Constants.SYS_ARRAY_WIDTH)) m.d.comb += [ f1.num_pixels_x.eq(self.config.num_pixels_x), f1.pixel_advance_x.eq(self.config.pixel_advance_x), f1.pixel_advance_y.eq(self.config.pixel_advance_y), f1.depth.eq(self.config.input_channel_depth >> 4), f1.num_repeats.eq(repeats), ] # Create RamMux and connect to LRAMs m.submodules['ram_mux'] = ram_mux = RamMux() mode = self.config.mode for i in range(4): # Connect to ram mux addr and data ports m.d.comb += self.lram_addr[i].eq(ram_mux.lram_addr[i]) m.d.comb += ram_mux.lram_data[i].eq(self.lram_data[i]) m.d.comb += f0.ram_mux_data[i].eq(ram_mux.data_out[i]) m.d.comb += f1.ram_mux_data[i].eq(ram_mux.data_out[i]) m.d.comb += ram_mux.addr_in[i].eq( Mux(mode, f1.ram_mux_addr[i], f0.ram_mux_addr[i])) # phase input depends on mode m.d.comb += ram_mux.phase.eq( Mux(mode, f1.ram_mux_phase, f0.ram_mux_phase)) # Router fetcher outputs depending on mode mode_first = Mux(mode, f1.first, f0.first) mode_last = Mux(mode, f1.last, f0.last) mode_data = [ Mux(mode, f1.data_out[i], f0.data_out[i]) for i in range(4) ] return (mode_first, mode_last, mode_data)
def elab(self, m): # Read_index tracks the address for memory 0 # self.start starts the address incrementing and reset stops it. read_index = Signal(range(Constants.FILTER_WORDS_PER_STORE)) running = Signal() with m.If(running): m.d.sync += read_index.eq( Mux(read_index == self.size - 1, 0, read_index + 1)) with m.If(self.start): m.d.sync += running.eq(True) m.d.sync += read_index.eq(0) # set up each memory for i in range(Constants.NUM_FILTER_STORES): m.submodules[f"mem{i}"] = mem = SinglePortMemory( data_width=32, depth=Constants.FILTER_WORDS_PER_STORE) # Handle writes to memory inp = self.write_input m.d.comb += [ mem.write_enable.eq(inp.valid & (inp.payload.store == i)), mem.write_addr.eq(inp.payload.addr), mem.write_data.eq(inp.payload.data), ] # Do reads continuously if i == 0: # i == 0 treated separately to avoid verilator warnings # (and save a few LUTs and FFs) m.d.comb += mem.read_addr.eq(read_index) else: addr = Signal(range(-i, Constants.FILTER_WORDS_PER_STORE), name=f"addr_{i}") m.d.comb += addr.eq(read_index - i) m.d.comb += mem.read_addr.eq( Mux(addr >= 0, addr, addr + self.size)) m.d.comb += self.values_out[i].eq(mem.read_data) # Always ready to receive more data m.d.comb += self.write_input.ready.eq(1)
def elab(self, m): # Current r_addr is the address being presented to memory this cycle # By default current address is last address, but this can be # modied by the reatart or next signals last_addr = Signal.like(self.r_addr) m.d.sync += last_addr.eq(self.r_addr) m.d.comb += self.r_addr.eq(last_addr) # Respond to inputs with m.If(self.restart): m.d.comb += self.r_addr.eq(0) with m.Elif(self.next): m.d.comb += self.r_addr.eq( Mux(last_addr >= self.limit - 1, 0, last_addr + 1))
def elab(self, m): with m.If(self.reset): m.d.sync += self.value.eq(0) with m.Else(): value_p1 = Signal.like(self.count) next_value = Signal.like(self.value) m.d.comb += [ value_p1.eq(self.value + 1), self.last.eq(value_p1 == self.count), next_value.eq(Mux(self.last, 0, value_p1)), ] with m.If(self.next): m.d.sync += self.value.eq(next_value)
def elab(self, m: Module): # One signal for each input stream, indicating whether # there is a value being buffered: buffering = {name: Signal(name=f'buffering_{name}') for name in self.field_names} # A buffer for each input stream: buffered_values = \ {name: Signal(self.field_shapes[name], name=f'buffered_{name}') for name in self.field_names} # For each field of the concatenated output, present either the # buffered value if we have one, or else plumb through the input. for name in self.field_names: m.d.comb += self.output.payload[name].eq( Mux(buffering[name], buffered_values[name], self.inputs[name].payload)) # The output is valid if we have either a buffered value or a valid # input for every slice in the output. valid_or_buffering = (Cat(*[ep.valid for ep in self.inputs.values()]) | Cat(*buffering.values())) m.d.comb += self.output.valid.eq(valid_or_buffering.all()) for name, input in self.inputs.items(): # We can accept an input if the buffer is not occupied, # or if we can output this cycle. m.d.comb += input.ready.eq(~buffering[name] | self.output.is_transferring()) with m.If(input.valid): # Buffer it if the buffer is not occupied and we can't output. with m.If(~buffering[name] & ~self.output.is_transferring()): m.d.sync += buffering[name].eq(True) m.d.sync += buffered_values[name].eq(input.payload) # Buffer it if the buffer is occupied but we are outputting. with m.If(self.output.is_transferring()): m.d.sync += buffered_values[name].eq(input.payload) with m.Else(): with m.If(self.output.is_transferring()): m.d.sync += buffering[name].eq(False) # Reset all state with m.If(self.reset): for b in buffering.values(): m.d.sync += b.eq(False) for bv in buffered_values.values(): m.d.sync += bv.eq(0)
def elab(self, m): running = Signal() count = Signal(2) with m.If(running): m.d.sync += count.eq(count + 1) # Allow rollover next_row = Signal(18) m.d.comb += self.addr_out.eq(Mux(count == 0, self.start_addr, next_row)) m.d.sync += next_row.eq(self.addr_out + self.INCREMENT_Y) m.d.comb += self.first.eq(running & (count == 0)) m.d.comb += self.last.eq(running & (count == 3)) with m.If(self.start): m.d.sync += running.eq(True) m.d.sync += count.eq(0) with m.If(self.reset): m.d.sync += running.eq(False) m.d.sync += count.eq(0)
def elaborate(self, platform): m = Module() # neat way of setting carry flag res_and_carry = Cat(self.res, self.carry) m.d.comb += res_and_carry.eq( Mux(self.sub, self.src1 - self.src2, self.src1 + self.src2)) with m.If(self.sub): with m.If((self.src1[-1] != self.src2[-1]) & (self.src1[-1] != self.res[-1])): m.d.comb += self.overflow.eq(1) with m.Else(): # add with m.If((self.src1[-1] == self.src2[-1]) & (self.src1[-1] != self.res[-1])): m.d.comb += self.overflow.eq(1) return m
def build_multipliers(self, m, accumulator): # Pipeline cycle 0: calculate products products = [] for i in range(self._n): a_bits = self.input_a.word_select(i, self._a_shape.width) b_bits = self.input_b.word_select(i, self._b_shape.width) a = Signal(self._a_shape, name=f"a_{i}") b = Signal(self._b_shape, name=f"b_{i}") m.d.comb += [ a.eq(a_bits), b.eq(b_bits), ] ab = Signal.like(a * b) m.d.sync += ab.eq(a * b) products.append(ab) # Pipeline cycle 1: accumulate product_sum = Signal.like(tree_sum(products)) m.d.comb += product_sum.eq(tree_sum(products)) first_delayed = delay(m, self.input_first, 1)[-1] base = Mux(first_delayed, 0, accumulator) m.d.sync += accumulator.eq(base + product_sum)
def elab(self, m): # This code covers 8 cases, determined by bits 1, 2 and 3 of self.addr. # First, bit 2 and 3 are used to select the appropriate ram_mux phase # and addresses in order to read the two words containing the required # data via channels 0 and 3 of the RAM Mux. Once the two words have been # retrieved, six bytes are selected from those two words based on the # value of bit 1 of self.addr. # Uses just two of the mux channels - 0 and 3 # For convenience, tie the unused addresses to zero m.d.comb += self.ram_mux_addr[1].eq(0) m.d.comb += self.ram_mux_addr[2].eq(0) # Calculate block addresses of the two words - second word may cross 16 # byte block boundary block = Signal(14) m.d.comb += block.eq(self.addr[4:]) m.d.comb += self.ram_mux_addr[0].eq(block) m.d.comb += self.ram_mux_addr[3].eq( Mux(self.ram_mux_phase == 3, block + 1, block)) # Use phase to select the two required words to channels 0 & 3 m.d.comb += self.ram_mux_phase.eq(self.addr[2:4]) # Select correct three half words when data is available, on cycle after # address received. byte_sel = Signal(1) m.d.sync += byte_sel.eq(self.addr[1]) d0 = self.ram_mux_data[0] d3 = self.ram_mux_data[3] dmix = Signal(32) m.d.comb += dmix.eq(Cat(d0[16:], d3[:16])) with m.If(byte_sel == 0): m.d.comb += self.data_out[0].eq(d0) m.d.sync += self.data_out[1].eq(dmix) with m.Else(): m.d.comb += self.data_out[0].eq(dmix) m.d.sync += self.data_out[1].eq(d3)
def elab(self, m): areg = Signal.like(self.a) breg = Signal.like(self.b) ab = Signal(signed(64)) overflow = Signal() # for some reason negative nudge is not used nudge = 1 << 30 # cycle 0, register a and b m.d.sync += [ areg.eq(self.a), breg.eq(self.b), ] # cycle 1, decide if this is an overflow and multiply m.d.sync += [ overflow.eq((areg == INT32_MIN) & (breg == INT32_MIN)), ab.eq(areg * breg), ] # cycle 2, apply nudge determine result m.d.sync += [ self.result.eq(Mux(overflow, INT32_MAX, (ab + nudge)[31:])), ]
def elaborate(self, platform): m = Module() sync = m.d.sync comb = m.d.comb cfg = self.mem_config # TODO XXX self.no_match on decoder m.submodules.bridge = GenericInterfaceToWishboneMasterBridge( generic_bus=self.generic_bus, wb_bus=self.wb_bus) self.decoder = m.submodules.decoder = WishboneBusAddressDecoder( wb_bus=self.wb_bus, word_size=cfg.word_size) self.initialize_mmio_devices(self.decoder, m) pe = m.submodules.pe = self.pe = PriorityEncoder(width=len(self.ports)) sorted_ports = [port for priority, port in sorted(self.ports.items())] # force 'elaborate' invocation for all mmio modules. for mmio_module, addr_space in self.mmio_cfg: setattr(m.submodules, addr_space.basename, mmio_module) addr_translation_en = self.addr_translation_en = Signal() bus_free_to_latch = self.bus_free_to_latch = Signal(reset=1) if self.with_addr_translation: m.d.comb += addr_translation_en.eq(self.csr_unit.satp.mode & ( self.exception_unit.current_priv_mode == PrivModeBits.USER)) else: m.d.comb += addr_translation_en.eq(False) with m.If(~addr_translation_en): # when translation enabled, 'bus_free_to_latch' is low during page-walk. # with no translation it's simpler - just look at the main bus. m.d.comb += bus_free_to_latch.eq(~self.generic_bus.busy) with m.If(bus_free_to_latch): # no transaction in-progress for i, p in enumerate(sorted_ports): m.d.sync += pe.i[i].eq(p.en) virtual_req_bus_latch = LoadStoreInterface() phys_addr = self.phys_addr = Signal(32) # translation-unit controller signals. start_translation = self.start_translation = Signal() translation_ack = self.translation_ack = Signal() gb = self.generic_bus with m.If(self.decoder.no_match & self.wb_bus.cyc): m.d.comb += self.exception_unit.badaddr.eq(gb.addr) with m.If(gb.store): m.d.comb += self.exception_unit.m_store_error.eq(1) with m.Elif(gb.is_fetch): m.d.comb += self.exception_unit.m_fetch_error.eq(1) with m.Else(): m.d.comb += self.exception_unit.m_load_error.eq(1) with m.If(~pe.none): # transaction request occured for i, priority in enumerate(sorted_ports): with m.If(pe.o == i): bus_owner_port = sorted_ports[i] with m.If(~addr_translation_en): # simple case, no need to calculate physical address comb += gb.connect(bus_owner_port) with m.Else(): # page-walk performs multiple memory operations - will reconnect 'generic_bus' multiple times with m.FSM(): first = self.first = Signal( ) # TODO get rid of that with m.State("TRANSLATE"): comb += [ start_translation.eq(1), bus_free_to_latch.eq(0), ] sync += virtual_req_bus_latch.connect( bus_owner_port, exclude=[ name for name, _, dir in generic_bus_layout if dir == DIR_FANOUT ]) with m.If( translation_ack ): # wait for 'phys_addr' to be set by page-walk algorithm. m.next = "REQ" sync += first.eq(1) with m.State("REQ"): comb += gb.connect(bus_owner_port, exclude=["addr"]) comb += gb.addr.eq( phys_addr) # found by page-walk with m.If(first): sync += first.eq(0) with m.Else(): # without 'first' signal '~gb.busy' would be high at the very beginning with m.If(~gb.busy): comb += bus_free_to_latch.eq(1) m.next = "TRANSLATE" req_is_write = Signal() pte = self.pte = Record(pte_layout) vaddr = Record(virt_addr_layout) comb += [ req_is_write.eq(virtual_req_bus_latch.store), vaddr.eq(virtual_req_bus_latch.addr), ] @unique class Issue(IntEnum): OK = 0 PAGE_INVALID = 1 WRITABLE_NOT_READABLE = 2 LACK_PERMISSIONS = 3 FIRST_ACCESS = 4 MISALIGNED_SUPERPAGE = 5 LEAF_IS_NO_LEAF = 6 self.error_code = Signal(Issue) def error(code: Issue): m.d.sync += self.error_code.eq(code) # Code below implements algorithm 4.3.2 in Risc-V Privileged specification, v1.10 sv32_i = Signal(reset=1) root_ppn = self.root_ppn = Signal(22) if not self.with_addr_translation: return m with m.FSM(): with m.State("IDLE"): with m.If(start_translation): sync += sv32_i.eq(1) sync += root_ppn.eq(self.csr_unit.satp.ppn) m.next = "TRANSLATE" with m.State("TRANSLATE"): vpn = self.vpn = Signal(10) comb += vpn.eq(Mux( sv32_i, vaddr.vpn1, vaddr.vpn0, )) comb += [ gb.en.eq(1), gb.addr.eq(Cat(Const(0, 2), vpn, root_ppn)), gb.store.eq(0), gb.mask.eq(0b1111), # TODO use -1 ] with m.If(gb.ack): sync += pte.eq(gb.read_data) m.next = "PROCESS_PTE" with m.State("PROCESS_PTE"): with m.If(~pte.v): error(Issue.PAGE_INVALID) with m.If(pte.w & ~pte.r): error(Issue.WRITABLE_NOT_READABLE) is_leaf = lambda pte: pte.r | pte.x with m.If(is_leaf(pte)): with m.If(~pte.u & (self.exception_unit.current_priv_mode == PrivModeBits.USER)): error(Issue.LACK_PERMISSIONS) with m.If(~pte.a | (req_is_write & ~pte.d)): error(Issue.FIRST_ACCESS) with m.If(sv32_i.bool() & pte.ppn0.bool()): error(Issue.MISALIGNED_SUPERPAGE) # phys_addr could be 34 bits long, but our interconnect is 32-bit long. # below statement cuts lowest two bits of r-value. sync += phys_addr.eq( Cat(vaddr.page_offset, pte.ppn0, pte.ppn1)) with m.Else(): # not a leaf with m.If(sv32_i == 0): error(Issue.LEAF_IS_NO_LEAF) sync += root_ppn.eq( Cat(pte.ppn0, pte.ppn1)) # pte a is pointer to the next level m.next = "NEXT" with m.State("NEXT"): # Note that we cannot check 'sv32_i == 0', becuase superpages can be present. with m.If(is_leaf(pte)): sync += sv32_i.eq(1) comb += translation_ack.eq( 1) # notify that 'phys_addr' signal is set m.next = "IDLE" with m.Else(): sync += sv32_i.eq(0) m.next = "TRANSLATE" return m
def elaborate(self, platform) -> Module: """build the module""" m = Module() sync = m.d.sync comb = m.d.comb nrzidecoder = NRZIDecoder(self.clk_freq) m.submodules.nrzi_decoder = nrzidecoder framedata_shifter = InputShiftRegister(24) m.submodules.framedata_shifter = framedata_shifter output_pulser = EdgeToPulse() m.submodules.output_pulser = output_pulser active_channel = Signal(3) # counts the number of bits output bit_counter = Signal(8) # counts the bit position inside a nibble nibble_counter = Signal(3) # counts, how many 0 bits it got in a row sync_bit_counter = Signal(4) comb += [ nrzidecoder.nrzi_in.eq(self.adat_in), self.synced_out.eq(nrzidecoder.running), self.recovered_clock_out.eq(nrzidecoder.recovered_clock_out), ] with m.FSM(): # wait for SYNC with m.State("WAIT_SYNC"): # reset invalid frame bit to be able to start again with m.If(nrzidecoder.invalid_frame_in): sync += nrzidecoder.invalid_frame_in.eq(0) with m.If(nrzidecoder.running): sync += [ bit_counter.eq(0), nibble_counter.eq(0), active_channel.eq(0), output_pulser.edge_in.eq(0) ] with m.If(nrzidecoder.data_out_en): m.d.sync += sync_bit_counter.eq(Mux(nrzidecoder.data_out, 0, sync_bit_counter + 1)) with m.If(sync_bit_counter == 9): m.d.sync += sync_bit_counter.eq(0) m.next = "READ_FRAME" with m.State("READ_FRAME"): # at which bit of bit_counter to output sample data at output_at = Signal(8) # user bits have been read with m.If(bit_counter == 5): sync += [ # output user bits self.user_data_out.eq(framedata_shifter.value_out[0:4]), # at bit 35 the first channel has been read output_at.eq(35) ] # when each channel has been read, output the channel's sample with m.If((bit_counter > 5) & (bit_counter == output_at)): sync += [ self.output_enable.eq(1), self.addr_out.eq(active_channel), self.sample_out.eq(framedata_shifter.value_out), output_at.eq(output_at + 30), active_channel.eq(active_channel + 1) ] with m.Else(): sync += self.output_enable.eq(0) # we work and count only when we get # a new bit fron the NRZI decoder with m.If(nrzidecoder.data_out_en): comb += [ framedata_shifter.bit_in.eq(nrzidecoder.data_out), # skip sync bit, which is first framedata_shifter.enable_in.eq(~(nibble_counter == 0)) ] sync += [ nibble_counter.eq(nibble_counter + 1), bit_counter.eq(bit_counter + 1), ] # check 4b/5b sync bit with m.If((nibble_counter == 0) & ~nrzidecoder.data_out): sync += nrzidecoder.invalid_frame_in.eq(1) m.next = "WAIT_SYNC" with m.Else(): sync += nrzidecoder.invalid_frame_in.eq(0) with m.If(nibble_counter >= 4): sync += nibble_counter.eq(0) # 239 channel bits and 5 user bits (including sync bits) with m.If(bit_counter >= (239 + 5)): sync += [ bit_counter.eq(0), output_pulser.edge_in.eq(1) ] m.next = "READ_SYNC" with m.Else(): comb += framedata_shifter.enable_in.eq(0) with m.If(~nrzidecoder.running): m.next = "WAIT_SYNC" # read the sync bits with m.State("READ_SYNC"): sync += [ self.output_enable.eq(output_pulser.pulse_out), self.addr_out.eq(active_channel), self.sample_out.eq(framedata_shifter.value_out), ] with m.If(nrzidecoder.data_out_en): sync += [ nibble_counter.eq(0), bit_counter.eq(bit_counter + 1), ] with m.If(bit_counter == 9): comb += [ framedata_shifter.enable_in.eq(0), framedata_shifter.clear_in.eq(1), ] #check last sync bit before sync trough with m.If((bit_counter == 0) & ~nrzidecoder.data_out): sync += nrzidecoder.invalid_frame_in.eq(1) m.next = "WAIT_SYNC" #check all the null bits in the sync trough with m.Elif((bit_counter > 0) & nrzidecoder.data_out): sync += nrzidecoder.invalid_frame_in.eq(1) m.next = "WAIT_SYNC" with m.Elif((bit_counter == 10) & ~nrzidecoder.data_out): sync += [ bit_counter.eq(0), nibble_counter.eq(0), active_channel.eq(0), output_pulser.edge_in.eq(0), nrzidecoder.invalid_frame_in.eq(0) ] m.next = "READ_FRAME" with m.Else(): sync += nrzidecoder.invalid_frame_in.eq(0) with m.If(~nrzidecoder.running): m.next = "WAIT_SYNC" return m
def elab(self, m): captured = Signal.like(self.input) with m.If(self.capture): m.d.sync += captured.eq(self.input) m.d.comb += self.output.eq(Mux(self.capture, self.input, captured))
def elab(self, m): m.submodules["pixel_ag"] = pixel_ag = PixelAddressGenerator() m.submodules["repeater"] = repeater = PixelAddressRepeater() value_ags = [ValueAddressGenerator() for _ in range(4)] for i, v in enumerate(value_ags): m.submodules[f"value_ag{i}"] = v # Connect pixel address generator and repeater m.d.comb += [ pixel_ag.base_addr.eq(self.base_addr >> 4), pixel_ag.num_pixels_x.eq(self.num_pixels_x), pixel_ag.num_blocks_x.eq(self.pixel_advance_x), pixel_ag.num_blocks_y.eq(self.pixel_advance_y), repeater.repeats.eq(self.num_repeats), pixel_ag.next.eq(repeater.gen_next), repeater.gen_addr.eq(pixel_ag.addr), pixel_ag.start.eq(self.start), repeater.start.eq(self.start), ] # Connect value address generators for v in value_ags: m.d.comb += [ v.start_addr.eq(repeater.addr), v.depth.eq(self.depth), v.num_blocks_y.eq(self.pixel_advance_y), ] # cycle_counter counts cycles through reading input data max_cycle_counter = Signal(9) m.d.comb += max_cycle_counter.eq((self.depth << 6) - 1) cycle_counter = Signal(9) # Stop running on reset. Start running on start running = Signal() with m.If(self.reset): m.d.sync += running.eq(0) with m.Elif(self.start): m.d.sync += running.eq(1) m.d.sync += cycle_counter.eq(0) with m.Elif(running): rollover = cycle_counter == max_cycle_counter m.d.sync += cycle_counter.eq(Mux(rollover, 0, cycle_counter + 1)) # Calculate when to start value address generators and when to get # next pixel address next_pixel = 0 for i in range(4): start_gen = Signal() m.d.comb += start_gen.eq(running & (cycle_counter == i)) m.d.comb += value_ags[i].start.eq(start_gen) next_pixel |= start_gen m.d.comb += repeater.next.eq(next_pixel) # Generate first and last signals m.d.sync += [ self.first.eq(running & (cycle_counter == 0)), self.last.eq(running & (cycle_counter == max_cycle_counter)), ] # Connect to RamMux m.d.comb += self.ram_mux_phase.eq(cycle_counter[:2]) for i in range(4): m.d.comb += [ self.ram_mux_addr[i].eq(value_ags[i].addr_out), self.data_out[i].eq(self.ram_mux_data[i]), ]
def elaborate(self, platform): self.m = m = Module() comb = m.d.comb sync = m.d.sync # CPU units used. logic = m.submodules.logic = LogicUnit() adder = m.submodules.adder = AdderUnit() shifter = m.submodules.shifter = ShifterUnit() compare = m.submodules.compare = CompareUnit() self.current_priv_mode = Signal(PrivModeBits, reset=PrivModeBits.MACHINE) csr_unit = self.csr_unit = m.submodules.csr_unit = CsrUnit( # TODO does '==' below produces the same synth result as .all()? in_machine_mode=self.current_priv_mode == PrivModeBits.MACHINE) exception_unit = self.exception_unit = m.submodules.exception_unit = ExceptionUnit( csr_unit=csr_unit, current_priv_mode=self.current_priv_mode) arbiter = self.arbiter = m.submodules.arbiter = MemoryArbiter( mem_config=self.mem_config, with_addr_translation=True, csr_unit=csr_unit, # SATP register exception_unit=exception_unit, # current privilege mode ) mem_unit = m.submodules.mem_unit = MemoryUnit(mem_port=arbiter.port( priority=0)) ibus = arbiter.port(priority=2) if self.with_debug: m.submodules.debug = self.debug = DebugUnit(self) self.debug_bus = arbiter.port(priority=1) # Current decoding state signals. instr = self.instr = Signal(32) funct3 = self.funct3 = Signal(3) funct7 = self.funct7 = Signal(7) rd = self.rd = Signal(5) rs1 = Signal(5) rs2 = Signal(5) rs1val = Signal(32) rs2val = Signal(32) rdval = Signal(32) # calculated by unit, stored to register file imm = Signal(signed(12)) csr_idx = Signal(12) uimm = Signal(20) opcode = self.opcode = Signal(InstrType) pc = self.pc = Signal(32, reset=CODE_START_ADDR) # at most one active_unit at any time active_unit = ActiveUnit() # Register file. Contains two read ports (for rs1, rs2) and one write port. regs = Memory(width=32, depth=32, init=self.reg_init) reg_read_port1 = m.submodules.reg_read_port1 = regs.read_port() reg_read_port2 = m.submodules.reg_read_port2 = regs.read_port() reg_write_port = (self.reg_write_port ) = m.submodules.reg_write_port = regs.write_port() # Timer management. mtime = self.mtime = Signal(32) sync += mtime.eq(mtime + 1) comb += csr_unit.mtime.eq(mtime) self.halt = Signal() with m.If(csr_unit.mstatus.mie & csr_unit.mie.mtie): with m.If(mtime == csr_unit.mtimecmp): # 'halt' signal needs to be cleared when CPU jumps to trap handler. sync += [ self.halt.eq(1), ] comb += [ exception_unit.m_instruction.eq(instr), exception_unit.m_pc.eq(pc), # TODO more ] # TODO # DebugModule is able to read and write GPR values. # if self.with_debug: # comb += self.halt.eq(self.debug.HALT) # else: # comb += self.halt.eq(0) # with m.If(self.halt): # comb += [ # reg_read_port1.addr.eq(self.gprf_debug_addr), # reg_write_port.addr.eq(self.gprf_debug_addr), # reg_write_port.en.eq(self.gprf_debug_write_en) # ] # with m.If(self.gprf_debug_write_en): # comb += reg_write_port.data.eq(self.gprf_debug_data) # with m.Else(): # comb += self.gprf_debug_data.eq(reg_read_port1.data) with m.If(0): pass with m.Else(): comb += [ reg_read_port1.addr.eq(rs1), reg_read_port2.addr.eq(rs2), reg_write_port.addr.eq(rd), reg_write_port.data.eq(rdval), # reg_write_port.en set later rs1val.eq(reg_read_port1.data), rs2val.eq(reg_read_port2.data), ] comb += [ # following is not true for all instrutions, but in specific cases will be overwritten later imm.eq(instr[20:32]), csr_idx.eq(instr[20:32]), uimm.eq(instr[12:]), ] # drive input signals of actually used unit. with m.If(active_unit.logic): comb += [ logic.funct3.eq(funct3), logic.src1.eq(rs1val), logic.src2.eq(Mux(opcode == InstrType.OP_IMM, imm, rs2val)), ] with m.Elif(active_unit.adder): comb += [ adder.src1.eq(rs1val), adder.src2.eq(Mux(opcode == InstrType.OP_IMM, imm, rs2val)), ] with m.Elif(active_unit.shifter): comb += [ shifter.funct3.eq(funct3), shifter.funct7.eq(funct7), shifter.src1.eq(rs1val), shifter.shift.eq( Mux(opcode == InstrType.OP_IMM, imm[0:5].as_unsigned(), rs2val[0:5])), ] with m.Elif(active_unit.mem_unit): comb += [ mem_unit.en.eq(1), mem_unit.funct3.eq(funct3), mem_unit.src1.eq(rs1val), mem_unit.src2.eq(rs2val), mem_unit.store.eq(opcode == InstrType.STORE), mem_unit.offset.eq( Mux(opcode == InstrType.LOAD, imm, Cat(rd, imm[5:12]))), ] with m.Elif(active_unit.compare): comb += [ compare.funct3.eq(funct3), # Compare Unit uses Adder for carry and overflow flags. adder.src1.eq(rs1val), adder.src2.eq(Mux(opcode == InstrType.OP_IMM, imm, rs2val)), # adder.sub set somewhere below ] with m.Elif(active_unit.branch): comb += [ compare.funct3.eq(funct3), # Compare Unit uses Adder for carry and overflow flags. adder.src1.eq(rs1val), adder.src2.eq(rs2val), # adder.sub set somewhere below ] with m.Elif(active_unit.csr): comb += [ csr_unit.func3.eq(funct3), csr_unit.csr_idx.eq(csr_idx), csr_unit.rs1.eq(rs1), csr_unit.rs1val.eq(rs1val), csr_unit.rd.eq(rd), csr_unit.en.eq(1), ] comb += [ compare.negative.eq(adder.res[-1]), compare.overflow.eq(adder.overflow), compare.carry.eq(adder.carry), compare.zero.eq(adder.res == 0), ] # Decoding state (with redundancy - instr. type not known yet). # We use 'ibus.read_data' instead of 'instr' (that is driven by sync domain) # for getting registers to save 1 cycle. comb += [ opcode.eq(instr[0:7]), rd.eq(instr[7:12]), funct3.eq(instr[12:15]), rs1.eq(instr[15:20]), rs2.eq(instr[20:25]), funct7.eq(instr[25:32]), ] def fetch_with_new_pc(pc: Signal): m.next = "FETCH" m.d.sync += active_unit.eq(0) m.d.sync += self.pc.eq(pc) def trap(cause: Optional[Union[TrapCause, IrqCause]], interrupt=False): fetch_with_new_pc(Cat(Const(0, 2), self.csr_unit.mtvec.base)) if cause is None: return assert isinstance(cause, TrapCause) or isinstance(cause, IrqCause) e = exception_unit notifiers = e.irq_cause_map if interrupt else e.trap_cause_map m.d.comb += notifiers[cause].eq(1) self.fetch = Signal() interconnect_error = Signal() comb += interconnect_error.eq(exception_unit.m_store_error | exception_unit.m_fetch_error | exception_unit.m_load_error) with m.FSM(): with m.State("FETCH"): with m.If(self.halt): sync += self.halt.eq(0) trap(IrqCause.M_TIMER_INTERRUPT, interrupt=True) with m.Else(): with m.If(pc & 0b11): trap(TrapCause.FETCH_MISALIGNED) with m.Else(): comb += [ ibus.en.eq(1), ibus.store.eq(0), ibus.addr.eq(pc), ibus.mask.eq(0b1111), ibus.is_fetch.eq(1), ] with m.If(interconnect_error): trap(cause=None) with m.If(ibus.ack): sync += [ instr.eq(ibus.read_data), ] m.next = "DECODE" with m.State("DECODE"): comb += self.fetch.eq( 1 ) # only for simulation, notify that 'instr' ready to use. m.next = "EXECUTE" # here, we have registers already fetched into rs1val, rs2val. with m.If(instr & 0b11 != 0b11): trap(TrapCause.ILLEGAL_INSTRUCTION) with m.If(match_logic_unit(opcode, funct3, funct7)): sync += [ active_unit.logic.eq(1), ] with m.Elif(match_adder_unit(opcode, funct3, funct7)): sync += [ active_unit.adder.eq(1), adder.sub.eq((opcode == InstrType.ALU) & (funct7 == Funct7.SUB)), ] with m.Elif(match_shifter_unit(opcode, funct3, funct7)): sync += [ active_unit.shifter.eq(1), ] with m.Elif(match_loadstore_unit(opcode, funct3, funct7)): sync += [ active_unit.mem_unit.eq(1), ] with m.Elif(match_compare_unit(opcode, funct3, funct7)): sync += [ active_unit.compare.eq(1), adder.sub.eq(1), ] with m.Elif(match_lui(opcode, funct3, funct7)): sync += [ active_unit.lui.eq(1), ] comb += [ reg_read_port1.addr.eq(rd), # rd will be available in next cycle in rs1val ] with m.Elif(match_auipc(opcode, funct3, funct7)): sync += [ active_unit.auipc.eq(1), ] with m.Elif(match_jal(opcode, funct3, funct7)): sync += [ active_unit.jal.eq(1), ] with m.Elif(match_jalr(opcode, funct3, funct7)): sync += [ active_unit.jalr.eq(1), ] with m.Elif(match_branch(opcode, funct3, funct7)): sync += [ active_unit.branch.eq(1), adder.sub.eq(1), ] with m.Elif(match_csr(opcode, funct3, funct7)): sync += [active_unit.csr.eq(1)] with m.Elif(match_mret(opcode, funct3, funct7)): sync += [active_unit.mret.eq(1)] with m.Elif(match_sfence_vma(opcode, funct3, funct7)): pass # sfence.vma with m.Elif(opcode == 0b0001111): pass # fence with m.Else(): trap(TrapCause.ILLEGAL_INSTRUCTION) with m.State("EXECUTE"): with m.If(active_unit.logic): sync += [ rdval.eq(logic.res), ] with m.Elif(active_unit.adder): sync += [ rdval.eq(adder.res), ] with m.Elif(active_unit.shifter): sync += [ rdval.eq(shifter.res), ] with m.Elif(active_unit.mem_unit): sync += [ rdval.eq(mem_unit.res), ] with m.Elif(active_unit.compare): sync += [ rdval.eq(compare.condition_met), ] with m.Elif(active_unit.lui): sync += [ rdval.eq(Cat(Const(0, 12), uimm)), ] with m.Elif(active_unit.auipc): sync += [ rdval.eq(pc + Cat(Const(0, 12), uimm)), ] with m.Elif(active_unit.jal | active_unit.jalr): sync += [ rdval.eq(pc + 4), ] with m.Elif(active_unit.csr): sync += [rdval.eq(csr_unit.rd_val)] # control flow mux - all traps need to be here, otherwise it will overwrite m.next statement. with m.If(active_unit.mem_unit): with m.If(mem_unit.ack): m.next = "WRITEBACK" sync += active_unit.eq(0) with m.Else(): m.next = "EXECUTE" with m.If(interconnect_error): # NOTE: # the order of that 'If' is important. # In case of error overwrite m.next above. trap(cause=None) with m.Elif(active_unit.csr): with m.If(csr_unit.illegal_insn): trap(TrapCause.ILLEGAL_INSTRUCTION) with m.Else(): with m.If(csr_unit.vld): m.next = "WRITEBACK" sync += active_unit.eq(0) with m.Else(): m.next = "EXECUTE" with m.Elif(active_unit.mret): comb += exception_unit.m_mret.eq(1) fetch_with_new_pc(exception_unit.mepc) with m.Else(): # all units not specified by default take 1 cycle m.next = "WRITEBACK" sync += active_unit.eq(0) jal_offset = Signal(signed(21)) comb += jal_offset.eq( Cat( Const(0, 1), instr[21:31], instr[20], instr[12:20], instr[31], ).as_signed()) pc_addend = Signal(signed(32)) sync += pc_addend.eq(Mux(active_unit.jal, jal_offset, 4)) branch_addend = Signal(signed(13)) comb += branch_addend.eq( Cat( Const(0, 1), instr[8:12], instr[25:31], instr[7], instr[31], ).as_signed() # TODO is it ok that it's signed? ) with m.If(active_unit.branch): with m.If(compare.condition_met): sync += pc_addend.eq(branch_addend) new_pc = Signal(32) is_jalr_latch = Signal() # that's bad workaround with m.If(active_unit.jalr): sync += is_jalr_latch.eq(1) sync += new_pc.eq(rs1val.as_signed() + imm) with m.State("WRITEBACK"): with m.If(is_jalr_latch): sync += pc.eq(new_pc) with m.Else(): sync += pc.eq(pc + pc_addend) sync += is_jalr_latch.eq(0) # Here, rdval is already calculated. If neccessary, put it into register file. should_write_rd = self.should_write_rd = Signal() writeback = self.writeback = Signal() # for riscv-dv simulation: # detect that instruction does not perform register write to avoid infinite loop # by checking writeback & should_write_rd # TODO it will break for trap-causing instructions. comb += writeback.eq(1) comb += should_write_rd.eq( reduce( or_, [ match_shifter_unit(opcode, funct3, funct7), match_adder_unit(opcode, funct3, funct7), match_logic_unit(opcode, funct3, funct7), match_load(opcode, funct3, funct7), match_compare_unit(opcode, funct3, funct7), match_lui(opcode, funct3, funct7), match_auipc(opcode, funct3, funct7), match_jal(opcode, funct3, funct7), match_jalr(opcode, funct3, funct7), match_csr(opcode, funct3, funct7), ], ) & (rd != 0)) with m.If(should_write_rd): comb += reg_write_port.en.eq(True) m.next = "FETCH" return m
def clamped(value, min_bound, max_bound): return Mux(value < min_bound, min_bound, Mux(value > max_bound, max_bound, value))