Exemplo n.º 1
0
    def elab(self, m):
        # number of bytes received already
        byte_count = Signal(range(4))
        # output channel for next byte received
        pixel_index = Signal(range(self._num_pixels))
        # shift registers to buffer 3 incoming values
        shift_registers = Array(
            [Signal(24, name=f"sr_{i}") for i in range(self._num_pixels)])

        last_pixel_index = Signal.like(pixel_index)
        m.d.sync += last_pixel_index.eq(Mux(self.half_mode,
                                            self._num_pixels // 2 - 1,
                                            self._num_pixels - 1))
        next_pixel_index = Signal.like(pixel_index)
        m.d.comb += next_pixel_index.eq(Mux(pixel_index == last_pixel_index,
                                            0, pixel_index + 1))

        m.d.comb += self.input.ready.eq(1)
        with m.If(self.input.is_transferring()):
            sr = shift_registers[pixel_index]
            with m.If(byte_count == 3):
                # Output current value and shift register
                m.d.comb += self.output.valid.eq(1)
                payload = Cat(sr, self.input.payload)
                m.d.comb += self.output.payload.eq(payload)
            with m.Else():
                # Save input to shift register
                m.d.sync += sr[-8:].eq(self.input.payload)
                m.d.sync += sr[:-8].eq(sr[8:])

            # Increment pixel index
            m.d.sync += pixel_index.eq(next_pixel_index)
            with m.If(pixel_index == last_pixel_index):
                m.d.sync += pixel_index.eq(0)
                m.d.sync += byte_count.eq(byte_count + 1)  # allow rollover
Exemplo n.º 2
0
    def build_param_store(self, m):
        # Create memory for post process params
        # Use SinglePortMemory here?
        param_mem = Memory(width=POST_PROCESS_PARAMS_WIDTH,
                           depth=Constants.MAX_CHANNEL_DEPTH)
        m.submodules['param_rp'] = rp = param_mem.read_port(transparent=False)
        m.submodules['param_wp'] = wp = param_mem.write_port()

        # Configure param writer
        m.submodules['param_writer'] = pw = ParamWriter()
        m.d.comb += connect(self.post_process_params, pw.input_data)
        m.d.comb += [
            pw.reset.eq(self.reset),
            wp.en.eq(pw.mem_we),
            wp.addr.eq(pw.mem_addr),
            wp.data.eq(pw.mem_data),
        ]

        # Configure param reader
        m.submodules['param_reader'] = reader = ReadingProducer()
        repeats = Mux(self.config.mode, Constants.SYS_ARRAY_HEIGHT,
                      Constants.SYS_ARRAY_HEIGHT // 2)
        m.d.comb += [
            # Reset reader whenever new parameters are written
            reader.reset.eq(pw.input_data.is_transferring()),
            reader.sizes.depth.eq(self.config.output_channel_depth),
            reader.sizes.repeats.eq(repeats),
            rp.addr.eq(reader.mem_addr),
            reader.mem_data.eq(rp.data),
        ]
        return reader.output_data
Exemplo n.º 3
0
    def elab(self, m: Module):
        buffering = Signal()  # True if there is a value being buffered
        buffered_value = Signal.like(Value.cast(self.input.payload))

        # Pipe valid and ready back and forth
        m.d.comb += [
            self.input.ready.eq(~buffering | self.output.ready),
            self.output.valid.eq(buffering | self.input.valid),
            self.output.payload.eq(
                Mux(buffering, buffered_value, self.input.payload))
        ]

        # Buffer when have incoming value but cannot output just now
        with m.If(~buffering & ~self.output.ready & self.input.valid):
            m.d.sync += buffering.eq(True)
            m.d.sync += buffered_value.eq(self.input.payload)

        # Handle cases when transfering out from buffer
        with m.If(buffering & self.output.ready):
            with m.If(self.input.valid):
                m.d.sync += buffered_value.eq(self.input.payload)
            with m.Else():
                m.d.sync += buffering.eq(False)

        # Reset all state
        with m.If(self.reset):
            m.d.sync += buffering.eq(False)
            m.d.sync += buffered_value.eq(0)
Exemplo n.º 4
0
    def transform(self, m, in_value, out_value):
        # Cycle 0: register inputs
        dividend = Signal(signed(32))
        shift = Signal(4)
        m.d.sync += dividend.eq(in_value.dividend)
        m.d.sync += shift.eq(in_value.shift)

        # Cycle 1: calculate
        result = Signal(signed(32))
        remainder = Signal(signed(32))
        # Our threshold looks like 010, 0100, 01000 etc for positive values and
        # 011, 0101, 01001 etc for negative values.
        threshold = Signal(signed(32))
        quotient = Signal(signed(32))
        negative = Signal()
        m.d.comb += negative.eq(dividend < 0)
        with m.Switch(shift):
            for n in range(2, 13):
                with m.Case(n):
                    mask = (1 << n) - 1
                    m.d.comb += remainder.eq(dividend & mask)
                    m.d.comb += threshold[1:].eq(1 << (n - 2))
                    m.d.comb += quotient.eq(dividend >> n)
        m.d.comb += threshold[0].eq(negative)
        m.d.sync += result.eq(quotient + Mux(remainder >= threshold, 1, 0))

        # Cycle 2: send output
        m.d.sync += out_value.eq(result)
Exemplo n.º 5
0
 def elab(self, m):
     # TODO: reimplement as a function that returns an expression
     mask = (1 << self.exponent) - 1
     remainder = self.x & mask
     threshold = (mask >> 1) + self.x[31]
     rounding = Mux(remainder > threshold, 1, 0)
     m.d.comb += self.result.eq((self.x >> self.exponent) + rounding)
Exemplo n.º 6
0
    def elab(self, m):
        # Track previous restart, next
        was_restart = Signal()
        m.d.sync += was_restart.eq(self.restart)
        was_next = Signal()
        m.d.sync += was_next.eq(self.next)

        # Decide address to be output (determines data available next cycle)
        last_mem_addr = Signal.like(self.mem_addr)
        m.d.sync += last_mem_addr.eq(self.mem_addr)
        incremented_addr = Signal.like(self.limit)
        m.d.comb += incremented_addr.eq(
            Mux(last_mem_addr == self.limit - 1, 0, last_mem_addr + 1))
        with m.If(self.restart):
            m.d.comb += self.mem_addr.eq(0)
        with m.Elif(was_next | was_restart):
            m.d.comb += self.mem_addr.eq(incremented_addr)
        with m.Else():
            m.d.comb += self.mem_addr.eq(last_mem_addr)

        # Decide data to be output
        last_data = Signal.like(self.data)
        m.d.sync += last_data.eq(self.data)
        with m.If(was_restart | was_next):
            m.d.comb += self.data.eq(self.mem_data)
        with m.Else():
            m.d.comb += self.data.eq(last_data)
Exemplo n.º 7
0
    def handle_reading(self, m, memory, num_words, index, reset):
        """Handle stepping through memory on read."""
        m.d.comb += memory.read_addr.eq(index[2:])

        read_port_valid = Signal()
        read_started = Signal()

        with m.If(~self.data_output.valid
                  | self.data_output.is_transferring()):
            m.d.sync += self.data_output.payload.eq(memory.read_data)
            m.d.sync += self.data_output.valid.eq(read_port_valid)
            m.d.sync += read_port_valid.eq(0)

        with m.If(~read_port_valid & ~read_started):
            m.d.sync += index.eq(Mux(index == num_words - 4, 0, index + 4))
            m.d.sync += read_started.eq(1)
        with m.Else():
            m.d.sync += read_started.eq(0)

        with m.If(read_started):
            # Read port is valid next cycle if we started a new read this cycle
            m.d.sync += read_port_valid.eq(1)

        with m.If(reset):
            m.d.sync += read_port_valid.eq(0)
            m.d.sync += read_started.eq(0)
Exemplo n.º 8
0
def rounding_divide_by_pot(x, exponent):
    """Implements gemmlowp::RoundingDivideByPOT

    This divides by a power of two, rounding to the nearest whole number.
    """
    mask = (1 << exponent) - 1
    remainder = x & mask
    threshold = (mask >> 1) + x[31]
    rounding = Mux(remainder > threshold, 1, 0)
    return (x >> exponent) + rounding
Exemplo n.º 9
0
    def elab(self, m):
        count = Signal.like(self.max)
        last_count = self.max - 1
        next_count = Mux(count == last_count, 0, count + 1)

        with m.If(self.en):
            m.d.sync += count.eq(next_count)
            m.d.comb += self.done.eq(count == last_count)

        with m.If(self.restart):
            m.d.sync += count.eq(0)
Exemplo n.º 10
0
 def max_(word0, word1):
     result = [Signal(8, name=f"result{i}") for i in range(4)]
     bytes0 = [word0[i:i + 8] for i in range(0, 32, 8)]
     bytes1 = [word1[i:i + 8] for i in range(0, 32, 8)]
     for r, b0, b1 in zip(result, bytes0, bytes1):
         sb0 = Signal(signed(8))
         m.d.comb += sb0.eq(b0)
         sb1 = Signal(signed(8))
         m.d.comb += sb1.eq(b1)
         m.d.comb += r.eq(Mux(sb1 > sb0, b1, b0))
     return Cat(*result)
Exemplo n.º 11
0
    def transform(self, m, in_value, out_value):
        # Cycle 0: register inputs
        a = in_value.a
        reg_a = Signal(signed(32))
        reg_b = Signal(signed(32))
        m.d.sync += reg_a.eq(Mux(a >= 0, a, -a))
        m.d.sync += reg_b.eq(in_value.b)

        # Cycle 1: multiply to register
        # both operands are positive, so result always positive
        reg_ab = Signal(signed(63))
        m.d.sync += reg_ab.eq(reg_a * reg_b)

        # Cycle 2: nudge, take high bits and sign
        positive_2 = self.delay(m, 2, a >= 0)  # Whether input positive
        nudged = reg_ab + Mux(positive_2, (1 << 30), (1 << 30) - 1)
        high_bits = Signal(signed(32))
        m.d.comb += high_bits.eq(nudged[31:])
        with_sign = Mux(positive_2, high_bits, -high_bits)
        m.d.sync += out_value.eq(with_sign)
Exemplo n.º 12
0
def increment_to_limit(value, limit):
    """Builds a statement that performs circular increments.

    Parameters
    ----------
    value: Signal(n)
        The value to be incremented
    limit: Signal(n)
        The maximum allowed value
    """
    return value.eq(Mux(value == limit - 1, 0, value + 1))
Exemplo n.º 13
0
    def build_input_fetcher(self, m, stop):
        # Create fetchers
        f0 = self.create_fetcher(m, stop, 'f0', Mode0InputFetcher)
        f1 = self.create_fetcher(m, stop, 'f1', Mode1InputFetcher)

        # Additional config for fetcher1
        repeats = (self.config.output_channel_depth //
                   Const(Constants.SYS_ARRAY_WIDTH))
        m.d.comb += [
            f1.num_pixels_x.eq(self.config.num_pixels_x),
            f1.pixel_advance_x.eq(self.config.pixel_advance_x),
            f1.pixel_advance_y.eq(self.config.pixel_advance_y),
            f1.depth.eq(self.config.input_channel_depth >> 4),
            f1.num_repeats.eq(repeats),
        ]

        # Create RamMux and connect to LRAMs
        m.submodules['ram_mux'] = ram_mux = RamMux()
        mode = self.config.mode
        for i in range(4):
            # Connect to ram mux addr and data ports
            m.d.comb += self.lram_addr[i].eq(ram_mux.lram_addr[i])
            m.d.comb += ram_mux.lram_data[i].eq(self.lram_data[i])
            m.d.comb += f0.ram_mux_data[i].eq(ram_mux.data_out[i])
            m.d.comb += f1.ram_mux_data[i].eq(ram_mux.data_out[i])
            m.d.comb += ram_mux.addr_in[i].eq(
                Mux(mode, f1.ram_mux_addr[i], f0.ram_mux_addr[i]))

        # phase input depends on mode
        m.d.comb += ram_mux.phase.eq(
            Mux(mode, f1.ram_mux_phase, f0.ram_mux_phase))

        # Router fetcher outputs depending on mode
        mode_first = Mux(mode, f1.first, f0.first)
        mode_last = Mux(mode, f1.last, f0.last)
        mode_data = [
            Mux(mode, f1.data_out[i], f0.data_out[i]) for i in range(4)
        ]

        return (mode_first, mode_last, mode_data)
Exemplo n.º 14
0
    def elab(self, m):
        # Read_index tracks the address for memory 0
        # self.start starts the address incrementing and reset stops it.
        read_index = Signal(range(Constants.FILTER_WORDS_PER_STORE))
        running = Signal()
        with m.If(running):
            m.d.sync += read_index.eq(
                Mux(read_index == self.size - 1, 0, read_index + 1))
        with m.If(self.start):
            m.d.sync += running.eq(True)
            m.d.sync += read_index.eq(0)

        # set up each memory
        for i in range(Constants.NUM_FILTER_STORES):
            m.submodules[f"mem{i}"] = mem = SinglePortMemory(
                data_width=32, depth=Constants.FILTER_WORDS_PER_STORE)
            # Handle writes to memory
            inp = self.write_input
            m.d.comb += [
                mem.write_enable.eq(inp.valid & (inp.payload.store == i)),
                mem.write_addr.eq(inp.payload.addr),
                mem.write_data.eq(inp.payload.data),
            ]
            # Do reads continuously
            if i == 0:
                # i == 0 treated separately to avoid verilator warnings
                # (and save a few LUTs and FFs)
                m.d.comb += mem.read_addr.eq(read_index)
            else:
                addr = Signal(range(-i, Constants.FILTER_WORDS_PER_STORE),
                              name=f"addr_{i}")
                m.d.comb += addr.eq(read_index - i)
                m.d.comb += mem.read_addr.eq(
                    Mux(addr >= 0, addr, addr + self.size))
            m.d.comb += self.values_out[i].eq(mem.read_data)

        # Always ready to receive more data
        m.d.comb += self.write_input.ready.eq(1)
Exemplo n.º 15
0
    def elab(self, m):
        # Current r_addr is the address being presented to memory this cycle
        # By default current address is last address, but this can be
        # modied by the reatart or next signals
        last_addr = Signal.like(self.r_addr)
        m.d.sync += last_addr.eq(self.r_addr)
        m.d.comb += self.r_addr.eq(last_addr)

        # Respond to inputs
        with m.If(self.restart):
            m.d.comb += self.r_addr.eq(0)
        with m.Elif(self.next):
            m.d.comb += self.r_addr.eq(
                Mux(last_addr >= self.limit - 1, 0, last_addr + 1))
Exemplo n.º 16
0
    def elab(self, m):

        with m.If(self.reset):
            m.d.sync += self.value.eq(0)

        with m.Else():
            value_p1 = Signal.like(self.count)
            next_value = Signal.like(self.value)
            m.d.comb += [
                value_p1.eq(self.value + 1),
                self.last.eq(value_p1 == self.count),
                next_value.eq(Mux(self.last, 0, value_p1)),
            ]
            with m.If(self.next):
                m.d.sync += self.value.eq(next_value)
Exemplo n.º 17
0
    def elab(self, m: Module):
        # One signal for each input stream, indicating whether
        # there is a value being buffered:
        buffering = {name: Signal(name=f'buffering_{name}')
                     for name in self.field_names}
        # A buffer for each input stream:
        buffered_values = \
            {name: Signal(self.field_shapes[name], name=f'buffered_{name}')
             for name in self.field_names}

        # For each field of the concatenated output, present either the
        # buffered value if we have one, or else plumb through the input.
        for name in self.field_names:
            m.d.comb += self.output.payload[name].eq(
                    Mux(buffering[name],
                        buffered_values[name],
                        self.inputs[name].payload))

        # The output is valid if we have either a buffered value or a valid
        # input for every slice in the output.
        valid_or_buffering = (Cat(*[ep.valid for ep in self.inputs.values()]) |
                              Cat(*buffering.values()))
        m.d.comb += self.output.valid.eq(valid_or_buffering.all())

        for name, input in self.inputs.items():
            # We can accept an input if the buffer is not occupied,
            # or if we can output this cycle.
            m.d.comb += input.ready.eq(~buffering[name] |
                                       self.output.is_transferring())

            with m.If(input.valid):
                # Buffer it if the buffer is not occupied and we can't output.
                with m.If(~buffering[name] & ~self.output.is_transferring()):
                    m.d.sync += buffering[name].eq(True)
                    m.d.sync += buffered_values[name].eq(input.payload)
                # Buffer it if the buffer is occupied but we are outputting.
                with m.If(self.output.is_transferring()):
                    m.d.sync += buffered_values[name].eq(input.payload)
            with m.Else():
                with m.If(self.output.is_transferring()):
                    m.d.sync += buffering[name].eq(False)

        # Reset all state
        with m.If(self.reset):
            for b in buffering.values():
                m.d.sync += b.eq(False)
            for bv in buffered_values.values():
                m.d.sync += bv.eq(0)
Exemplo n.º 18
0
    def elab(self, m):
        running = Signal()
        count = Signal(2)
        with m.If(running):
            m.d.sync += count.eq(count + 1)  # Allow rollover

        next_row = Signal(18)
        m.d.comb += self.addr_out.eq(Mux(count == 0, self.start_addr,
                                         next_row))
        m.d.sync += next_row.eq(self.addr_out + self.INCREMENT_Y)
        m.d.comb += self.first.eq(running & (count == 0))
        m.d.comb += self.last.eq(running & (count == 3))

        with m.If(self.start):
            m.d.sync += running.eq(True)
            m.d.sync += count.eq(0)
        with m.If(self.reset):
            m.d.sync += running.eq(False)
            m.d.sync += count.eq(0)
Exemplo n.º 19
0
    def elaborate(self, platform):
        m = Module()

        # neat way of setting carry flag
        res_and_carry = Cat(self.res, self.carry)

        m.d.comb += res_and_carry.eq(
            Mux(self.sub, self.src1 - self.src2, self.src1 + self.src2))

        with m.If(self.sub):
            with m.If((self.src1[-1] != self.src2[-1])
                      & (self.src1[-1] != self.res[-1])):
                m.d.comb += self.overflow.eq(1)
        with m.Else():
            # add
            with m.If((self.src1[-1] == self.src2[-1])
                      & (self.src1[-1] != self.res[-1])):
                m.d.comb += self.overflow.eq(1)

        return m
Exemplo n.º 20
0
    def build_multipliers(self, m, accumulator):
        # Pipeline cycle 0: calculate products
        products = []
        for i in range(self._n):
            a_bits = self.input_a.word_select(i, self._a_shape.width)
            b_bits = self.input_b.word_select(i, self._b_shape.width)
            a = Signal(self._a_shape, name=f"a_{i}")
            b = Signal(self._b_shape, name=f"b_{i}")
            m.d.comb += [
                a.eq(a_bits),
                b.eq(b_bits),
            ]
            ab = Signal.like(a * b)
            m.d.sync += ab.eq(a * b)
            products.append(ab)

        # Pipeline cycle 1: accumulate
        product_sum = Signal.like(tree_sum(products))
        m.d.comb += product_sum.eq(tree_sum(products))
        first_delayed = delay(m, self.input_first, 1)[-1]
        base = Mux(first_delayed, 0, accumulator)
        m.d.sync += accumulator.eq(base + product_sum)
Exemplo n.º 21
0
    def elab(self, m):
        # This code covers 8 cases, determined by bits 1, 2 and 3 of self.addr.
        # First, bit 2 and 3 are used to select the appropriate ram_mux phase
        # and addresses in order to read the two words containing the required
        # data via channels 0 and 3 of the RAM Mux. Once the two words have been
        # retrieved, six bytes are selected from those two words based on the
        # value of bit 1 of self.addr.

        # Uses just two of the mux channels - 0 and 3
        # For convenience, tie the unused addresses to zero
        m.d.comb += self.ram_mux_addr[1].eq(0)
        m.d.comb += self.ram_mux_addr[2].eq(0)

        # Calculate block addresses of the two words - second word may cross 16
        # byte block boundary
        block = Signal(14)
        m.d.comb += block.eq(self.addr[4:])
        m.d.comb += self.ram_mux_addr[0].eq(block)
        m.d.comb += self.ram_mux_addr[3].eq(
            Mux(self.ram_mux_phase == 3, block + 1, block))

        # Use phase to select the two required words to channels 0 & 3
        m.d.comb += self.ram_mux_phase.eq(self.addr[2:4])

        # Select correct three half words when data is available, on cycle after
        # address received.
        byte_sel = Signal(1)
        m.d.sync += byte_sel.eq(self.addr[1])
        d0 = self.ram_mux_data[0]
        d3 = self.ram_mux_data[3]
        dmix = Signal(32)
        m.d.comb += dmix.eq(Cat(d0[16:], d3[:16]))
        with m.If(byte_sel == 0):
            m.d.comb += self.data_out[0].eq(d0)
            m.d.sync += self.data_out[1].eq(dmix)
        with m.Else():
            m.d.comb += self.data_out[0].eq(dmix)
            m.d.sync += self.data_out[1].eq(d3)
Exemplo n.º 22
0
    def elab(self, m):
        areg = Signal.like(self.a)
        breg = Signal.like(self.b)
        ab = Signal(signed(64))
        overflow = Signal()

        # for some reason negative nudge is not used
        nudge = 1 << 30

        # cycle 0, register a and b
        m.d.sync += [
            areg.eq(self.a),
            breg.eq(self.b),
        ]
        # cycle 1, decide if this is an overflow and multiply
        m.d.sync += [
            overflow.eq((areg == INT32_MIN) & (breg == INT32_MIN)),
            ab.eq(areg * breg),
        ]
        # cycle 2, apply nudge determine result
        m.d.sync += [
            self.result.eq(Mux(overflow, INT32_MAX, (ab + nudge)[31:])),
        ]
Exemplo n.º 23
0
    def elaborate(self, platform):
        m = Module()
        sync = m.d.sync
        comb = m.d.comb

        cfg = self.mem_config
        # TODO XXX self.no_match on decoder
        m.submodules.bridge = GenericInterfaceToWishboneMasterBridge(
            generic_bus=self.generic_bus, wb_bus=self.wb_bus)
        self.decoder = m.submodules.decoder = WishboneBusAddressDecoder(
            wb_bus=self.wb_bus, word_size=cfg.word_size)
        self.initialize_mmio_devices(self.decoder, m)
        pe = m.submodules.pe = self.pe = PriorityEncoder(width=len(self.ports))
        sorted_ports = [port for priority, port in sorted(self.ports.items())]

        # force 'elaborate' invocation for all mmio modules.
        for mmio_module, addr_space in self.mmio_cfg:
            setattr(m.submodules, addr_space.basename, mmio_module)

        addr_translation_en = self.addr_translation_en = Signal()
        bus_free_to_latch = self.bus_free_to_latch = Signal(reset=1)

        if self.with_addr_translation:
            m.d.comb += addr_translation_en.eq(self.csr_unit.satp.mode & (
                self.exception_unit.current_priv_mode == PrivModeBits.USER))
        else:
            m.d.comb += addr_translation_en.eq(False)

        with m.If(~addr_translation_en):
            # when translation enabled, 'bus_free_to_latch' is low during page-walk.
            # with no translation it's simpler - just look at the main bus.
            m.d.comb += bus_free_to_latch.eq(~self.generic_bus.busy)

        with m.If(bus_free_to_latch):
            # no transaction in-progress
            for i, p in enumerate(sorted_ports):
                m.d.sync += pe.i[i].eq(p.en)

        virtual_req_bus_latch = LoadStoreInterface()
        phys_addr = self.phys_addr = Signal(32)

        # translation-unit controller signals.
        start_translation = self.start_translation = Signal()
        translation_ack = self.translation_ack = Signal()
        gb = self.generic_bus

        with m.If(self.decoder.no_match & self.wb_bus.cyc):
            m.d.comb += self.exception_unit.badaddr.eq(gb.addr)
            with m.If(gb.store):
                m.d.comb += self.exception_unit.m_store_error.eq(1)
            with m.Elif(gb.is_fetch):
                m.d.comb += self.exception_unit.m_fetch_error.eq(1)
            with m.Else():
                m.d.comb += self.exception_unit.m_load_error.eq(1)

        with m.If(~pe.none):
            # transaction request occured
            for i, priority in enumerate(sorted_ports):
                with m.If(pe.o == i):
                    bus_owner_port = sorted_ports[i]
                    with m.If(~addr_translation_en):
                        # simple case, no need to calculate physical address
                        comb += gb.connect(bus_owner_port)
                    with m.Else():
                        # page-walk performs multiple memory operations - will reconnect 'generic_bus' multiple times
                        with m.FSM():
                            first = self.first = Signal(
                            )  # TODO get rid of that
                            with m.State("TRANSLATE"):
                                comb += [
                                    start_translation.eq(1),
                                    bus_free_to_latch.eq(0),
                                ]
                                sync += virtual_req_bus_latch.connect(
                                    bus_owner_port,
                                    exclude=[
                                        name
                                        for name, _, dir in generic_bus_layout
                                        if dir == DIR_FANOUT
                                    ])
                                with m.If(
                                        translation_ack
                                ):  # wait for 'phys_addr' to be set by page-walk algorithm.
                                    m.next = "REQ"
                                sync += first.eq(1)
                            with m.State("REQ"):
                                comb += gb.connect(bus_owner_port,
                                                   exclude=["addr"])
                                comb += gb.addr.eq(
                                    phys_addr)  # found by page-walk
                                with m.If(first):
                                    sync += first.eq(0)
                                with m.Else():
                                    # without 'first' signal '~gb.busy' would be high at the very beginning
                                    with m.If(~gb.busy):
                                        comb += bus_free_to_latch.eq(1)
                                        m.next = "TRANSLATE"

        req_is_write = Signal()
        pte = self.pte = Record(pte_layout)
        vaddr = Record(virt_addr_layout)
        comb += [
            req_is_write.eq(virtual_req_bus_latch.store),
            vaddr.eq(virtual_req_bus_latch.addr),
        ]

        @unique
        class Issue(IntEnum):
            OK = 0
            PAGE_INVALID = 1
            WRITABLE_NOT_READABLE = 2
            LACK_PERMISSIONS = 3
            FIRST_ACCESS = 4
            MISALIGNED_SUPERPAGE = 5
            LEAF_IS_NO_LEAF = 6

        self.error_code = Signal(Issue)

        def error(code: Issue):
            m.d.sync += self.error_code.eq(code)

        # Code below implements algorithm 4.3.2 in Risc-V Privileged specification, v1.10
        sv32_i = Signal(reset=1)
        root_ppn = self.root_ppn = Signal(22)

        if not self.with_addr_translation:
            return m

        with m.FSM():
            with m.State("IDLE"):
                with m.If(start_translation):
                    sync += sv32_i.eq(1)
                    sync += root_ppn.eq(self.csr_unit.satp.ppn)
                    m.next = "TRANSLATE"
            with m.State("TRANSLATE"):
                vpn = self.vpn = Signal(10)
                comb += vpn.eq(Mux(
                    sv32_i,
                    vaddr.vpn1,
                    vaddr.vpn0,
                ))
                comb += [
                    gb.en.eq(1),
                    gb.addr.eq(Cat(Const(0, 2), vpn, root_ppn)),
                    gb.store.eq(0),
                    gb.mask.eq(0b1111),  # TODO use -1
                ]
                with m.If(gb.ack):
                    sync += pte.eq(gb.read_data)
                    m.next = "PROCESS_PTE"
            with m.State("PROCESS_PTE"):
                with m.If(~pte.v):
                    error(Issue.PAGE_INVALID)
                with m.If(pte.w & ~pte.r):
                    error(Issue.WRITABLE_NOT_READABLE)

                is_leaf = lambda pte: pte.r | pte.x
                with m.If(is_leaf(pte)):
                    with m.If(~pte.u & (self.exception_unit.current_priv_mode
                                        == PrivModeBits.USER)):
                        error(Issue.LACK_PERMISSIONS)
                    with m.If(~pte.a | (req_is_write & ~pte.d)):
                        error(Issue.FIRST_ACCESS)
                    with m.If(sv32_i.bool() & pte.ppn0.bool()):
                        error(Issue.MISALIGNED_SUPERPAGE)
                    # phys_addr could be 34 bits long, but our interconnect is 32-bit long.
                    # below statement cuts lowest two bits of r-value.
                    sync += phys_addr.eq(
                        Cat(vaddr.page_offset, pte.ppn0, pte.ppn1))
                with m.Else():  # not a leaf
                    with m.If(sv32_i == 0):
                        error(Issue.LEAF_IS_NO_LEAF)
                    sync += root_ppn.eq(
                        Cat(pte.ppn0,
                            pte.ppn1))  # pte a is pointer to the next level
                m.next = "NEXT"
            with m.State("NEXT"):
                # Note that we cannot check 'sv32_i == 0', becuase superpages can be present.
                with m.If(is_leaf(pte)):
                    sync += sv32_i.eq(1)
                    comb += translation_ack.eq(
                        1)  # notify that 'phys_addr' signal is set
                    m.next = "IDLE"
                with m.Else():
                    sync += sv32_i.eq(0)
                    m.next = "TRANSLATE"
        return m
Exemplo n.º 24
0
    def elaborate(self, platform) -> Module:
        """build the module"""
        m = Module()
        sync = m.d.sync
        comb = m.d.comb

        nrzidecoder = NRZIDecoder(self.clk_freq)
        m.submodules.nrzi_decoder = nrzidecoder

        framedata_shifter = InputShiftRegister(24)
        m.submodules.framedata_shifter = framedata_shifter

        output_pulser = EdgeToPulse()
        m.submodules.output_pulser = output_pulser

        active_channel = Signal(3)
        # counts the number of bits output
        bit_counter      = Signal(8)
        # counts the bit position inside a nibble
        nibble_counter   = Signal(3)
        # counts, how many 0 bits it got in a row
        sync_bit_counter = Signal(4)

        comb += [
            nrzidecoder.nrzi_in.eq(self.adat_in),
            self.synced_out.eq(nrzidecoder.running),
            self.recovered_clock_out.eq(nrzidecoder.recovered_clock_out),
        ]

        with m.FSM():
            # wait for SYNC
            with m.State("WAIT_SYNC"):
                # reset invalid frame bit to be able to start again
                with m.If(nrzidecoder.invalid_frame_in):
                    sync += nrzidecoder.invalid_frame_in.eq(0)

                with m.If(nrzidecoder.running):
                    sync += [
                        bit_counter.eq(0),
                        nibble_counter.eq(0),
                        active_channel.eq(0),
                        output_pulser.edge_in.eq(0)
                    ]

                    with m.If(nrzidecoder.data_out_en):
                        m.d.sync += sync_bit_counter.eq(Mux(nrzidecoder.data_out, 0, sync_bit_counter + 1))
                        with m.If(sync_bit_counter == 9):
                            m.d.sync += sync_bit_counter.eq(0)
                            m.next = "READ_FRAME"

            with m.State("READ_FRAME"):
                # at which bit of bit_counter to output sample data at
                output_at = Signal(8)

                # user bits have been read
                with m.If(bit_counter == 5):
                    sync += [
                        # output user bits
                        self.user_data_out.eq(framedata_shifter.value_out[0:4]),
                        # at bit 35 the first channel has been read
                        output_at.eq(35)
                    ]

                # when each channel has been read, output the channel's sample
                with m.If((bit_counter > 5) & (bit_counter == output_at)):
                    sync += [
                        self.output_enable.eq(1),
                        self.addr_out.eq(active_channel),
                        self.sample_out.eq(framedata_shifter.value_out),
                        output_at.eq(output_at + 30),
                        active_channel.eq(active_channel + 1)
                    ]
                with m.Else():
                    sync += self.output_enable.eq(0)

                # we work and count only when we get
                # a new bit fron the NRZI decoder
                with m.If(nrzidecoder.data_out_en):
                    comb += [
                        framedata_shifter.bit_in.eq(nrzidecoder.data_out),
                        # skip sync bit, which is first
                        framedata_shifter.enable_in.eq(~(nibble_counter == 0))
                    ]
                    sync += [
                        nibble_counter.eq(nibble_counter + 1),
                        bit_counter.eq(bit_counter + 1),
                    ]

                    # check 4b/5b sync bit
                    with m.If((nibble_counter == 0) & ~nrzidecoder.data_out):
                        sync += nrzidecoder.invalid_frame_in.eq(1)
                        m.next = "WAIT_SYNC"
                    with m.Else():
                        sync += nrzidecoder.invalid_frame_in.eq(0)

                    with m.If(nibble_counter >= 4):
                        sync += nibble_counter.eq(0)

                    # 239 channel bits and 5 user bits (including sync bits)
                    with m.If(bit_counter >= (239 + 5)):
                        sync += [
                            bit_counter.eq(0),
                            output_pulser.edge_in.eq(1)
                        ]
                        m.next = "READ_SYNC"

                with m.Else():
                    comb += framedata_shifter.enable_in.eq(0)

                with m.If(~nrzidecoder.running):
                    m.next = "WAIT_SYNC"

            # read the sync bits
            with m.State("READ_SYNC"):
                sync += [
                    self.output_enable.eq(output_pulser.pulse_out),
                    self.addr_out.eq(active_channel),
                    self.sample_out.eq(framedata_shifter.value_out),
                ]

                with m.If(nrzidecoder.data_out_en):
                    sync += [
                        nibble_counter.eq(0),
                        bit_counter.eq(bit_counter + 1),
                    ]

                    with m.If(bit_counter == 9):
                        comb += [
                            framedata_shifter.enable_in.eq(0),
                            framedata_shifter.clear_in.eq(1),
                        ]

                    #check last sync bit before sync trough
                    with m.If((bit_counter == 0) & ~nrzidecoder.data_out):
                        sync += nrzidecoder.invalid_frame_in.eq(1)
                        m.next = "WAIT_SYNC"
                    #check all the null bits in the sync trough
                    with m.Elif((bit_counter > 0) & nrzidecoder.data_out):
                        sync += nrzidecoder.invalid_frame_in.eq(1)
                        m.next = "WAIT_SYNC"
                    with m.Elif((bit_counter == 10) & ~nrzidecoder.data_out):
                        sync += [
                            bit_counter.eq(0),
                            nibble_counter.eq(0),
                            active_channel.eq(0),
                            output_pulser.edge_in.eq(0),
                            nrzidecoder.invalid_frame_in.eq(0)
                        ]
                        m.next = "READ_FRAME"
                    with m.Else():
                        sync += nrzidecoder.invalid_frame_in.eq(0)

                with m.If(~nrzidecoder.running):
                    m.next = "WAIT_SYNC"

        return m
Exemplo n.º 25
0
 def elab(self, m):
     captured = Signal.like(self.input)
     with m.If(self.capture):
         m.d.sync += captured.eq(self.input)
     m.d.comb += self.output.eq(Mux(self.capture, self.input, captured))
Exemplo n.º 26
0
    def elab(self, m):
        m.submodules["pixel_ag"] = pixel_ag = PixelAddressGenerator()
        m.submodules["repeater"] = repeater = PixelAddressRepeater()
        value_ags = [ValueAddressGenerator() for _ in range(4)]
        for i, v in enumerate(value_ags):
            m.submodules[f"value_ag{i}"] = v

        # Connect pixel address generator and repeater
        m.d.comb += [
            pixel_ag.base_addr.eq(self.base_addr >> 4),
            pixel_ag.num_pixels_x.eq(self.num_pixels_x),
            pixel_ag.num_blocks_x.eq(self.pixel_advance_x),
            pixel_ag.num_blocks_y.eq(self.pixel_advance_y),
            repeater.repeats.eq(self.num_repeats),
            pixel_ag.next.eq(repeater.gen_next),
            repeater.gen_addr.eq(pixel_ag.addr),
            pixel_ag.start.eq(self.start),
            repeater.start.eq(self.start),
        ]

        # Connect value address generators
        for v in value_ags:
            m.d.comb += [
                v.start_addr.eq(repeater.addr),
                v.depth.eq(self.depth),
                v.num_blocks_y.eq(self.pixel_advance_y),
            ]

        # cycle_counter counts cycles through reading input data
        max_cycle_counter = Signal(9)
        m.d.comb += max_cycle_counter.eq((self.depth << 6) - 1)
        cycle_counter = Signal(9)

        # Stop running on reset. Start running on start
        running = Signal()
        with m.If(self.reset):
            m.d.sync += running.eq(0)
        with m.Elif(self.start):
            m.d.sync += running.eq(1)
            m.d.sync += cycle_counter.eq(0)
        with m.Elif(running):
            rollover = cycle_counter == max_cycle_counter
            m.d.sync += cycle_counter.eq(Mux(rollover, 0, cycle_counter + 1))

        # Calculate when to start value address generators and when to get
        # next pixel address
        next_pixel = 0
        for i in range(4):
            start_gen = Signal()
            m.d.comb += start_gen.eq(running & (cycle_counter == i))
            m.d.comb += value_ags[i].start.eq(start_gen)
            next_pixel |= start_gen
        m.d.comb += repeater.next.eq(next_pixel)

        # Generate first and last signals
        m.d.sync += [
            self.first.eq(running & (cycle_counter == 0)),
            self.last.eq(running & (cycle_counter == max_cycle_counter)),
        ]

        # Connect to RamMux
        m.d.comb += self.ram_mux_phase.eq(cycle_counter[:2])
        for i in range(4):
            m.d.comb += [
                self.ram_mux_addr[i].eq(value_ags[i].addr_out),
                self.data_out[i].eq(self.ram_mux_data[i]),
            ]
Exemplo n.º 27
0
    def elaborate(self, platform):
        self.m = m = Module()

        comb = m.d.comb
        sync = m.d.sync

        # CPU units used.
        logic = m.submodules.logic = LogicUnit()
        adder = m.submodules.adder = AdderUnit()
        shifter = m.submodules.shifter = ShifterUnit()
        compare = m.submodules.compare = CompareUnit()

        self.current_priv_mode = Signal(PrivModeBits,
                                        reset=PrivModeBits.MACHINE)

        csr_unit = self.csr_unit = m.submodules.csr_unit = CsrUnit(
            # TODO does '==' below produces the same synth result as .all()?
            in_machine_mode=self.current_priv_mode == PrivModeBits.MACHINE)
        exception_unit = self.exception_unit = m.submodules.exception_unit = ExceptionUnit(
            csr_unit=csr_unit, current_priv_mode=self.current_priv_mode)
        arbiter = self.arbiter = m.submodules.arbiter = MemoryArbiter(
            mem_config=self.mem_config,
            with_addr_translation=True,
            csr_unit=csr_unit,  # SATP register
            exception_unit=exception_unit,  # current privilege mode
        )
        mem_unit = m.submodules.mem_unit = MemoryUnit(mem_port=arbiter.port(
            priority=0))

        ibus = arbiter.port(priority=2)

        if self.with_debug:
            m.submodules.debug = self.debug = DebugUnit(self)
            self.debug_bus = arbiter.port(priority=1)

        # Current decoding state signals.
        instr = self.instr = Signal(32)
        funct3 = self.funct3 = Signal(3)
        funct7 = self.funct7 = Signal(7)
        rd = self.rd = Signal(5)
        rs1 = Signal(5)
        rs2 = Signal(5)
        rs1val = Signal(32)
        rs2val = Signal(32)
        rdval = Signal(32)  # calculated by unit, stored to register file
        imm = Signal(signed(12))
        csr_idx = Signal(12)
        uimm = Signal(20)
        opcode = self.opcode = Signal(InstrType)
        pc = self.pc = Signal(32, reset=CODE_START_ADDR)

        # at most one active_unit at any time
        active_unit = ActiveUnit()

        # Register file. Contains two read ports (for rs1, rs2) and one write port.
        regs = Memory(width=32, depth=32, init=self.reg_init)
        reg_read_port1 = m.submodules.reg_read_port1 = regs.read_port()
        reg_read_port2 = m.submodules.reg_read_port2 = regs.read_port()
        reg_write_port = (self.reg_write_port
                          ) = m.submodules.reg_write_port = regs.write_port()

        # Timer management.
        mtime = self.mtime = Signal(32)
        sync += mtime.eq(mtime + 1)
        comb += csr_unit.mtime.eq(mtime)

        self.halt = Signal()
        with m.If(csr_unit.mstatus.mie & csr_unit.mie.mtie):
            with m.If(mtime == csr_unit.mtimecmp):
                # 'halt' signal needs to be cleared when CPU jumps to trap handler.
                sync += [
                    self.halt.eq(1),
                ]

        comb += [
            exception_unit.m_instruction.eq(instr),
            exception_unit.m_pc.eq(pc),
            # TODO more
        ]

        # TODO
        # DebugModule is able to read and write GPR values.
        # if self.with_debug:
        #     comb += self.halt.eq(self.debug.HALT)
        # else:
        #     comb += self.halt.eq(0)

        # with m.If(self.halt):
        #     comb += [
        #         reg_read_port1.addr.eq(self.gprf_debug_addr),
        #         reg_write_port.addr.eq(self.gprf_debug_addr),
        #         reg_write_port.en.eq(self.gprf_debug_write_en)
        #     ]

        #     with m.If(self.gprf_debug_write_en):
        #         comb += reg_write_port.data.eq(self.gprf_debug_data)
        #     with m.Else():
        #         comb += self.gprf_debug_data.eq(reg_read_port1.data)
        with m.If(0):
            pass
        with m.Else():
            comb += [
                reg_read_port1.addr.eq(rs1),
                reg_read_port2.addr.eq(rs2),
                reg_write_port.addr.eq(rd),
                reg_write_port.data.eq(rdval),
                # reg_write_port.en set later
                rs1val.eq(reg_read_port1.data),
                rs2val.eq(reg_read_port2.data),
            ]

        comb += [
            # following is not true for all instrutions, but in specific cases will be overwritten later
            imm.eq(instr[20:32]),
            csr_idx.eq(instr[20:32]),
            uimm.eq(instr[12:]),
        ]

        # drive input signals of actually used unit.
        with m.If(active_unit.logic):
            comb += [
                logic.funct3.eq(funct3),
                logic.src1.eq(rs1val),
                logic.src2.eq(Mux(opcode == InstrType.OP_IMM, imm, rs2val)),
            ]
        with m.Elif(active_unit.adder):
            comb += [
                adder.src1.eq(rs1val),
                adder.src2.eq(Mux(opcode == InstrType.OP_IMM, imm, rs2val)),
            ]
        with m.Elif(active_unit.shifter):
            comb += [
                shifter.funct3.eq(funct3),
                shifter.funct7.eq(funct7),
                shifter.src1.eq(rs1val),
                shifter.shift.eq(
                    Mux(opcode == InstrType.OP_IMM, imm[0:5].as_unsigned(),
                        rs2val[0:5])),
            ]
        with m.Elif(active_unit.mem_unit):
            comb += [
                mem_unit.en.eq(1),
                mem_unit.funct3.eq(funct3),
                mem_unit.src1.eq(rs1val),
                mem_unit.src2.eq(rs2val),
                mem_unit.store.eq(opcode == InstrType.STORE),
                mem_unit.offset.eq(
                    Mux(opcode == InstrType.LOAD, imm, Cat(rd, imm[5:12]))),
            ]
        with m.Elif(active_unit.compare):
            comb += [
                compare.funct3.eq(funct3),
                # Compare Unit uses Adder for carry and overflow flags.
                adder.src1.eq(rs1val),
                adder.src2.eq(Mux(opcode == InstrType.OP_IMM, imm, rs2val)),
                # adder.sub set somewhere below
            ]
        with m.Elif(active_unit.branch):
            comb += [
                compare.funct3.eq(funct3),
                # Compare Unit uses Adder for carry and overflow flags.
                adder.src1.eq(rs1val),
                adder.src2.eq(rs2val),
                # adder.sub set somewhere below
            ]
        with m.Elif(active_unit.csr):
            comb += [
                csr_unit.func3.eq(funct3),
                csr_unit.csr_idx.eq(csr_idx),
                csr_unit.rs1.eq(rs1),
                csr_unit.rs1val.eq(rs1val),
                csr_unit.rd.eq(rd),
                csr_unit.en.eq(1),
            ]

        comb += [
            compare.negative.eq(adder.res[-1]),
            compare.overflow.eq(adder.overflow),
            compare.carry.eq(adder.carry),
            compare.zero.eq(adder.res == 0),
        ]

        # Decoding state (with redundancy - instr. type not known yet).
        # We use 'ibus.read_data' instead of 'instr' (that is driven by sync domain)
        # for getting registers to save 1 cycle.
        comb += [
            opcode.eq(instr[0:7]),
            rd.eq(instr[7:12]),
            funct3.eq(instr[12:15]),
            rs1.eq(instr[15:20]),
            rs2.eq(instr[20:25]),
            funct7.eq(instr[25:32]),
        ]

        def fetch_with_new_pc(pc: Signal):
            m.next = "FETCH"
            m.d.sync += active_unit.eq(0)
            m.d.sync += self.pc.eq(pc)

        def trap(cause: Optional[Union[TrapCause, IrqCause]], interrupt=False):
            fetch_with_new_pc(Cat(Const(0, 2), self.csr_unit.mtvec.base))
            if cause is None:
                return
            assert isinstance(cause, TrapCause) or isinstance(cause, IrqCause)
            e = exception_unit
            notifiers = e.irq_cause_map if interrupt else e.trap_cause_map
            m.d.comb += notifiers[cause].eq(1)

        self.fetch = Signal()
        interconnect_error = Signal()
        comb += interconnect_error.eq(exception_unit.m_store_error
                                      | exception_unit.m_fetch_error
                                      | exception_unit.m_load_error)
        with m.FSM():
            with m.State("FETCH"):
                with m.If(self.halt):
                    sync += self.halt.eq(0)
                    trap(IrqCause.M_TIMER_INTERRUPT, interrupt=True)
                with m.Else():
                    with m.If(pc & 0b11):
                        trap(TrapCause.FETCH_MISALIGNED)
                    with m.Else():
                        comb += [
                            ibus.en.eq(1),
                            ibus.store.eq(0),
                            ibus.addr.eq(pc),
                            ibus.mask.eq(0b1111),
                            ibus.is_fetch.eq(1),
                        ]
                    with m.If(interconnect_error):
                        trap(cause=None)
                    with m.If(ibus.ack):
                        sync += [
                            instr.eq(ibus.read_data),
                        ]
                        m.next = "DECODE"
            with m.State("DECODE"):
                comb += self.fetch.eq(
                    1
                )  # only for simulation, notify that 'instr' ready to use.
                m.next = "EXECUTE"
                # here, we have registers already fetched into rs1val, rs2val.
                with m.If(instr & 0b11 != 0b11):
                    trap(TrapCause.ILLEGAL_INSTRUCTION)
                with m.If(match_logic_unit(opcode, funct3, funct7)):
                    sync += [
                        active_unit.logic.eq(1),
                    ]
                with m.Elif(match_adder_unit(opcode, funct3, funct7)):
                    sync += [
                        active_unit.adder.eq(1),
                        adder.sub.eq((opcode == InstrType.ALU)
                                     & (funct7 == Funct7.SUB)),
                    ]
                with m.Elif(match_shifter_unit(opcode, funct3, funct7)):
                    sync += [
                        active_unit.shifter.eq(1),
                    ]
                with m.Elif(match_loadstore_unit(opcode, funct3, funct7)):
                    sync += [
                        active_unit.mem_unit.eq(1),
                    ]
                with m.Elif(match_compare_unit(opcode, funct3, funct7)):
                    sync += [
                        active_unit.compare.eq(1),
                        adder.sub.eq(1),
                    ]
                with m.Elif(match_lui(opcode, funct3, funct7)):
                    sync += [
                        active_unit.lui.eq(1),
                    ]
                    comb += [
                        reg_read_port1.addr.eq(rd),
                        # rd will be available in next cycle in rs1val
                    ]
                with m.Elif(match_auipc(opcode, funct3, funct7)):
                    sync += [
                        active_unit.auipc.eq(1),
                    ]
                with m.Elif(match_jal(opcode, funct3, funct7)):
                    sync += [
                        active_unit.jal.eq(1),
                    ]
                with m.Elif(match_jalr(opcode, funct3, funct7)):
                    sync += [
                        active_unit.jalr.eq(1),
                    ]
                with m.Elif(match_branch(opcode, funct3, funct7)):
                    sync += [
                        active_unit.branch.eq(1),
                        adder.sub.eq(1),
                    ]
                with m.Elif(match_csr(opcode, funct3, funct7)):
                    sync += [active_unit.csr.eq(1)]
                with m.Elif(match_mret(opcode, funct3, funct7)):
                    sync += [active_unit.mret.eq(1)]
                with m.Elif(match_sfence_vma(opcode, funct3, funct7)):
                    pass  # sfence.vma
                with m.Elif(opcode == 0b0001111):
                    pass  # fence
                with m.Else():
                    trap(TrapCause.ILLEGAL_INSTRUCTION)
            with m.State("EXECUTE"):
                with m.If(active_unit.logic):
                    sync += [
                        rdval.eq(logic.res),
                    ]
                with m.Elif(active_unit.adder):
                    sync += [
                        rdval.eq(adder.res),
                    ]
                with m.Elif(active_unit.shifter):
                    sync += [
                        rdval.eq(shifter.res),
                    ]
                with m.Elif(active_unit.mem_unit):
                    sync += [
                        rdval.eq(mem_unit.res),
                    ]
                with m.Elif(active_unit.compare):
                    sync += [
                        rdval.eq(compare.condition_met),
                    ]
                with m.Elif(active_unit.lui):
                    sync += [
                        rdval.eq(Cat(Const(0, 12), uimm)),
                    ]
                with m.Elif(active_unit.auipc):
                    sync += [
                        rdval.eq(pc + Cat(Const(0, 12), uimm)),
                    ]
                with m.Elif(active_unit.jal | active_unit.jalr):
                    sync += [
                        rdval.eq(pc + 4),
                    ]
                with m.Elif(active_unit.csr):
                    sync += [rdval.eq(csr_unit.rd_val)]

                # control flow mux - all traps need to be here, otherwise it will overwrite m.next statement.
                with m.If(active_unit.mem_unit):
                    with m.If(mem_unit.ack):
                        m.next = "WRITEBACK"
                        sync += active_unit.eq(0)
                    with m.Else():
                        m.next = "EXECUTE"
                    with m.If(interconnect_error):
                        # NOTE:
                        # the order of that 'If' is important.
                        # In case of error overwrite m.next above.
                        trap(cause=None)
                with m.Elif(active_unit.csr):
                    with m.If(csr_unit.illegal_insn):
                        trap(TrapCause.ILLEGAL_INSTRUCTION)
                    with m.Else():
                        with m.If(csr_unit.vld):
                            m.next = "WRITEBACK"
                            sync += active_unit.eq(0)
                        with m.Else():
                            m.next = "EXECUTE"
                with m.Elif(active_unit.mret):
                    comb += exception_unit.m_mret.eq(1)
                    fetch_with_new_pc(exception_unit.mepc)
                with m.Else():
                    # all units not specified by default take 1 cycle
                    m.next = "WRITEBACK"
                    sync += active_unit.eq(0)

                jal_offset = Signal(signed(21))
                comb += jal_offset.eq(
                    Cat(
                        Const(0, 1),
                        instr[21:31],
                        instr[20],
                        instr[12:20],
                        instr[31],
                    ).as_signed())

                pc_addend = Signal(signed(32))
                sync += pc_addend.eq(Mux(active_unit.jal, jal_offset, 4))

                branch_addend = Signal(signed(13))
                comb += branch_addend.eq(
                    Cat(
                        Const(0, 1),
                        instr[8:12],
                        instr[25:31],
                        instr[7],
                        instr[31],
                    ).as_signed()  # TODO is it ok that it's signed?
                )

                with m.If(active_unit.branch):
                    with m.If(compare.condition_met):
                        sync += pc_addend.eq(branch_addend)

                new_pc = Signal(32)
                is_jalr_latch = Signal()  # that's bad workaround
                with m.If(active_unit.jalr):
                    sync += is_jalr_latch.eq(1)
                    sync += new_pc.eq(rs1val.as_signed() + imm)

            with m.State("WRITEBACK"):
                with m.If(is_jalr_latch):
                    sync += pc.eq(new_pc)
                with m.Else():
                    sync += pc.eq(pc + pc_addend)
                sync += is_jalr_latch.eq(0)

                # Here, rdval is already calculated. If neccessary, put it into register file.
                should_write_rd = self.should_write_rd = Signal()
                writeback = self.writeback = Signal()
                # for riscv-dv simulation:
                # detect that instruction does not perform register write to avoid infinite loop
                # by checking writeback & should_write_rd
                # TODO it will break for trap-causing instructions.
                comb += writeback.eq(1)
                comb += should_write_rd.eq(
                    reduce(
                        or_,
                        [
                            match_shifter_unit(opcode, funct3, funct7),
                            match_adder_unit(opcode, funct3, funct7),
                            match_logic_unit(opcode, funct3, funct7),
                            match_load(opcode, funct3, funct7),
                            match_compare_unit(opcode, funct3, funct7),
                            match_lui(opcode, funct3, funct7),
                            match_auipc(opcode, funct3, funct7),
                            match_jal(opcode, funct3, funct7),
                            match_jalr(opcode, funct3, funct7),
                            match_csr(opcode, funct3, funct7),
                        ],
                    )
                    & (rd != 0))

                with m.If(should_write_rd):
                    comb += reg_write_port.en.eq(True)
                m.next = "FETCH"

        return m
Exemplo n.º 28
0
def clamped(value, min_bound, max_bound):
    return Mux(value < min_bound, min_bound,
               Mux(value > max_bound, max_bound, value))