예제 #1
0
    def __init__(self, data_width, max_line_length):
        super().__init__("agg_aligner")

        # Capture to the object
        self.data_width = data_width
        self.max_line_length = max_line_length
        self.counter_width = clog2(self.max_line_length)

        # Clock and Reset
        self._clk = self.clock("clk")
        self._rst_n = self.reset("rst_n")

        # Inputs
        self._in_dat = self.input("in_dat", self.data_width)
        self._in_valid = self.input("in_valid", 1)

        self._line_length = self.input("line_length", self.counter_width)
        self._line_length.add_attribute(
            ConfigRegAttr("Line Length/Image Width for alignment"))

        # Outputs
        self._out_dat = self.output("out_dat", self.data_width)
        self._out_valid = self.output("out_valid", 1)
        self._out_align = self.output("align", 1)

        # Local Signals
        self._cnt = self.var("cnt", self.counter_width)

        # Generate
        self.add_code(self.update_cnt)
        self.add_code(self.set_align)

        self.wire(self._out_dat, self._in_dat)
        self.wire(self._out_valid, self._in_valid)
예제 #2
0
파일: passes.py 프로젝트: StanfordAHA/lake
        def visit(self, node):
            if isinstance(node, _kratos.Generator):
                ports_ = node.get_port_names()
                for port_name in ports_:
                    curr_port = node.get_port(port_name)
                    attrs = curr_port.find_attribute(
                        lambda a: isinstance(a, ConfigRegAttr))
                    annotation_attr = curr_port.find_attribute(
                        lambda a: isinstance(a, FormalAttr))
                    if port_name is "mode":
                        print("Found mode...")
                        print(attrs)
                    if len(attrs) != 1:
                        continue

                    cr_attr = attrs[0]
                    doc = cr_attr.get_documentation()
                    # need to wire it to the instances parent and up
                    gen = node
                    parent_gen = gen.parent_generator()
                    child_port = curr_port
                    child_gen = gen
                    top_lvl_cfg = parent_gen is None
                    while parent_gen is not None:
                        # create a port based on the target's definition
                        new_name = child_gen.instance_name + "_" + child_port.name
                        p = parent_gen.port(child_port, new_name, False)
                        parent_gen.wire(child_port, p)
                        # move up the hierarchy
                        child_port = p
                        child_gen = parent_gen
                        parent_gen = parent_gen.parent_generator()
                    # Only add the attribute if this is a newly created port, not a top-level cfg reg
                    if top_lvl_cfg is False:
                        child_port_cra = ConfigRegAttr()
                        child_port_cra.set_documentation(doc)
                        child_port.add_attribute(child_port_cra)
                        if annotation_attr:
                            ann_att = annotation_attr[0]
                            annot_type = ann_att.get_formal_ann()
                            child_port.add_attribute(
                                FormalAttr(f"{child_port}", annot_type))
예제 #3
0
    def __init__(self, name, data_width, init_val):
        super().__init__(f"config_reg_{name}_{data_width}")

        self.add_attribute(ConfigRegAttr())

        self.init_val = init_val
        self.data_width = data_width

        self._out = self.output("o_data_out", data_width)
        self._clk = self.clock("i_clk")
        self._rst_n = self.reset("i_rst_n")

        self.add_code(self.set_out)
예제 #4
0
    def __init__(self, data_width, interconnect_output_ports):
        super().__init__("Chain", debug=True)

        # generator parameters
        self.data_width = data_width
        self.interconnect_output_ports = interconnect_output_ports

        # chain enable configuration register
        self._chain_en = self.input("chain_en", 1)
        self._chain_en.add_attribute(
            ConfigRegAttr("Signal indicating whether to enable chaining"))
        self._chain_en.add_attribute(
            FormalAttr(self._chain_en.name, FormalSignalConstraint.SET0))

        # inputs
        self._curr_tile_data_out = self.input(
            "curr_tile_data_out",
            self.data_width,
            size=self.interconnect_output_ports,
            packed=True,
            explicit_array=True)

        self._chain_data_in = self.input("chain_data_in",
                                         self.data_width,
                                         size=self.interconnect_output_ports,
                                         packed=True,
                                         explicit_array=True)

        self._accessor_output = self.input("accessor_output",
                                           self.interconnect_output_ports)

        self._data_out_tile = self.output("data_out_tile",
                                          self.data_width,
                                          size=self.interconnect_output_ports,
                                          packed=True,
                                          explicit_array=True)

        self.add_code(self.set_data_out)
예제 #5
0
파일: memory.py 프로젝트: StanfordAHA/lake
    def __init__(self, mem_params, word_width):

        super().__init__("lake_mem", debug=True)

        ################################################################
        # PARAMETERS
        ################################################################
        # print("MEM PARAMS ", mem_params)

        # basic parameters
        self.word_width = word_width

        # general memory parameters
        self.mem_name = mem_params["name"]
        self.capacity = mem_params["capacity"]
        self.rw_same_cycle = mem_params["rw_same_cycle"]
        self.use_macro = mem_params["use_macro"]
        self.macro_name = mem_params["macro_name"]

        # number of port types
        self.num_read_write_ports = mem_params["num_read_write_ports"]
        self.num_read_only_ports = mem_params["num_read_ports"]
        self.num_write_only_ports = mem_params["num_write_ports"]
        self.num_read_ports = self.num_read_only_ports + self.num_read_write_ports
        self.num_write_ports = self.num_write_only_ports + self.num_read_write_ports

        # info for port types
        self.write_info = mem_params["write_info"]
        self.read_info = mem_params["read_info"]
        self.read_write_info = mem_params["read_write_info"]

        # TODO change - for now, we assume you cannot have read/write and read or write ports
        # should be the max of write vs read_write and need to handle more general case
        if self.num_read_write_ports == 0:
            self.write_width = mem_params["write_port_width"]
            self.read_width = mem_params["read_port_width"]
        else:
            self.write_width = mem_params["read_write_port_width"]
            self.read_width = mem_params["read_write_port_width"]

        assert self.capacity % self.write_width == 0, \
            "Memory capacity is not a multiple of the port width for writes"
        assert self.capacity % self.read_width == 0, \
            "Memory capacity is not a multiple of the port width for reads"

        # innermost dimension for size of memory is the size of whichever port
        # type has a wider width between reads and writes
        self.mem_size = max(self.read_width, self.write_width)
        # this assert has to be true if previous two asserts are true
        assert self.capacity % self.mem_size == 0
        # this is the last dimension for size of memory - equal to the number
        # of the port type with wider width addresses can fit in the memory
        self.mem_last_dim = int(self.capacity / self.mem_size)

        self.mem_size_bits = max(1, clog2(self.mem_size))
        self.mem_last_dim_bits = max(1, clog2(self.mem_last_dim))

        # chaining parameters and config regs
        self.chaining = mem_params["chaining"]
        self.num_chain = mem_params["num_chain"]
        self.num_chain_bits = clog2(self.num_chain)
        if self.chaining:
            self.chain_index = self.var("chain_index",
                                        width=self.num_chain_bits)
            self.chain_index.add_attribute(
                ConfigRegAttr("Chain index for chaining"))
            self.chain_index.add_attribute(
                FormalAttr(self.chain_index.name, FormalSignalConstraint.SET0))

        # minimum required widths for address signals
        if self.mem_size == self.write_width and self.mem_size == self.read_width:
            self.write_addr_width = self.mem_last_dim_bits + self.num_chain_bits
            self.read_addr_width = self.mem_last_dim_bits + self.num_chain_bits
        elif self.mem_size == self.write_width:
            self.write_addr_width = self.mem_last_dim_bits + self.num_chain_bits
            self.read_addr_width = self.mem_size_bits + self.mem_last_dim_bits + self.num_chain_bits
        elif self.mem_size == self.read_width:
            self.write_addr_width = self.mem_size_bits + self.mem_last_dim_bits + self.num_chain_bits
            self.read_addr_width = self.mem_last_dim_bits + self.num_chain_bits
        else:
            print("Error occurred! Memory size does not make sense.")

        ################################################################
        # I/O INTERFACE (WITHOUT ADDRESSING) + MEMORY
        ################################################################
        self.clk = self.clock("clk")
        # active low asynchornous reset
        self.rst_n = self.reset("rst_n", 1)

        self.data_in = self.input("data_in",
                                  width=self.word_width,
                                  size=(self.num_write_ports,
                                        self.write_width),
                                  explicit_array=True,
                                  packed=True)

        self.chain_en = self.input("chain_en", 1)

        # write enable (high: write, low: read when rw_same_cycle = False, else
        # only indicates write)
        self.write = self.input("write", width=1, size=self.num_write_ports)

        self.data_out = self.output("data_out",
                                    width=self.word_width,
                                    size=(self.num_read_ports,
                                          self.read_width),
                                    explicit_array=True,
                                    packed=True)

        self.write_chain = self.var("write_chain",
                                    width=1,
                                    size=self.num_write_ports)

        if self.use_macro:

            self.read_write_addr = self.input("read_write_addr",
                                              width=self.addr_width,
                                              size=self.num_read_write_ports,
                                              explicit_array=True)

            sram = SRAM(not self.use_macro, self.macro_name, word_width,
                        mem_params["read_write_port_width"],
                        mem_params["capacity"],
                        mem_params["num_read_write_ports"],
                        mem_params["num_read_write_ports"],
                        clog2(mem_params["capacity"]), 0, 1)

            self.add_child(
                "SRAM_" + mem_params["name"],
                sram,
                clk=self.clk,
                clk_en=1,
                mem_data_in_bank=self.data_in,
                mem_data_out_bank=self.data_out,
                mem_addr_in_bank=self.read_write_addr,
                # TODO adjust
                mem_cen_in_bank=1,
                mem_wen_in_bank=self.write_chain,
                wtsel=0,
                rtsel=1)

        else:
            # memory variable (not I/O)
            self.memory = self.var("memory",
                                   width=self.word_width,
                                   size=(self.mem_last_dim, self.mem_size),
                                   explicit_array=True,
                                   packed=True)

            ################################################################
            # ADDRESSING I/O AND SIGNALS
            ################################################################

            # I/O is different depending on whether we have read and write ports or
            # read/write ports

            # we keep address width at 16 to avoid unpacked
            # safe_wire errors for addr in hw_top_lake - can change by changing
            # default_config_width for those addr gens while accounting for muxing
            # bits, but the extra bits are unused anyway

            if self.rw_same_cycle:
                self.read = self.input("read",
                                       width=1,
                                       size=self.num_read_ports)
            else:
                self.read = self.var("read", width=1, size=self.num_read_ports)
                for i in range(self.num_read_ports):
                    self.wire(self.read[i], 1)

            # TODO change later - same read/write or read and write assumption as above
            if self.num_write_only_ports != 0 and self.num_read_only_ports != 0:
                # writes
                self.write_addr = self.input(
                    "write_addr",
                    width=16,  # self.write_addr_width,
                    size=self.num_write_ports,
                    explicit_array=True)

                assert self.write_info[0]["latency"] > 0, \
                    "Latency for write ports must be greater than 1 clock cycle."

                # reads
                self.read_addr = self.input(
                    "read_addr",
                    width=16,  # self.read_addr_width,
                    size=self.num_read_ports,
                    explicit_array=True)

                # TODO for now assuming all read ports have same latency
                # TODO also should add support for other latencies

            # rw_same_cycle is not valid here because read/write share the same port
            elif self.num_read_write_ports != 0:
                self.read_write_addr = self.input(
                    "read_write_addr",
                    width=
                    16,  # max(self.read_addr_width, self.write_addr_width),
                    size=self.num_read_write_ports,
                    explicit_array=True)

                # writes
                self.write_addr = self.var(
                    "write_addr",
                    width=16,  # self.write_addr_width,
                    size=self.num_read_write_ports,
                    explicit_array=True)

                for p in range(self.num_read_write_ports):
                    safe_wire(gen=self,
                              w_to=self.write_addr[p],
                              w_from=self.read_write_addr[p])

                # reads
                self.read_addr = self.var("read_addr",
                                          width=self.read_addr_width,
                                          size=self.num_read_write_ports,
                                          explicit_array=True)
                for p in range(self.num_read_write_ports):
                    safe_wire(gen=self,
                              w_to=self.read_addr[p],
                              w_from=self.read_write_addr[p])

                # TODO in self.read_write_info we should allow for different read
                # and write latencies?
                self.read_info = self.read_write_info

            # TODO just doing chaining for SRAM
            if self.chaining and self.num_read_write_ports > 0:
                self.wire(
                    self.write_chain,
                    # chaining not enabled
                    (
                        ~self.chain_en |
                        # chaining enabled
                        (self.chain_en &
                         (self.chain_index == self.read_write_addr[
                             self.write_addr_width + self.num_chain_bits,
                             self.write_addr_width]))) & self.write)
                # chaining not supported
            else:
                self.wire(self.write_chain, self.write)

            if self.use_macro:
                self.wire(sram.ports.mem_wen_in_bank, self.write_chain)

            self.add_write_data_block()
            self.add_read_data_block()
예제 #6
0
    def __init__(self,
                 data_width=16,
                 banks=1,
                 memory_width=64,
                 rw_same_cycle=False,
                 read_delay=1,
                 addr_width=9):
        super().__init__("strg_fifo")

        # Generation parameters
        self.banks = banks
        self.data_width = data_width
        self.memory_width = memory_width
        self.rw_same_cycle = rw_same_cycle
        self.read_delay = read_delay
        self.addr_width = addr_width
        self.fw_int = int(self.memory_width / self.data_width)

        # assert banks > 1 or rw_same_cycle is True or self.fw_int > 1, \
        #     "Can't sustain throughput with this setup. Need potential bandwidth for " + \
        #     "1 write and 1 read in a cycle - try using more banks or a macro that supports 1R1W"

        # Clock and Reset
        self._clk = self.clock("clk")
        self._rst_n = self.reset("rst_n")

        # Inputs + Outputs
        self._push = self.input("push", 1)
        self._data_in = self.input("data_in", self.data_width)
        self._pop = self.input("pop", 1)

        self._data_out = self.output("data_out", self.data_width)
        self._valid_out = self.output("valid_out", 1)
        self._empty = self.output("empty", 1)
        self._full = self.output("full", 1)

        # get relevant signals from the storage banks
        self._data_from_strg = self.input("data_from_strg", self.data_width,
                                          size=(self.banks,
                                                self.fw_int),
                                          explicit_array=True,
                                          packed=True)

        self._wen_addr = self.var("wen_addr", self.addr_width,
                                  size=self.banks,
                                  explicit_array=True,
                                  packed=True)

        self._ren_addr = self.var("ren_addr", self.addr_width,
                                  size=self.banks,
                                  explicit_array=True,
                                  packed=True)

        self._front_combined = self.var("front_combined", self.data_width,
                                        size=self.fw_int,
                                        explicit_array=True,
                                        packed=True)

        self._data_to_strg = self.output("data_to_strg", self.data_width,
                                         size=(self.banks,
                                               self.fw_int),
                                         explicit_array=True,
                                         packed=True)

        self._wen_to_strg = self.output("wen_to_strg", self.banks)
        self._ren_to_strg = self.output("ren_to_strg", self.banks)

        self._num_words_mem = self.var("num_words_mem", self.data_width)

        if self.banks == 1:
            self._curr_bank_wr = self.var("curr_bank_wr", 1)
            self.wire(self._curr_bank_wr, kts.const(0, 1))
            self._curr_bank_rd = self.var("curr_bank_rd", 1)
            self.wire(self._curr_bank_rd, kts.const(0, 1))
        else:
            self._curr_bank_wr = self.var("curr_bank_wr", kts.clog2(self.banks))
            self._curr_bank_rd = self.var("curr_bank_rd", kts.clog2(self.banks))

        self._write_queue = self.var("write_queue", self.data_width,
                                     size=(self.banks,
                                           self.fw_int),
                                     explicit_array=True,
                                     packed=True)

        # Lets us know if the bank has a write queued up
        self._queued_write = self.var("queued_write", self.banks)

        self._front_data_out = self.var("front_data_out", self.data_width)
        self._front_pop = self.var("front_pop", 1)
        self._front_empty = self.var("front_empty", 1)
        self._front_full = self.var("front_full", 1)
        self._front_valid = self.var("front_valid", 1)
        self._front_par_read = self.var("front_par_read", 1)
        self._front_par_out = self.var("front_par_out", self.data_width,
                                       size=(self.fw_int,
                                             1),
                                       explicit_array=True,
                                       packed=True)

        self._front_rd_ptr = self.var("front_rd_ptr", max(1, clog2(self.fw_int)))

        self._front_push = self.var("front_push", 1)
        self.wire(self._front_push, self._push & (~self._full | self._pop))

        self._front_rf = RegFIFO(data_width=self.data_width,
                                 width_mult=1,
                                 depth=self.fw_int,
                                 parallel=True,
                                 break_out_rd_ptr=True)

        # This one breaks out the read pointer so we can properly
        # reorder the data to storage

        self.add_child("front_rf", self._front_rf,
                       clk=self._clk,
                       clk_en=kts.const(1, 1),
                       rst_n=self._rst_n,
                       push=self._front_push,
                       pop=self._front_pop,
                       empty=self._front_empty,
                       full=self._front_full,
                       valid=self._front_valid,
                       parallel_read=self._front_par_read,
                       parallel_load=kts.const(0, 1),  # We don't need to parallel load the front
                       parallel_in=0,  # Same reason as above
                       parallel_out=self._front_par_out,
                       num_load=0,
                       rd_ptr_out=self._front_rd_ptr)
        self.wire(self._front_rf.ports.data_in[0], self._data_in)
        self.wire(self._front_data_out, self._front_rf.ports.data_out[0])

        self._back_data_in = self.var("back_data_in", self.data_width)
        self._back_data_out = self.var("back_data_out", self.data_width)
        self._back_push = self.var("back_push", 1)
        self._back_empty = self.var("back_empty", 1)
        self._back_full = self.var("back_full", 1)
        self._back_valid = self.var("back_valid", 1)
        self._back_pl = self.var("back_pl", 1)
        self._back_par_in = self.var("back_par_in", self.data_width,
                                     size=(self.fw_int,
                                           1),
                                     explicit_array=True,
                                     packed=True)
        self._back_num_load = self.var("back_num_load", clog2(self.fw_int) + 1)

        self._back_occ = self.var("back_occ", clog2(self.fw_int) + 1)
        self._front_occ = self.var("front_occ", clog2(self.fw_int) + 1)

        self._back_rf = RegFIFO(data_width=self.data_width,
                                width_mult=1,
                                depth=self.fw_int,
                                parallel=True,
                                break_out_rd_ptr=False)

        self._fw_is_1 = self.var("fw_is_1", 1)
        self.wire(self._fw_is_1, kts.const(self.fw_int == 1, 1))

        self._back_pop = self.var("back_pop", 1)
        if self.fw_int == 1:
            self.wire(self._back_pop, self._pop & (~self._empty | self._push) & ~self._back_pl)
        else:
            self.wire(self._back_pop, self._pop & (~self._empty | self._push))

        self.add_child("back_rf", self._back_rf,
                       clk=self._clk,
                       clk_en=kts.const(1, 1),
                       rst_n=self._rst_n,
                       push=self._back_push,
                       pop=self._back_pop,
                       empty=self._back_empty,
                       full=self._back_full,
                       valid=self._back_valid,
                       parallel_read=kts.const(0, 1),
                       # Only do back load when data is going there
                       parallel_load=self._back_pl & self._back_num_load.r_or(),
                       parallel_in=self._back_par_in,
                       num_load=self._back_num_load)
        self.wire(self._back_rf.ports.data_in[0], self._back_data_in)
        self.wire(self._back_data_out, self._back_rf.ports.data_out[0])
        # send the writes through when a read isn't happening
        for i in range(self.banks):
            self.add_code(self.send_writes, idx=i)
            self.add_code(self.send_reads, idx=i)

        # Set the parallel load to back bank - if no delay it's immediate
        # if not, it's delayed :)
        if self.read_delay == 1:
            self._ren_delay = self.var("ren_delay", 1)
            self.add_code(self.set_parallel_ld_delay_1)
            self.wire(self._back_pl, self._ren_delay)
        else:
            self.wire(self._back_pl, self._ren_to_strg.r_or())

        # Combine front end data - just the items + incoming
        # this data is actually based on the rd_ptr from the front fifo
        for i in range(self.fw_int):
            self.wire(self._front_combined[i], self._front_par_out[self._front_rd_ptr + i])
        # This is always true
        # self.wire(self._front_combined[self.fw_int - 1], self._data_in)

        # prioritize queued writes, otherwise send combined data
        for i in range(self.banks):
            self.wire(self._data_to_strg[i],
                      kts.ternary(self._queued_write[i], self._write_queue[i], self._front_combined))

        # Wire the thin output from front to thin input to back
        self.wire(self._back_data_in, self._front_data_out)
        self.wire(self._back_push, self._front_valid)
        self.add_code(self.set_front_pop)

        # Queue writes
        for i in range(self.banks):
            self.add_code(self.set_write_queue, idx=i)

        # Track number of words in memory
        # if self.read_delay == 1:
        #     self.add_code(self.set_num_words_mem_delay)
        # else:
        self.add_code(self.set_num_words_mem)

        # Track occupancy of the two small fifos
        self.add_code(self.set_front_occ)
        self.add_code(self.set_back_occ)

        if self.banks > 1:
            self.add_code(self.set_curr_bank_wr)
            self.add_code(self.set_curr_bank_rd)
        if self.read_delay == 1:
            self._prev_bank_rd = self.var("prev_bank_rd", max(1, kts.clog2(self.banks)))
            self.add_code(self.set_prev_bank_rd)

        # Parallel load data to back - based on num_load
        index_into = self._curr_bank_rd
        if self.read_delay == 1:
            index_into = self._prev_bank_rd
        for i in range(self.fw_int - 1):
            # Shift data over if you bypassed from the memory output
            self.wire(self._back_par_in[i],
                      kts.ternary(self._back_num_load == self.fw_int,
                                  self._data_from_strg[index_into][i],
                                  self._data_from_strg[index_into][i + 1]))
        self.wire(self._back_par_in[self.fw_int - 1],
                  kts.ternary(self._back_num_load == self.fw_int,
                              self._data_from_strg[index_into][self.fw_int - 1],
                              kts.const(0, self.data_width)))

        # Set the parallel read to the front fifo - analogous with trying to write to the memory
        self.add_code(self.set_front_par_read)

        # Set the number being parallely loaded into the register
        self.add_code(self.set_back_num_load)

        # Data out and valid out are (in the general case) just the data and valid from the back fifo
        # In the case where we have a fresh memory read, it would be from that
        bank_idx_read = self._curr_bank_rd
        if self.read_delay == 1:
            bank_idx_read = self._prev_bank_rd
        self.wire(self._data_out,
                  kts.ternary(self._back_pl, self._data_from_strg[bank_idx_read][0], self._back_data_out))
        self.wire(self._valid_out, kts.ternary(self._back_pl, self._pop, self._back_valid))

        # Set addresses to storage
        for i in range(self.banks):
            self.add_code(self.set_wen_addr, idx=i)
            self.add_code(self.set_ren_addr, idx=i)
        # Now deal with a shared address vs separate addresses
        if self.rw_same_cycle:
            # Separate
            self._wen_addr_out = self.output("wen_addr_out", self.addr_width,
                                             size=self.banks,
                                             explicit_array=True,
                                             packed=True)
            self._ren_addr_out = self.output("ren_addr_out", self.addr_width,
                                             size=self.banks,
                                             explicit_array=True,
                                             packed=True)
            self.wire(self._wen_addr_out, self._wen_addr)
            self.wire(self._ren_addr_out, self._ren_addr)
        else:
            self._addr_out = self.output("addr_out", self.addr_width,
                                         size=self.banks,
                                         explicit_array=True,
                                         packed=True)
            # If sharing the addresses, send read addr with priority
            for i in range(self.banks):
                self.wire(self._addr_out[i],
                          kts.ternary(self._wen_to_strg[i], self._wen_addr[i], self._ren_addr[i]))

        # Do final empty/full
        self._num_items = self.var("num_items", self.data_width)
        self.add_code(self.set_num_items)
        self._fifo_depth = self.input("fifo_depth", self.data_width)
        self._fifo_depth.add_attribute(ConfigRegAttr("Fifo depth..."))
        self.wire(self._empty, self._num_items == 0)
        self.wire(self._full, self._num_items == (self._fifo_depth))
예제 #7
0
    def __init__(self,
                 word_width,
                 input_ports,
                 output_ports,
                 memories,
                 edges):

        super().__init__("LakeTop", debug=True)

        # parameters
        self.word_width = word_width
        self.input_ports = input_ports
        self.output_ports = output_ports

        self.default_config_width = 16
        self.cycle_count_width = 16

        self.stencil_valid = False

        # objects
        self.memories = memories
        self.edges = edges

        # tile enable and clock
        self.tile_en = self.input("tile_en", 1)
        self.tile_en.add_attribute(ConfigRegAttr("Tile logic enable manifested as clock gate"))
        self.tile_en.add_attribute(FormalAttr(self.tile_en.name, FormalSignalConstraint.SET1))

        self.clk_mem = self.clock("clk")
        self.clk_mem.add_attribute(FormalAttr(self.clk_mem.name, FormalSignalConstraint.CLK))

        # chaining
        chain_supported = False
        for mem in self.memories.keys():
            if self.memories[mem]["chaining"]:
                chain_supported = True
                break

        if chain_supported:
            self.chain_en = self.input("chain_en", 1)
            self.chain_en.add_attribute(ConfigRegAttr("Chaining enable"))
            self.chain_en.add_attribute(FormalAttr(self.chain_en.name, FormalSignalConstraint.SET0))
        else:
            self.chain_en = self.var("chain_en", 1)
            self.wire(self.chain_en, 0)

        # gate clock with tile_en
        gclk = self.var("gclk", 1)
        self.gclk = kts.util.clock(gclk)
        self.wire(gclk, self.clk_mem & self.tile_en)

        self.clk_en = self.clock_en("clk_en", 1)

        # active low asynchornous reset
        self.rst_n = self.reset("rst_n", 1)
        self.rst_n.add_attribute(FormalAttr(self.rst_n.name, FormalSignalConstraint.RSTN))

        # data in and out of top level Lake memory object
        self.data_in = self.input("data_in",
                                  width=self.word_width,
                                  size=self.input_ports,
                                  explicit_array=True,
                                  packed=True)
        self.data_in.add_attribute(FormalAttr(self.data_in.name, FormalSignalConstraint.SEQUENCE))

        self.data_out = self.output("data_out",
                                    width=self.word_width,
                                    size=self.output_ports,
                                    explicit_array=True,
                                    packed=True)
        self.data_out.add_attribute(FormalAttr(self.data_out.name, FormalSignalConstraint.SEQUENCE))

        # global cycle count for accessor comparison
        self._cycle_count = self.var("cycle_count", 16)

        @always_ff((posedge, self.gclk), (negedge, "rst_n"))
        def increment_cycle_count(self):
            if ~self.rst_n:
                self._cycle_count = 0
            else:
                self._cycle_count = self._cycle_count + 1

        self.add_always(increment_cycle_count)

        # info about memories
        num_mem = len(memories)
        subscript_mems = list(self.memories.keys())

        # list of the data out from each memory
        self.mem_data_outs = [self.var(f"mem_data_out_{subscript_mems[i]}",
                                       width=self.word_width,
                                       size=self.memories[subscript_mems[i]]
                                       ["read_port_width" if "read_port_width" in self.memories[subscript_mems[i]]
                                        else "read_write_port_width"],
                                       explicit_array=True, packed=True) for i in range(num_mem)]

        # keep track of write, read_addr, and write_addr vars for read/write memories
        # to later check whether there is a write and what to use for the shared port
        self.mem_read_write_addrs = {}

        # create memory instance for each memory
        self.mem_insts = {}
        i = 0
        for mem in self.memories.keys():
            m = mem_inst(self.memories[mem], self.word_width)
            self.mem_insts[mem] = m

            self.add_child(mem,
                           m,
                           clk=self.gclk,
                           rst_n=self.rst_n,
                           # put data out in memory data out list
                           data_out=self.mem_data_outs[i],
                           chain_en=self.chain_en)
            i += 1

        # get input and output memories
        is_input, is_output = [], []
        for mem_name in self.memories.keys():
            mem = self.memories[mem_name]
            if mem["is_input"]:
                is_input.append(mem_name)
            if mem["is_output"]:
                is_output.append(mem_name)

        # TODO direct connection to write doesn't work (?), so have to do this...
        self.low = self.var("low", 1)
        self.wire(self.low, 0)

        # TODO adding multiple ports to 1 memory after talking about mux with compiler team

        # set up input memories
        for i in range(len(is_input)):
            in_mem = is_input[i]

            # input addressor / accessor parameters
            input_dim = self.memories[in_mem]["input_edge_params"]["dim"]
            input_range = self.memories[in_mem]["input_edge_params"]["max_range"]
            input_stride = self.memories[in_mem]["input_edge_params"]["max_stride"]
            # input port associated with memory
            input_port_index = self.memories[in_mem]["input_port"]

            self.valid = self.var(
                f"input_port{input_port_index}_2{in_mem}_accessor_valid", 1)
            self.wire(self.mem_insts[in_mem].ports.write, self.valid)

            # hook up data from the specified input port to the memory
            safe_wire(self, self.mem_insts[in_mem].ports.data_in[0],
                      self.data_in[input_port_index])

            if self.memories[in_mem]["num_read_write_ports"] > 0:
                self.mem_read_write_addrs[in_mem] = {"write": self.valid}

            # create IteratorDomain, AddressGenerator, and ScheduleGenerator
            # for writes to this input memory
            forloop = ForLoop(iterator_support=input_dim,
                              config_width=max(1, clog2(input_range)))  # self.default_config_width)
            loop_itr = forloop.get_iter()
            loop_wth = forloop.get_cfg_width()

            self.add_child(f"input_port{input_port_index}_2{in_mem}_forloop",
                           forloop,
                           clk=self.gclk,
                           rst_n=self.rst_n,
                           step=self.valid)

            newAG = AddrGen(iterator_support=input_dim,
                            config_width=max(1, clog2(input_stride)))  # self.default_config_width)
            self.add_child(f"input_port{input_port_index}_2{in_mem}_write_addr_gen",
                           newAG,
                           clk=self.gclk,
                           rst_n=self.rst_n,
                           step=self.valid,
                           mux_sel=forloop.ports.mux_sel_out,
                           restart=forloop.ports.restart)

            if self.memories[in_mem]["num_read_write_ports"] == 0:
                safe_wire(self, self.mem_insts[in_mem].ports.write_addr[0], newAG.ports.addr_out)
            else:
                self.mem_read_write_addrs[in_mem]["write_addr"] = newAG.ports.addr_out

            newSG = SchedGen(iterator_support=input_dim,
                             config_width=self.cycle_count_width)
            self.add_child(f"input_port{input_port_index}_2{in_mem}_write_sched_gen",
                           newSG,
                           clk=self.gclk,
                           rst_n=self.rst_n,
                           mux_sel=forloop.ports.mux_sel_out,
                           finished=forloop.ports.restart,
                           cycle_count=self._cycle_count,
                           valid_output=self.valid)

        # set up output memories
        for i in range(len(is_output)):
            out_mem = is_output[i]

            # output addressor / accessor parameters
            output_dim = self.memories[out_mem]["output_edge_params"]["dim"]
            output_range = self.memories[out_mem]["output_edge_params"]["max_range"]
            output_stride = self.memories[out_mem]["output_edge_params"]["max_stride"]
            # output port associated with memory
            output_port_index = self.memories[out_mem]["output_port"]

            # hook up data from the memory to the specified output port
            self.wire(self.data_out[output_port_index],
                      self.mem_insts[out_mem].ports.data_out[0][0])
            # self.mem_data_outs[subscript_mems.index(out_mem)][0])

            self.valid = self.var(f"{out_mem}2output_port{output_port_index}_accessor_valid", 1)
            if self.memories[out_mem]["rw_same_cycle"]:
                self.wire(self.mem_insts[out_mem].ports.read, self.valid)

            # create IteratorDomain, AddressGenerator, and ScheduleGenerator
            # for reads from this output memory
            forloop = ForLoop(iterator_support=output_dim,
                              config_width=max(1, clog2(output_range)))  # self.default_config_width)
            loop_itr = forloop.get_iter()
            loop_wth = forloop.get_cfg_width()

            self.add_child(f"{out_mem}2output_port{output_port_index}_forloop",
                           forloop,
                           clk=self.gclk,
                           rst_n=self.rst_n,
                           step=self.valid)

            newAG = AddrGen(iterator_support=output_dim,
                            config_width=max(1, clog2(output_stride)))  # self.default_config_width)
            self.add_child(f"{out_mem}2output_port{output_port_index}_read_addr_gen",
                           newAG,
                           clk=self.gclk,
                           rst_n=self.rst_n,
                           step=self.valid,
                           mux_sel=forloop.ports.mux_sel_out,
                           restart=forloop.ports.restart)

            if self.memories[out_mem]["num_read_write_ports"] == 0:
                safe_wire(self, self.mem_insts[out_mem].ports.read_addr[0], newAG.ports.addr_out)
            else:
                self.mem_read_write_addrs[in_mem]["read_addr"] = newAG.ports.addr_out

            newSG = SchedGen(iterator_support=output_dim,
                             config_width=self.cycle_count_width)  # self.default_config_width)
            self.add_child(f"{out_mem}2output_port{output_port_index}_read_sched_gen",
                           newSG,
                           clk=self.gclk,
                           rst_n=self.rst_n,
                           mux_sel=forloop.ports.mux_sel_out,
                           finished=forloop.ports.restart,
                           cycle_count=self._cycle_count,
                           valid_output=self.valid)

        # create shared IteratorDomains and accessors as well as
        # read/write addressors for memories connected by each edge
        for edge in self.edges:

            # see how many signals need to be selected between for
            # from and to signals for edge
            num_mux_from = len(edge["from_signal"])
            num_mux_to = len(edge["to_signal"])

            # get unique edge_name identifier for hardware modules
            edge_name = get_edge_name(edge)

            # create forloop and accessor valid output signal
            self.valid = self.var(edge_name + "_accessor_valid", 1)

            forloop = ForLoop(iterator_support=edge["dim"])
            self.forloop = forloop
            loop_itr = forloop.get_iter()
            loop_wth = forloop.get_cfg_width()

            self.add_child(edge_name + "_forloop",
                           forloop,
                           clk=self.gclk,
                           rst_n=self.rst_n,
                           step=self.valid)

            # create input addressor
            readAG = AddrGen(iterator_support=edge["dim"],
                             config_width=self.default_config_width)
            self.add_child(f"{edge_name}_read_addr_gen",
                           readAG,
                           clk=self.gclk,
                           rst_n=self.rst_n,
                           step=self.valid,
                           mux_sel=forloop.ports.mux_sel_out,
                           restart=forloop.ports.restart)

            # assign read address to all from memories
            if self.memories[edge["from_signal"][0]]["num_read_write_ports"] == 0:
                # can assign same read addrs to all the memories
                for i in range(len(edge["from_signal"])):
                    safe_wire(self, self.mem_insts[edge["from_signal"][i]].ports.read_addr[0], readAG.ports.addr_out)
            else:
                for i in range(len(edge["from_signal"])):
                    self.mem_read_write_addrs[edge["from_signal"][i]]["read_addr"] = readAG.ports.addr_out

            # if needing to mux, choose which from memory we get data
            # from for to memory data in
            if num_mux_from > 1:
                num_mux_bits = clog2(num_mux_from)
                self.mux_sel = self.var(f"{edge_name}_mux_sel",
                                        width=num_mux_bits)

                read_addr_width = max(1, clog2(self.memories[edge["from_signal"][0]]["capacity"]))
                # decide which memory to get data from for to memory's data in
                safe_wire(self, self.mux_sel,
                          readAG.ports.addr_out[read_addr_width + num_mux_from - 1, read_addr_width])

                comb_mux_from = self.combinational()
                # for i in range(num_mux_from):
                # TODO want to use a switch statement here, but get add_fn_ln issue
                if_mux_sel = IfStmt(self.mux_sel == 0)
                for j in range(len(edge["to_signal"])):
                    # print("TO ", edge["to_signal"][j])
                    # print("FROM ", edge["from_signal"][i])
                    if_mux_sel.then_(self.mem_insts[edge["to_signal"][j]].ports.data_in.assign(self.mem_insts[edge["from_signal"][0]].ports.data_out))
                    if_mux_sel.else_(self.mem_insts[edge["to_signal"][j]].ports.data_in.assign(self.mem_insts[edge["from_signal"][1]].ports.data_out))
                comb_mux_from.add_stmt(if_mux_sel)

            # no muxing from, data_out from the one and only memory
            # goes to all to memories (valid determines whether it is
            # actually written)
            else:
                for j in range(len(edge["to_signal"])):
                    # print("TO ", edge["to_signal"][j])
                    # print("FROM ", edge["from_signal"][0])
                    safe_wire(self,
                              self.mem_insts[edge["to_signal"][j]].ports.data_in,
                              # only one memory to read from
                              self.mem_insts[edge["from_signal"][0]].ports.data_out)

            # create output addressor
            writeAG = AddrGen(iterator_support=edge["dim"],
                              config_width=self.default_config_width)
            # step, mux_sel, restart may need delayed signals (assigned later)
            self.add_child(f"{edge_name}_write_addr_gen",
                           writeAG,
                           clk=self.gclk,
                           rst_n=self.rst_n)

            # set write addr for to memories
            if self.memories[edge["to_signal"][0]]["num_read_write_ports"] == 0:
                for i in range(len(edge["to_signal"])):
                    safe_wire(self, self.mem_insts[edge["to_signal"][i]].ports.write_addr[0], writeAG.ports.addr_out)
            else:
                for i in range(len(edge["to_signal"])):
                    self.mem_read_write_addrs[edge["to_signal"][i]] = {"write": self.valid, "write_addr": writeAG.ports.addr_out}

            # calculate necessary delay between from_signal to to_signal
            # TODO this may need to be more sophisticated and based on II as well
            # TODO just need to add for loops for all the ports
            if self.memories[edge["from_signal"][0]]["num_read_write_ports"] == 0:
                self.delay = self.memories[edge["from_signal"][0]]["read_info"][0]["latency"]
            else:
                self.delay = self.memories[edge["from_signal"][0]]["read_write_info"][0]["latency"]

            if self.delay > 0:
                # signals that need to be delayed due to edge latency
                self.delayed_writes = self.var(f"{edge_name}_delayed_writes",
                                               width=self.delay)
                self.delayed_mux_sels = self.var(f"{edge_name}_delayed_mux_sels",
                                                 width=self.forloop.ports.mux_sel_out.width,
                                                 size=self.delay,
                                                 explicit_array=True,
                                                 packed=True)
                self.delayed_restarts = self.var(f"{edge_name}_delayed_restarts",
                                                 width=self.delay)

                # delay in valid between read from memory and write to next memory
                @always_ff((posedge, self.gclk), (negedge, "rst_n"))
                def get_delayed_write(self):
                    if ~self.rst_n:
                        self.delayed_writes = 0
                        self.delayed_mux_sels = 0
                        self.delayed_restarts = 0
                    else:
                        for i in range(self.delay - 1):
                            self.delayed_writes[i + 1] = self.delayed_writes[i]
                            self.delayed_mux_sels[i + 1] = self.delayed_mux_sels[i]
                            self.delayed_restarts[i + 1] = self.delayed_restarts[i]
                        self.delayed_writes[0] = self.valid
                        self.delayed_mux_sels[0] = self.forloop.ports.mux_sel_out
                        self.delayed_restarts[0] = self.forloop.ports.restart

                self.add_always(get_delayed_write)

            # if we have a mux for the destination memories,
            # choose which mux to write to
            if num_mux_to > 1:
                num_mux_bits = clog2(num_mux_to)
                self.mux_sel_to = self.var(f"{edge_name}_mux_sel_to",
                                           width=num_mux_bits)

                write_addr_width = max(1, clog2(self.memories[edge["to_signal"][0]]["capacity"]))
                # decide which destination memory gets written to
                safe_wire(self, self.mux_sel_to,
                          writeAG.ports.addr_out[write_addr_width + num_mux_to - 1, write_addr_width])

                # wire the write (or if needed, delayed write) signal to the selected destination memory
                # and set write enable low for all other destination memories
                comb_mux_to = self.combinational()
                for i in range(num_mux_to):
                    if_mux_sel_to = IfStmt(self.mux_sel_to == i)
                    if self.delay == 0:
                        if_mux_sel_to.then_(self.mem_insts[edge["to_signal"][i]].ports.write.assign(self.valid))
                    else:
                        if_mux_sel_to.then_(self.mem_insts[edge["to_signal"][i]].ports.write.assign(self.delayed_writes[self.delay - 1]))

                    if_mux_sel_to.else_(self.mem_insts[edge["to_signal"][i]].ports.write.assign(self.low))
                    comb_mux_to.add_stmt(if_mux_sel_to)

            # no muxing to, just write to the one destination memory
            else:
                if self.delay == 0:
                    self.wire(self.mem_insts[edge["to_signal"][0]].ports.write, self.valid)
                else:
                    self.wire(self.mem_insts[edge["to_signal"][0]].ports.write, self.delayed_writes[self.delay - 1])

            # assign delayed signals for write addressor if needed
            if self.delay == 0:
                self.wire(writeAG.ports.step, self.valid)
                self.wire(writeAG.ports.mux_sel, self.forloop.ports.mux_sel_out)
                self.wire(writeAG.ports.restart, self.forloop.ports.restart)
            else:
                self.wire(writeAG.ports.step, self.delayed_writes[self.delay - 1])
                self.wire(writeAG.ports.mux_sel, self.delayed_mux_sels[self.delay - 1])
                self.wire(writeAG.ports.restart, self.delayed_restarts[self.delay - 1])

            # create accessor for edge
            newSG = SchedGen(iterator_support=edge["dim"],
                             config_width=self.cycle_count_width)  # self.default_config_width)

            self.add_child(edge_name + "_sched_gen",
                           newSG,
                           clk=self.gclk,
                           rst_n=self.rst_n,
                           mux_sel=forloop.ports.mux_sel_out,
                           finished=forloop.ports.restart,
                           cycle_count=self._cycle_count,
                           valid_output=self.valid)

        # for read write memories, choose either read or write address based on whether
        # we are writing to the memory (whether write enable is high)
        read_write_addr_comb = self.combinational()
        for mem_name in self.memories:
            if mem_name in self.mem_read_write_addrs:
                mem_info = self.mem_read_write_addrs[mem_name]
                if_write = IfStmt(mem_info["write"] == 1)
                addr_width = self.mem_insts[mem_name].ports.read_write_addr[0].width
                if_write.then_(self.mem_insts[mem_name].ports.read_write_addr[0].assign(mem_info["write_addr"][addr_width - 1, 0]))
                if_write.else_(self.mem_insts[mem_name].ports.read_write_addr[0].assign(mem_info["read_addr"][addr_width - 1, 0]))
                read_write_addr_comb.add_stmt(if_write)

        # clock enable and flush passes
        kts.passes.auto_insert_clock_enable(self.internal_generator)
        clk_en_port = self.internal_generator.get_port("clk_en")
        clk_en_port.add_attribute(FormalAttr(clk_en_port.name, FormalSignalConstraint.SET1))

        self.add_attribute("sync-reset=flush")
        kts.passes.auto_insert_sync_reset(self.internal_generator)
        flush_port = self.internal_generator.get_port("flush")

        # bring config registers up to top level
        lift_config_reg(self.internal_generator)
예제 #8
0
    def __init__(self, fetch_width, data_width, int_out_ports):

        assert not (fetch_width &
                    (fetch_width - 1)), "Memory width needs to be a power of 2"

        super().__init__("sync_groups")
        # Absorb inputs
        self.fetch_width = fetch_width
        self.data_width = data_width
        self.fw_int = int(self.fetch_width / self.data_width)
        self.int_out_ports = int_out_ports
        self.groups = self.int_out_ports

        # Clock and Reset
        self._clk = self.clock("clk")
        self._rst_n = self.reset("rst_n")

        # Inputs
        self._ack_in = self.input("ack_in", self.int_out_ports)

        self._data_in = self.input("data_in",
                                   self.data_width,
                                   size=(self.int_out_ports, self.fw_int),
                                   explicit_array=True,
                                   packed=True)

        self._mem_valid_data = self.input("mem_valid_data", self.int_out_ports)

        self._mem_valid_data_out = self.output("mem_valid_data_out",
                                               self.int_out_ports)

        self._valid_in = self.input("valid_in", self.int_out_ports)
        # Indicates which port belongs to which synchronization group
        self._sync_group = self.input("sync_group",
                                      self.int_out_ports,
                                      size=self.int_out_ports,
                                      explicit_array=True,
                                      packed=True)
        sync_config = ConfigRegAttr(
            "This array of one hot vectors" +
            " is used to denote which ports are synchronized to eachother." +
            " If multiple ports should output data relative to eachother" +
            " one should put them in the same group.")
        self._sync_group.add_attribute(sync_config)

        # Outputs
        self._data_out = self.output("data_out",
                                     self.data_width,
                                     size=(self.int_out_ports, self.fw_int),
                                     explicit_array=True,
                                     packed=True)

        self._valid_out = self.output("valid_out", self.int_out_ports)

        # Locals
        self._sync_agg = self.var("sync_agg",
                                  self.int_out_ports,
                                  size=self.int_out_ports,
                                  explicit_array=True,
                                  packed=True)

        self._sync_valid = self.var("sync_valid", self.int_out_ports)
        self._data_reg = self.var("data_reg",
                                  self.data_width,
                                  size=(self.int_out_ports, self.fw_int),
                                  explicit_array=True,
                                  packed=True)

        self._valid_reg = self.var("valid_reg", self.int_out_ports)

        # This signal allows us to orchestrate the synchronization groups
        # at the output address controller
        self._ren_in = self.input("ren_in", self.int_out_ports)

        self._ren_int = self.var("ren_int", self.int_out_ports)

        self._rd_sync_gate = self.output("rd_sync_gate", self.int_out_ports)

        self._local_gate_bus = self.var("local_gate_bus",
                                        self.int_out_ports,
                                        size=self.int_out_ports,
                                        explicit_array=True,
                                        packed=True)

        self._local_gate_bus_n = self.var("local_gate_bus_n",
                                          self.int_out_ports,
                                          size=self.int_out_ports,
                                          explicit_array=True,
                                          packed=True)

        self.wire(self._local_gate_bus, ~self._local_gate_bus_n)

        self._local_gate_bus_tpose = self.var("local_gate_bus_tpose",
                                              self.int_out_ports,
                                              size=self.int_out_ports,
                                              explicit_array=True,
                                              packed=True)

        self._local_gate_reduced = self.var("local_gate_reduced",
                                            self.int_out_ports)

        self._local_gate_mask = self.var("local_gate_mask",
                                         self.int_out_ports,
                                         size=self.int_out_ports,
                                         explicit_array=True,
                                         packed=True)

        self._group_finished = self.var("group_finished", self.int_out_ports)

        self._grp_fin_large = self.var("grp_fin_large",
                                       self.int_out_ports,
                                       size=self.int_out_ports,
                                       explicit_array=True,
                                       packed=True)

        self._done = self.var("done", self.int_out_ports)
        self._done_alt = self.var("done_alt", self.int_out_ports)

        # Output data is ungated
        self.wire(self._data_out, self._data_reg)
        self.wire(self._rd_sync_gate, self._local_gate_reduced)
        # Valid requires gating based on sync_valid
        self.wire(self._ren_int, self._ren_in & self._local_gate_reduced)
        # Add Code
        self.add_code(self.set_sync_agg, unroll_for=True)
        self.add_code(self.set_sync_valid)
        for i in range(self.int_out_ports):
            self.add_code(self.set_sync_stage, idx=i)
        self.add_code(self.set_out_valid)
        self.add_code(self.set_reduce_gate)
        for i in range(self.groups):
            self.add_code(self.set_rd_gates, idx=i)
        self.add_code(self.set_tpose)
        self.add_code(self.set_finished)
        self.add_code(self.next_gate_mask, unroll_for=True)
        self.add_code(self.set_grp_fin, unroll_for=True)
예제 #9
0
    def __init__(self, agg_height, data_width, mem_width, max_agg_schedule):

        super().__init__("aggregation_buffer")

        self.agg_height = agg_height
        self.data_width = data_width
        self.mem_width = mem_width
        self.max_agg_schedule = max_agg_schedule  # This is the maximum length of the schedule
        self.fw_int = int(self.mem_width / self.data_width)

        # Clock and Reset
        self._clk = self.clock("clk")
        self._rst_n = self.reset("rst_n")

        # Inputs
        # Bring in a single element into an AggregationBuffer w/ valid signaling
        self._data_in = self.input("data_in", self.data_width)
        self._valid_in = self.input("valid_in", 1)
        self._align = self.input("align", 1)
        # Outputs
        self._data_out = self.output("data_out", self.mem_width)
        self._valid_out = self.output("valid_out", 1)

        self._data_out_chop = []
        for i in range(self.fw_int):
            self._data_out_chop.append(
                self.output(f"data_out_chop_{i}", self.data_width))
            self.add_stmt(self._data_out_chop[i].assign(
                self._data_out[(self.data_width * (i + 1)) - 1,
                               self.data_width * i]))

        # CONFIG:
        # We receive a periodic (doesn't need to be, but has a maximum schedule,
        # so...possibly the schedule is a for loop?
        # Tells us where to write successive elements...
        self._in_schedule = self.input("in_sched",
                                       max(1, clog2(self.agg_height)),
                                       size=self.max_agg_schedule,
                                       explicit_array=True,
                                       packed=True)
        doc = "Input schedule for aggregation buffer. Enumerate which" + \
              f" of {self.agg_height} buffers to write to."
        self._in_schedule.add_attribute(ConfigRegAttr(doc))

        self._in_period = self.input("in_period", clog2(self.max_agg_schedule))
        doc = "Input period for aggregation buffer. 1 is a reasonable" + \
              " setting for most applications"
        self._in_period.add_attribute(ConfigRegAttr(doc))

        # ...and which order to output the blocks
        self._out_schedule = self.input("out_sched",
                                        max(1, clog2(agg_height)),
                                        size=self.max_agg_schedule,
                                        explicit_array=True,
                                        packed=True)
        doc = "Output schedule for aggregation buffer. Enumerate which" + \
              f" of {self.agg_height} buffers to write to SRAM from."
        self._out_schedule.add_attribute(ConfigRegAttr(doc))

        self._out_period = self.input("out_period",
                                      clog2(self.max_agg_schedule))
        self._out_period.add_attribute(
            ConfigRegAttr("Output period for aggregation buffer"))

        self._in_sched_ptr = self.var("in_sched_ptr",
                                      clog2(self.max_agg_schedule))
        self._out_sched_ptr = self.var("out_sched_ptr",
                                       clog2(self.max_agg_schedule))

        # Local Signals
        self._aggs_out = self.var("aggs_out",
                                  self.mem_width,
                                  size=self.agg_height,
                                  packed=True,
                                  explicit_array=True)

        self._aggs_sep = []
        for i in range(self.agg_height):
            self._aggs_sep.append(
                self.var(f"aggs_sep_{i}",
                         self.data_width,
                         size=self.fw_int,
                         packed=True))
        self._valid_demux = self.var("valid_demux", self.agg_height)
        self._align_demux = self.var("align_demux", self.agg_height)
        self._next_full = self.var("next_full", self.agg_height)
        self._valid_out_mux = self.var("valid_out_mux", self.agg_height)
        for i in range(self.agg_height):
            # Add in the children aggregators...
            self.add_child(f"agg_{i}",
                           Aggregator(self.data_width,
                                      mem_word_width=self.fw_int),
                           clk=self._clk,
                           rst_n=self._rst_n,
                           in_pixels=self._data_in,
                           valid_in=self._valid_demux[i],
                           agg_out=self._aggs_sep[i],
                           valid_out=self._valid_out_mux[i],
                           next_full=self._next_full[i],
                           align=self._align_demux[i])
            portlist = []
            if self.fw_int == 1:
                self.wire(self._aggs_out[i], self._aggs_sep[i])
            else:
                for j in range(self.fw_int):
                    portlist.append(self._aggs_sep[i][self.fw_int - 1 - j])
                self.wire(self._aggs_out[i], kts.concat(*portlist))

        # Sequential code blocks
        self.add_code(self.update_in_sched_ptr)
        self.add_code(self.update_out_sched_ptr)

        # Combinational code blocks
        self.add_code(self.valid_demux_comb)
        self.add_code(self.align_demux_comb)
        self.add_code(self.valid_out_comb)
        self.add_code(self.output_data_comb)
예제 #10
0
    def __init__(self, iterator_support=6, config_width=16, use_enable=True):

        super().__init__(f"sched_gen_{iterator_support}_{config_width}")

        self.iterator_support = iterator_support
        self.config_width = config_width
        self.use_enable = use_enable

        # PORT DEFS: begin

        # INPUTS
        self._clk = self.clock("clk")
        self._rst_n = self.reset("rst_n")

        # OUTPUTS
        self._valid_output = self.output("valid_output", 1)

        # VARS
        self._valid_out = self.var("valid_out", 1)
        self._cycle_count = self.input("cycle_count", self.config_width)
        self._mux_sel = self.input("mux_sel",
                                   max(clog2(self.iterator_support), 1))
        self._addr_out = self.var("addr_out", self.config_width)

        # Receive signal on last iteration of looping structure and
        # gate the output...
        self._finished = self.input("finished", 1)
        self._valid_gate_inv = self.var("valid_gate_inv", 1)
        self._valid_gate = self.var("valid_gate", 1)
        self.wire(self._valid_gate, ~self._valid_gate_inv)

        # Since dim = 0 is not sufficient, we need a way to prevent
        # the controllers from firing on the starting offset
        if self.use_enable:
            self._enable = self.input("enable", 1)
            self._enable.add_attribute(
                ConfigRegAttr("Disable the controller so it never fires..."))
            self._enable.add_attribute(
                FormalAttr(f"{self._enable.name}",
                           FormalSignalConstraint.SOLVE))
        # Otherwise we set it as a 1 and leave it up to synthesis...
        else:
            self._enable = self.var("enable", 1)
            self.wire(self._enable, kratos.const(1, 1))

        @always_ff((posedge, "clk"), (negedge, "rst_n"))
        def valid_gate_inv_ff():
            if ~self._rst_n:
                self._valid_gate_inv = 0
            # If we are finishing the looping structure, turn this off to implement one-shot
            elif self._finished:
                self._valid_gate_inv = 1

        self.add_code(valid_gate_inv_ff)

        # Compare based on minimum of addr + global cycle...
        self.c_a_cmp = min(self._cycle_count.width, self._addr_out.width)

        # PORT DEFS: end

        self.add_child(f"sched_addr_gen",
                       AddrGen(iterator_support=self.iterator_support,
                               config_width=self.config_width),
                       clk=self._clk,
                       rst_n=self._rst_n,
                       step=self._valid_out,
                       mux_sel=self._mux_sel,
                       addr_out=self._addr_out,
                       restart=const(0, 1))

        self.add_code(self.set_valid_out)
        self.add_code(self.set_valid_output)
예제 #11
0
    def __init__(self, iterator_support=6, config_width=16):

        super().__init__(f"for_loop_{iterator_support}_{config_width}",
                         debug=True)

        self.iterator_support = iterator_support
        self.config_width = config_width
        # Create params for instancing this module...
        self.iterator_support_par = self.param("ITERATOR_SUPPORT",
                                               clog2(iterator_support) + 1,
                                               value=self.iterator_support)
        self.config_width_par = self.param("CONFIG_WIDTH",
                                           clog2(config_width) + 1,
                                           value=self.config_width)

        # PORT DEFS: begin

        # INPUTS
        self._clk = self.clock("clk")
        self._rst_n = self.reset("rst_n")

        self._ranges = self.input("ranges",
                                  self.config_width,
                                  size=self.iterator_support,
                                  packed=True,
                                  explicit_array=True)
        self._ranges.add_attribute(
            ConfigRegAttr("Ranges of address generator"))
        self._ranges.add_attribute(
            FormalAttr(f"{self._ranges.name}", FormalSignalConstraint.SOLVE))

        self._dimensionality = self.input("dimensionality",
                                          1 + clog2(self.iterator_support))
        self._dimensionality.add_attribute(
            ConfigRegAttr("Dimensionality of address generator"))
        self._dimensionality.add_attribute(
            FormalAttr(f"{self._dimensionality.name}",
                       FormalSignalConstraint.SOLVE))

        self._step = self.input("step", 1)
        # OUTPUTS

        # PORT DEFS: end

        # LOCAL VARIABLES: begin
        self._dim_counter = self.var("dim_counter",
                                     self.config_width,
                                     size=self.iterator_support,
                                     packed=True,
                                     explicit_array=True)

        self._strt_addr = self.var("strt_addr", self.config_width)

        self._counter_update = self.var("counter_update", 1)
        self._calc_addr = self.var("calc_addr", self.config_width)

        self._max_value = self.var("max_value", self.iterator_support)
        self._mux_sel = self.var("mux_sel", max(clog2(self.iterator_support),
                                                1))
        self._mux_sel_out = self.output("mux_sel_out",
                                        max(clog2(self.iterator_support), 1))
        self.wire(self._mux_sel_out, self._mux_sel)
        # LOCAL VARIABLES: end
        # GENERATION LOGIC: begin
        self._done = self.var("done", 1)
        self._clear = self.var("clear", self.iterator_support)
        self._inc = self.var("inc", self.iterator_support)

        self._inced_cnt = self.var("inced_cnt", self._dim_counter.width)
        self.wire(self._inced_cnt, self._dim_counter[self._mux_sel] + 1)
        # Next_max_value
        self._maxed_value = self.var("maxed_value", 1)
        self.wire(
            self._maxed_value,
            (self._dim_counter[self._mux_sel] == self._ranges[self._mux_sel])
            & self._inc[self._mux_sel])

        self.add_code(self.set_mux_sel)
        for i in range(self.iterator_support):
            self.add_code(self.set_clear, idx=i)
            self.add_code(self.set_inc, idx=i)
            self.add_code(self.dim_counter_update, idx=i)
            self.add_code(self.max_value_update, idx=i)
        # GENERATION LOGIC: end

        self._restart = self.output("restart", 1)
        self.wire(self._restart, self._step & (~self._done))
예제 #12
0
    def __init__(
            self,
            # number of bits in a word
            word_width,
            # number of words that can be sotred at an address in SRAM
            # note fetch_width must be powers of 2
            fetch_width,
            # total number of transpose buffers
            num_tb,
            # height of this particular transpose buffer
            max_tb_height,
            # maximum value for range parameters in nested for loop
            # (and as a result, maximum length of indices input vector)
            # specifying inner for loop values for output column
            # addressing
            max_range,
            max_range_inner,
            max_stride,
            tb_iterator_support):
        super().__init__("transpose_buffer", debug=True)

        #########################
        # GENERATION PARAMETERS #
        #########################

        self.word_width = word_width
        self.fetch_width = fetch_width
        self.num_tb = num_tb
        self.max_tb_height = max_tb_height
        self.max_range = max_range
        self.max_range_inner = max_range_inner
        self.max_stride = max_stride
        self.tb_iterator_support = tb_iterator_support

        ##################################
        # BITS FOR GENERATION PARAMETERS #
        ##################################

        self.fetch_width_bits = max(1, clog2(self.fetch_width))
        self.num_tb_bits = max(1, clog2(self.num_tb))
        self.max_range_bits = max(1, clog2(self.max_range))
        self.max_range_inner_bits = max(1, clog2(self.max_range_inner))
        self.max_stride_bits = max(1, clog2(self.max_stride))
        self.tb_col_index_bits = 2 * max(self.fetch_width_bits,
                                         self.num_tb_bits) + 1
        self.max_tb_height_bits2 = max(1, clog2(2 * self.max_tb_height))
        self.max_tb_height_bits = max(1, clog2(self.max_tb_height))
        self.tb_iterator_support_bits = max(
            1,
            clog2(self.tb_iterator_support) + 1)
        self.max_range_stride_bits2 = max(2 * self.max_range_bits,
                                          2 * self.max_stride_bits)

        ##########
        # INPUTS #
        ##########

        self.clk = self.clock("clk")
        # active low asynchronous reset
        self.rst_n = self.reset("rst_n", 1)

        # data input from SRAM
        if self.fetch_width == 1:
            self.input_data = self.input("input_data", self.word_width)
        else:
            self.input_data = self.input("input_data",
                                         width=self.word_width,
                                         size=self.fetch_width,
                                         packed=True)
        # valid indicating whether data input from SRAM is valid and
        # should be stored in transpose buffer
        self.valid_data = self.input("valid_data", 1)
        self.ack_in = self.input("ack_in", 1)
        self.ren = self.input("ren", 1)

        self.mem_valid_data = self.input("mem_valid_data", 1)

        ###########################
        # CONFIGURATION REGISTERS #
        ###########################

        # the range of the outer for loop in nested for loop for output
        # column address generation
        self.range_outer = self.input("range_outer", self.max_range_bits)
        self.range_outer.add_attribute(
            ConfigRegAttr("Outer range for output for loop pattern"))

        # the range of the inner for loop in nested for loop for output
        # column address generation
        self.range_inner = self.input("range_inner", self.max_range_inner_bits)
        self.range_inner.add_attribute(
            ConfigRegAttr("Inner range for output for for loop pattern"))

        # stride for the given application
        self.stride = self.input("stride", self.max_stride_bits)
        self.stride.add_attribute(ConfigRegAttr("Application stride"))

        self.tb_height = self.input("tb_height", self.max_tb_height_bits)
        self.tb_height.add_attribute(ConfigRegAttr("Transpose Buffer height"))

        self.dimensionality = self.input("dimensionality",
                                         self.tb_iterator_support_bits)
        self.dimensionality.add_attribute(
            ConfigRegAttr("Transpose Buffer dimensionality"))

        # specifies inner for loop values for output column
        # addressing
        self.indices = self.input(
            "indices",
            width=clog2(2 * self.num_tb * self.fetch_width),
            # the length of indices is equal to range_inner,
            # so the maximum possible size for self.indices
            # is the maximum value of range_inner, which is
            # self.max_range_inner
            size=self.max_range_inner,
            explicit_array=True,
            packed=True)
        self.indices.add_attribute(
            ConfigRegAttr("Output indices for for loop pattern"))

        # offset to start output address if we're starting in the middle of a wider
        # fetch width word for example
        self.starting_addr = self.input("starting_addr", self.fetch_width_bits)
        self.starting_addr.add_attribute(ConfigRegAttr("TB starting address"))

        ###########
        # OUTPUTS #
        ###########

        self.col_pixels = self.output("col_pixels",
                                      width=self.word_width,
                                      size=self.max_tb_height,
                                      packed=True,
                                      explicit_array=True)
        self.output_valid = self.output("output_valid", 1)
        self.rdy_to_arbiter = self.output("rdy_to_arbiter", 1)

        ###################
        # LOCAL VARIABLES #
        ###################

        # transpose buffer
        if self.fetch_width == 1:
            self.tb = self.var("tb",
                               width=self.word_width,
                               size=2 * self.max_tb_height,
                               packed=True)

        else:
            self.tb = self.var("tb",
                               width=self.word_width,
                               size=[2 * self.max_tb_height, self.fetch_width],
                               packed=True)

        self.tb_valid = self.var("tb_valid", 2 * self.max_tb_height)

        self.index_outer = self.var("index_outer", self.max_range_bits)
        self.index_inner = self.var("index_inner", self.max_range_inner_bits)

        self.input_buf_index = self.var("input_buf_index", 1)
        self.out_buf_index = self.var("out_buf_index", 1)
        self.switch_out_buf = self.var("switch_out_buf", 1)
        self.switch_next_line = self.var("switch_next_line", 1)
        self.row_index = self.var("row_index", self.max_tb_height_bits)
        self.input_index = self.var("input_index", self.max_tb_height_bits2)

        self.output_index_abs = self.var("output_index_abs",
                                         self.max_range_stride_bits2)

        if self.fetch_width != 1:
            self.output_index_long = self.var("output_index_long",
                                              self.max_range_stride_bits2)
            self.output_index = self.var("output_index", self.fetch_width_bits)

        self.indices_index_inner = self.var(
            "indices_index_inner", clog2(2 * self.num_tb * self.fetch_width))
        self.curr_out_start = self.var("curr_out_start",
                                       self.max_range_stride_bits2)

        self.start_data = self.var("start_data", 1)
        self.old_start_data = self.var("old_start_data", 1)

        self.pause_tb = self.var("pause_tb", 1)
        self.pause_output = self.var("pause_output", 1)

        self.on_next_line = self.var("on_next_line", 1)

        self.mask_valid = self.var("mask_valid", 1)

        self.pause_tbinv = self.var("pause_tbinv", 1)
        self.rdy_to_arbiterinv = self.var("rdy_to_arbiterinv", 1)
        self.out_buf_indexinv = self.var("out_buf_indexinv", 1)

        ##########################
        # SEQUENTIAL CODE BLOCKS #
        ##########################

        self.add_code(self.set_index_outer)
        self.add_code(self.set_index_inner)
        self.add_code(self.set_pause_tb)
        self.add_code(self.set_row_index)
        self.add_code(self.set_input_buf_index)
        self.add_code(self.input_to_tb)
        self.add_code(self.output_from_tb)
        self.add_code(self.set_output_valid)
        self.add_code(self.set_out_buf_index)
        self.add_code(self.set_rdy_to_arbiter)
        self.add_code(self.set_start_data)
        self.add_code(self.set_curr_out_start)
        if self.fetch_width != 1:
            self.add_code(self.set_output_index)
        self.add_code(self.set_old_start_data)
        self.add_code(self.set_on_next_line)

        #############################
        # COMBINATIONAL CODE BLOCKS #
        #############################

        self.add_code(self.set_pause_output)
        self.add_code(self.set_input_index)
        self.add_code(self.set_tb_out_indices)
        self.add_code(self.set_switch_out_buf)
        self.add_code(self.set_switch_next_line)
        if self.fetch_width != 1:
            self.add_code(self.set_output_index_long)
        self.add_code(self.set_mask_valid)
        self.add_code(self.set_invs)
예제 #13
0
    def __init__(self,
                 data_width=16,  # CGRA Params
                 mem_width=64,
                 mem_depth=512,
                 banks=1,
                 input_iterator_support=6,  # Addr Controllers
                 output_iterator_support=6,
                 input_config_width=16,
                 output_config_width=16,
                 interconnect_input_ports=2,  # Connection to int
                 interconnect_output_ports=2,
                 mem_input_ports=1,
                 mem_output_ports=1,
                 read_delay=1,  # Cycle delay in read (SRAM vs Register File)
                 rw_same_cycle=False,  # Does the memory allow r+w in same cycle?
                 agg_height=4,
                 max_agg_schedule=32,
                 input_max_port_sched=32,
                 output_max_port_sched=32,
                 align_input=1,
                 max_line_length=128,
                 max_tb_height=1,
                 tb_range_max=128,
                 tb_range_inner_max=5,
                 tb_sched_max=64,
                 max_tb_stride=15,
                 num_tb=1,
                 tb_iterator_support=2,
                 multiwrite=1,
                 num_tiles=1,
                 max_prefetch=8,
                 app_ctrl_depth_width=16,
                 remove_tb=False,
                 stcl_valid_iter=4):
        super().__init__("strg_ub")

        self.data_width = data_width
        self.mem_width = mem_width
        self.mem_depth = mem_depth
        self.banks = banks
        self.input_iterator_support = input_iterator_support
        self.output_iterator_support = output_iterator_support
        self.input_config_width = input_config_width
        self.output_config_width = output_config_width
        self.interconnect_input_ports = interconnect_input_ports
        self.interconnect_output_ports = interconnect_output_ports
        self.mem_input_ports = mem_input_ports
        self.mem_output_ports = mem_output_ports
        self.agg_height = agg_height
        self.max_agg_schedule = max_agg_schedule
        self.input_max_port_sched = input_max_port_sched
        self.output_max_port_sched = output_max_port_sched
        self.input_port_sched_width = clog2(self.interconnect_input_ports)
        self.align_input = align_input
        self.max_line_length = max_line_length
        assert self.mem_width >= self.data_width, "Data width needs to be smaller than mem"
        self.fw_int = int(self.mem_width / self.data_width)
        self.num_tb = num_tb
        self.max_tb_height = max_tb_height
        self.tb_range_max = tb_range_max
        self.tb_range_inner_max = tb_range_inner_max
        self.max_tb_stride = max_tb_stride
        self.tb_sched_max = tb_sched_max
        self.tb_iterator_support = tb_iterator_support
        self.multiwrite = multiwrite
        self.max_prefetch = max_prefetch
        self.num_tiles = num_tiles
        self.app_ctrl_depth_width = app_ctrl_depth_width
        self.remove_tb = remove_tb
        self.read_delay = read_delay
        self.rw_same_cycle = rw_same_cycle
        self.stcl_valid_iter = stcl_valid_iter
        # phases = [] TODO

        self.address_width = clog2(self.num_tiles * self.mem_depth)

        # CLK and RST
        self._clk = self.clock("clk")
        self._rst_n = self.reset("rst_n")

        # INPUTS
        self._data_in = self.input("data_in",
                                   self.data_width,
                                   size=self.interconnect_input_ports,
                                   packed=True,
                                   explicit_array=True)

        self._wen_in = self.input("wen_in", self.interconnect_input_ports)
        self._ren_input = self.input("ren_in", self.interconnect_output_ports)
        # Post rate matched
        self._ren_in = self.var("ren_in_muxed", self.interconnect_output_ports)
        # Processed versions of wen and ren from the app ctrl
        self._wen = self.var("wen", self.interconnect_input_ports)
        self._ren = self.var("ren", self.interconnect_output_ports)

        # Add rate matched
        # If one input port, let any output port use the wen_in as the ren_in
        # If more, do the same thing but also provide port selection
        if self.interconnect_input_ports == 1:
            self._rate_matched = self.input("rate_matched", self.interconnect_output_ports)
            self._rate_matched.add_attribute(ConfigRegAttr("Rate matched - 1 or 0"))
            for i in range(self.interconnect_output_ports):
                self.wire(self._ren_in[i],
                          kts.ternary(self._rate_matched[i],
                                      self._wen_in,
                                      self._ren_input[i]))
        else:
            self._rate_matched = self.input("rate_matched", 1 + kts.clog2(self.interconnect_input_ports),
                                            size=self.interconnect_output_ports,
                                            explicit_array=True,
                                            packed=True)
            self._rate_matched.add_attribute(ConfigRegAttr("Rate matched [input port | on/off]"))
            for i in range(self.interconnect_output_ports):
                self.wire(self._ren_in[i],
                          kts.ternary(self._rate_matched[i][0],
                                      self._wen_in[self._rate_matched[i][kts.clog2(self.interconnect_input_ports), 1]],
                                      self._ren_input[i]))

        self._arb_wen_en = self.var("arb_wen_en", self.interconnect_input_ports)
        self._arb_ren_en = self.var("arb_ren_en", self.interconnect_output_ports)

        self._data_from_strg = self.input("data_from_strg",
                                          self.data_width,
                                          size=(self.banks,
                                                self.mem_output_ports,
                                                self.fw_int),
                                          packed=True,
                                          explicit_array=True)

        self._mem_valid_data = self.input("mem_valid_data",
                                          self.mem_output_ports,
                                          size=self.banks,
                                          explicit_array=True,
                                          packed=True)

        self._out_mem_valid_data = self.var("out_mem_valid_data",
                                            self.mem_output_ports,
                                            size=self.banks,
                                            explicit_array=True,
                                            packed=True)

        # We need to signal valids out of the agg buff, only if one exists...
        if self.agg_height > 0:
            self._to_iac_valid = self.var("ab_to_mem_valid",
                                          self.interconnect_input_ports)

        self._data_out = self.output("data_out",
                                     self.data_width,
                                     size=self.interconnect_output_ports,
                                     packed=True,
                                     explicit_array=True)

        self._valid_out = self.output("valid_out",
                                      self.interconnect_output_ports)

        self._valid_out_alt = self.var("valid_out_alt",
                                       self.interconnect_output_ports)

        self._data_to_strg = self.output("data_to_strg",
                                         self.data_width,
                                         size=(self.banks,
                                               self.mem_input_ports,
                                               self.fw_int),
                                         packed=True,
                                         explicit_array=True)

        # If we can perform a read and a write on the same cycle,
        # this will necessitate a separate read and write address...
        if self.rw_same_cycle:
            self._wr_addr_out = self.output("wr_addr_out",
                                            self.address_width,
                                            size=(self.banks,
                                                  self.mem_input_ports),
                                            explicit_array=True,
                                            packed=True)

            self._rd_addr_out = self.output("rd_addr_out",
                                            self.address_width,
                                            size=(self.banks,
                                                  self.mem_output_ports),
                                            explicit_array=True,
                                            packed=True)

        else:
            self._addr_out = self.output("addr_out",
                                         self.address_width,
                                         size=(self.banks,
                                               self.mem_input_ports),
                                         packed=True,
                                         explicit_array=True)

        self._cen_to_strg = self.output("cen_to_strg", self.mem_output_ports,
                                        size=self.banks,
                                        explicit_array=True,
                                        packed=True)
        self._wen_to_strg = self.output("wen_to_strg", self.mem_input_ports,
                                        size=self.banks,
                                        explicit_array=True,
                                        packed=True)
        if self.num_tb > 0:
            self._tb_valid_out = self.var("tb_valid_out", self.interconnect_output_ports)

        self._port_wens = self.var("port_wens", self.interconnect_input_ports)

        ####################
        ##### APP CTRL #####
        ####################
        self._ack_transpose = self.var("ack_transpose",
                                       self.banks,
                                       size=self.interconnect_output_ports,
                                       explicit_array=True,
                                       packed=True)

        self._ack_reduced = self.var("ack_reduced",
                                     self.interconnect_output_ports)

        self.app_ctrl = AppCtrl(interconnect_input_ports=self.interconnect_input_ports,
                                interconnect_output_ports=self.interconnect_output_ports,
                                depth_width=self.app_ctrl_depth_width,
                                sprt_stcl_valid=True,
                                stcl_iter_support=self.stcl_valid_iter)

        # Some refactoring here for pond to get rid of app controllers...
        # This is honestly pretty messy and should clean up nicely when we have the spec...
        self._ren_out_reduced = self.var("ren_out_reduced",
                                         self.interconnect_output_ports)

        if self.num_tb == 0 or self.remove_tb:
            self.wire(self._wen, self._wen_in)
            self.wire(self._ren, self._ren_in)
            self.wire(self._valid_out, self._valid_out_alt)
            self.wire(self._arb_wen_en, self._wen)
            self.wire(self._arb_ren_en, self._ren)
        else:
            self.add_child("app_ctrl", self.app_ctrl,
                           clk=self._clk,
                           rst_n=self._rst_n,
                           wen_in=self._wen_in,
                           ren_in=self._ren_in,
                           #    ren_update=self._tb_valid_out,
                           valid_out_data=self._valid_out,
                           # valid_out_stencil=,
                           wen_out=self._wen,
                           ren_out=self._ren)

            self.wire(self.app_ctrl.ports.tb_valid, self._tb_valid_out)
            self.wire(self.app_ctrl.ports.ren_update, self._tb_valid_out)

            self.app_ctrl_coarse = AppCtrl(interconnect_input_ports=self.interconnect_input_ports,
                                           interconnect_output_ports=self.interconnect_output_ports,
                                           depth_width=self.app_ctrl_depth_width)
            self.add_child("app_ctrl_coarse", self.app_ctrl_coarse,
                           clk=self._clk,
                           rst_n=self._rst_n,
                           wen_in=self._to_iac_valid,  # self._port_wens & self._to_iac_valid,  # Gets valid and the ack
                           ren_in=self._ren_out_reduced,
                           tb_valid=kts.const(0, 1),
                           ren_update=self._ack_reduced,
                           wen_out=self._arb_wen_en,
                           ren_out=self._arb_ren_en)

        ###########################
        ##### INPUT AGG SCHED #####
        ###########################
        ###########################################
        ##### AGGREGATION ALIGNERS (OPTIONAL) #####
        ###########################################
        # These variables are holders and can be swapped out if needed
        self._data_consume = self._data_in
        self._valid_consume = self._wen
        # Zero out if not aligning
        if(self.agg_height > 0):
            self._align_to_agg = self.var("align_input",
                                          self.interconnect_input_ports)
        # Add the aggregation buffer aligners
        if(self.align_input):
            self._data_consume = self.var("data_consume",
                                          self.data_width,
                                          size=self.interconnect_input_ports,
                                          packed=True,
                                          explicit_array=True)
            self._valid_consume = self.var("valid_consume",
                                           self.interconnect_input_ports)

            # Make new aggregation aligners for each port
            for i in range(self.interconnect_input_ports):
                new_child = AggAligner(self.data_width, self.max_line_length)
                self.add_child(f"agg_align_{i}", new_child,
                               clk=self._clk,
                               rst_n=self._rst_n,
                               in_dat=self._data_in[i],
                               in_valid=self._wen[i],
                               align=self._align_to_agg[i],
                               out_valid=self._valid_consume[i],
                               out_dat=self._data_consume[i])
        else:
            if self.agg_height > 0:
                self.wire(self._align_to_agg, const(0, self._align_to_agg.width))
        ################################################
        ##### END: AGGREGATION ALIGNERS (OPTIONAL) #####
        ################################################

        if self.agg_height == 0:
            self._to_iac_dat = self._data_consume
            self._to_iac_valid = self._valid_consume

        ##################################
        ##### AGG BUFFERS (OPTIONAL) #####
        ##################################
        # Only instantiate agg_buffer if needed
        if(self.agg_height > 0):
            self._to_iac_dat = self.var("ab_to_mem_dat",
                                        self.mem_width,
                                        size=self.interconnect_input_ports,
                                        packed=True,
                                        explicit_array=True)

            # self._to_iac_valid = self.var("ab_to_mem_valid",
            #                               self.interconnect_input_ports)

            self._agg_buffers = []
            # Add input aggregations buffers
            for i in range(self.interconnect_input_ports):
                # add children aggregator buffers...
                agg_buffer_new = AggregationBuffer(self.agg_height,
                                                   self.data_width,
                                                   self.mem_width,
                                                   self.max_agg_schedule)
                self._agg_buffers.append(agg_buffer_new)
                self.add_child(f"agg_in_{i}", agg_buffer_new,
                               clk=self._clk,
                               rst_n=self._rst_n,
                               data_in=self._data_consume[i],
                               valid_in=self._valid_consume[i],
                               align=self._align_to_agg[i],
                               data_out=self._to_iac_dat[i],
                               valid_out=self._to_iac_valid[i])

        #######################################
        ##### END: AGG BUFFERS (OPTIONAL) #####
        #######################################

        self._ready_tba = self.var("ready_tba", self.interconnect_output_ports)
        ####################################
        ##### INPUT ADDRESS CONTROLLER #####
        ####################################
        self._wen_to_arb = self.var("wen_to_arb", self.mem_input_ports,
                                    size=self.banks,
                                    explicit_array=True,
                                    packed=True)
        self._addr_to_arb = self.var("addr_to_arb",
                                     self.address_width,
                                     size=(self.banks,
                                           self.mem_input_ports),
                                     explicit_array=True,
                                     packed=True)
        self._data_to_arb = self.var("data_to_arb",
                                     self.data_width,
                                     size=(self.banks,
                                           self.mem_input_ports,
                                           self.fw_int),
                                     explicit_array=True,
                                     packed=True)

        # Connect these inputs ports to an address generator
        iac = InputAddrCtrl(interconnect_input_ports=self.interconnect_input_ports,
                            mem_depth=self.mem_depth,
                            num_tiles=self.num_tiles,
                            banks=self.banks,
                            iterator_support=self.input_iterator_support,
                            address_width=self.address_width,
                            data_width=self.data_width,
                            fetch_width=self.mem_width,
                            multiwrite=self.multiwrite,
                            strg_wr_ports=self.mem_input_ports,
                            config_width=self.input_config_width)
        self.add_child(f"input_addr_ctrl", iac,
                       clk=self._clk,
                       rst_n=self._rst_n,
                       valid_in=self._to_iac_valid,
                       # wen_en=kts.concat(*([kts.const(1, 1)] * self.interconnect_input_ports)),
                       wen_en=self._arb_wen_en,
                       data_in=self._to_iac_dat,
                       wen_to_sram=self._wen_to_arb,
                       addr_out=self._addr_to_arb,
                       port_out=self._port_wens,
                       data_out=self._data_to_arb)

        #########################################
        ##### END: INPUT ADDRESS CONTROLLER #####
        #########################################
        self._arb_acks = self.var("arb_acks",
                                  self.interconnect_output_ports,
                                  size=self.banks,
                                  explicit_array=True,
                                  packed=True)

        self._prefetch_step = self.var("prefetch_step", self.interconnect_output_ports)
        self._oac_step = self.var("oac_step", self.interconnect_output_ports)
        self._oac_valid = self.var("oac_valid", self.interconnect_output_ports)
        self._ren_out = self.var("ren_out",
                                 self.interconnect_output_ports,
                                 size=self.banks,
                                 explicit_array=True,
                                 packed=True)
        self._ren_out_tpose = self.var("ren_out_tpose",
                                       self.banks,
                                       size=self.interconnect_output_ports,
                                       explicit_array=True,
                                       packed=True)

        self._oac_addr_out = self.var("oac_addr_out",
                                      self.address_width,
                                      size=self.interconnect_output_ports,
                                      explicit_array=True,
                                      packed=True)
        #####################################
        ##### OUTPUT ADDRESS CONTROLLER #####
        #####################################
        oac = OutputAddrCtrl(interconnect_output_ports=self.interconnect_output_ports,
                             mem_depth=self.mem_depth,
                             num_tiles=self.num_tiles,
                             banks=self.banks,
                             iterator_support=self.output_iterator_support,
                             address_width=self.address_width,
                             config_width=self.output_config_width)

        if self.remove_tb:
            self.wire(self._oac_valid, self._ren)
            self.wire(self._oac_step, self._ren)
        else:
            self.wire(self._oac_valid, self._prefetch_step)
            self.wire(self._oac_step, self._ack_reduced)

        self.chain_idx_bits = max(1, clog2(num_tiles))
        self._enable_chain_output = self.input("enable_chain_output", 1)
        self._chain_idx_output = self.input("chain_idx_output", self.chain_idx_bits)

        self.add_child(f"output_addr_ctrl", oac,
                       clk=self._clk,
                       rst_n=self._rst_n,
                       valid_in=self._oac_valid,
                       ren=self._ren_out,
                       addr_out=self._oac_addr_out,
                       step_in=self._oac_step)

        for i in range(self.interconnect_output_ports):
            for j in range(self.banks):
                self.wire(self._ren_out_tpose[i][j], self._ren_out[j][i])
        ##############################
        ##### READ/WRITE ARBITER #####
        ##############################
        # Hook up the read write arbiters for each bank
        self._arb_dat_out = self.var("arb_dat_out",
                                     self.data_width,
                                     size=(self.banks,
                                           self.mem_output_ports,
                                           self.fw_int),
                                     explicit_array=True,
                                     packed=True)

        self._arb_port_out = self.var("arb_port_out",
                                      self.interconnect_output_ports,
                                      size=(self.banks,
                                            self.mem_output_ports),
                                      explicit_array=True,
                                      packed=True)
        self._arb_valid_out = self.var("arb_valid_out", self.mem_output_ports,
                                       size=self.banks,
                                       explicit_array=True,
                                       packed=True)

        self._rd_sync_gate = self.var("rd_sync_gate",
                                      self.interconnect_output_ports)

        self.arbiters = []
        for i in range(self.banks):
            rw_arb = RWArbiter(fetch_width=self.mem_width,
                               data_width=self.data_width,
                               memory_depth=self.mem_depth,
                               num_tiles=self.num_tiles,
                               int_in_ports=self.interconnect_input_ports,
                               int_out_ports=self.interconnect_output_ports,
                               strg_wr_ports=self.mem_input_ports,
                               strg_rd_ports=self.mem_output_ports,
                               read_delay=self.read_delay,
                               rw_same_cycle=self.rw_same_cycle,
                               separate_addresses=self.rw_same_cycle)
            self.arbiters.append(rw_arb)
            self.add_child(f"rw_arb_{i}", rw_arb,
                           clk=self._clk,
                           rst_n=self._rst_n,
                           wen_in=self._wen_to_arb[i],
                           w_data=self._data_to_arb[i],
                           w_addr=self._addr_to_arb[i],
                           data_from_mem=self._data_from_strg[i],
                           mem_valid_data=self._mem_valid_data[i],
                           out_mem_valid_data=self._out_mem_valid_data[i],
                           ren_en=self._arb_ren_en,
                           rd_addr=self._oac_addr_out,
                           out_data=self._arb_dat_out[i],
                           out_port=self._arb_port_out[i],
                           out_valid=self._arb_valid_out[i],
                           cen_mem=self._cen_to_strg[i],
                           wen_mem=self._wen_to_strg[i],
                           data_to_mem=self._data_to_strg[i],
                           out_ack=self._arb_acks[i])

            # Bind the separate addrs
            if self.rw_same_cycle:
                self.wire(rw_arb.ports.wr_addr_to_mem, self._wr_addr_out[i])
                self.wire(rw_arb.ports.rd_addr_to_mem, self._rd_addr_out[i])
            else:
                self.wire(rw_arb.ports.addr_to_mem, self._addr_out[i])

            if self.remove_tb:
                self.wire(rw_arb.ports.ren_in, self._ren_out[i])
            else:
                self.wire(rw_arb.ports.ren_in, self._ren_out[i] & self._rd_sync_gate)

        self.num_tb_bits = max(1, clog2(self.num_tb))

        self._data_to_sync = self.var("data_to_sync",
                                      self.data_width,
                                      size=(self.interconnect_output_ports,
                                            self.fw_int),
                                      explicit_array=True,
                                      packed=True)

        self._valid_to_sync = self.var("valid_to_sync", self.interconnect_output_ports)

        self._data_to_tba = self.var("data_to_tba",
                                     self.data_width,
                                     size=(self.interconnect_output_ports,
                                           self.fw_int),
                                     explicit_array=True,
                                     packed=True)

        self._valid_to_tba = self.var("valid_to_tba", self.interconnect_output_ports)

        self._data_to_pref = self.var("data_to_pref",
                                      self.data_width,
                                      size=(self.interconnect_output_ports,
                                            self.fw_int),
                                      explicit_array=True,
                                      packed=True)

        self._valid_to_pref = self.var("valid_to_pref", self.interconnect_output_ports)
        #######################
        ##### DEMUX READS #####
        #######################
        dmux_rd = DemuxReads(fetch_width=self.mem_width,
                             data_width=self.data_width,
                             banks=self.banks,
                             int_out_ports=self.interconnect_output_ports,
                             strg_rd_ports=self.mem_output_ports)

        self._arb_dat_out_f = self.var("arb_dat_out_f",
                                       self.data_width,
                                       size=(self.banks * self.mem_output_ports,
                                             self.fw_int),
                                       explicit_array=True,
                                       packed=True)

        self._arb_port_out_f = self.var("arb_port_out_f",
                                        self.interconnect_output_ports,
                                        size=(self.banks * self.mem_output_ports),
                                        explicit_array=True,
                                        packed=True)
        self._arb_valid_out_f = self.var("arb_valid_out_f", self.mem_output_ports * self.banks)
        self._arb_mem_valid_data_f = self.var("arb_mem_valid_data_f", self.mem_output_ports * self.banks)

        self._arb_mem_valid_data_out = self.var("arb_mem_valid_data_out",
                                                self.interconnect_output_ports)

        self._mem_valid_data_sync = self.var("mem_valid_data_sync",
                                             self.interconnect_output_ports)

        self._mem_valid_data_pref = self.var("mem_valid_data_pref",
                                             self.interconnect_output_ports)

        tmp_cnt = 0
        for i in range(self.banks):
            for j in range(self.mem_output_ports):
                self.wire(self._arb_dat_out_f[tmp_cnt], self._arb_dat_out[i][j])
                self.wire(self._arb_port_out_f[tmp_cnt], self._arb_port_out[i][j])
                self.wire(self._arb_valid_out_f[tmp_cnt], self._arb_valid_out[i][j])
                self.wire(self._arb_mem_valid_data_f[tmp_cnt], self._out_mem_valid_data[i][j])
                tmp_cnt = tmp_cnt + 1

        # If this is end of the road...
        if self.remove_tb:
            assert self.fw_int == 1, "Make it easier on me now..."
            self.add_child("demux_rds", dmux_rd,
                           clk=self._clk,
                           rst_n=self._rst_n,
                           data_in=self._arb_dat_out_f,
                           mem_valid_data=self._arb_mem_valid_data_f,
                           mem_valid_data_out=self._arb_mem_valid_data_out,
                           valid_in=self._arb_valid_out_f,
                           port_in=self._arb_port_out_f,
                           valid_out=self._valid_out_alt)
            for i in range(self.interconnect_output_ports):
                self.wire(self._data_out[i], dmux_rd.ports.data_out[i])

        else:
            self.add_child("demux_rds", dmux_rd,
                           clk=self._clk,
                           rst_n=self._rst_n,
                           data_in=self._arb_dat_out_f,
                           mem_valid_data=self._arb_mem_valid_data_f,
                           mem_valid_data_out=self._arb_mem_valid_data_out,
                           valid_in=self._arb_valid_out_f,
                           port_in=self._arb_port_out_f,
                           data_out=self._data_to_sync,
                           valid_out=self._valid_to_sync)

            #######################
            ##### SYNC GROUPS #####
            #######################
            sync_group = SyncGroups(fetch_width=self.mem_width,
                                    data_width=self.data_width,
                                    int_out_ports=self.interconnect_output_ports)

            for i in range(self.interconnect_output_ports):
                self.wire(self._ren_out_reduced[i], self._ren_out_tpose[i].r_or())

            self.add_child("sync_grp", sync_group,
                           clk=self._clk,
                           rst_n=self._rst_n,
                           data_in=self._data_to_sync,
                           mem_valid_data=self._arb_mem_valid_data_out,
                           mem_valid_data_out=self._mem_valid_data_sync,
                           valid_in=self._valid_to_sync,
                           data_out=self._data_to_pref,
                           valid_out=self._valid_to_pref,
                           ren_in=self._ren_out_reduced,
                           rd_sync_gate=self._rd_sync_gate,
                           ack_in=self._ack_reduced)

            # This is the end of the line if we aren't using tb
            ######################
            ##### PREFETCHER #####
            ######################
            prefetchers = []
            for i in range(self.interconnect_output_ports):

                pref = Prefetcher(fetch_width=self.mem_width,
                                  data_width=self.data_width,
                                  max_prefetch=self.max_prefetch)

                prefetchers.append(pref)

                if self.num_tb == 0:
                    assert self.fw_int == 1, \
                        "If no transpose buffer, data width needs match memory width"
                    self.add_child(f"pre_fetch_{i}", pref,
                                   clk=self._clk,
                                   rst_n=self._rst_n,
                                   data_in=self._data_to_pref[i],
                                   mem_valid_data=self._mem_valid_data_sync[i],
                                   mem_valid_data_out=self._mem_valid_data_pref[i],
                                   valid_read=self._valid_to_pref[i],
                                   tba_rdy_in=self._ren[i],
                                   #    data_out=self._data_out[i],
                                   valid_out=self._valid_out_alt[i],
                                   prefetch_step=self._prefetch_step[i])
                    self.wire(self._data_out[i], pref.ports.data_out[0])
                else:
                    self.add_child(f"pre_fetch_{i}", pref,
                                   clk=self._clk,
                                   rst_n=self._rst_n,
                                   data_in=self._data_to_pref[i],
                                   mem_valid_data=self._mem_valid_data_sync[i],
                                   mem_valid_data_out=self._mem_valid_data_pref[i],
                                   valid_read=self._valid_to_pref[i],
                                   tba_rdy_in=self._ready_tba[i],
                                   data_out=self._data_to_tba[i],
                                   valid_out=self._valid_to_tba[i],
                                   prefetch_step=self._prefetch_step[i])

                    #############################
                    ##### TRANSPOSE BUFFERS #####
                    #############################
            if self.num_tb > 0:

                self._tb_data_out = self.var("tb_data_out", self.data_width,
                                             size=self.interconnect_output_ports,
                                             explicit_array=True,
                                             packed=True)
                self._tb_valid_out = self.var("tb_valid_out", self.interconnect_output_ports)

                for i in range(self.interconnect_output_ports):

                    tba = TransposeBufferAggregation(word_width=self.data_width,
                                                     fetch_width=self.fw_int,
                                                     num_tb=self.num_tb,
                                                     max_tb_height=self.max_tb_height,
                                                     max_range=self.tb_range_max,
                                                     max_range_inner=self.tb_range_inner_max,
                                                     max_stride=self.max_tb_stride,
                                                     tb_iterator_support=self.tb_iterator_support)

                    self.add_child(f"tba_{i}", tba,
                                   clk=self._clk,
                                   rst_n=self._rst_n,
                                   SRAM_to_tb_data=self._data_to_tba[i],
                                   valid_data=self._valid_to_tba[i],
                                   tb_index_for_data=0,
                                   ack_in=self._valid_to_tba[i],
                                   mem_valid_data=self._mem_valid_data_pref[i],
                                   tb_to_interconnect_data=self._tb_data_out[i],
                                   tb_to_interconnect_valid=self._tb_valid_out[i],
                                   tb_arbiter_rdy=self._ready_tba[i],
                                   tba_ren=self._ren[i])

                for i in range(self.interconnect_output_ports):
                    self.wire(self._data_out[i], self._tb_data_out[i])
                    # self.wire(self._valid_out[i], self._tb_valid_out[i])
            else:
                self.wire(self._valid_out, self._valid_out_alt)

        ####################
        ##### ADD CODE #####
        ####################
        self.add_code(self.transpose_acks)
        self.add_code(self.reduce_acks)
예제 #14
0
    def __init__(self, iterator_support=6, config_width=16):

        super().__init__(f"addr_gen_{iterator_support}_{config_width}",
                         debug=True)

        # Store local...
        self.iterator_support = iterator_support
        self.config_width = config_width

        # PORT DEFS: begin

        # INPUTS
        self._clk = self.clock("clk")
        self._rst_n = self.reset("rst_n")

        self._restart = self.input("restart", 1)

        self._strides = self.input("strides",
                                   self.config_width,
                                   size=self.iterator_support,
                                   packed=True,
                                   explicit_array=True)
        self._strides.add_attribute(
            ConfigRegAttr("Strides of address generator"))
        self._strides.add_attribute(
            FormalAttr(f"{self._strides.name}", FormalSignalConstraint.SOLVE))

        self._starting_addr = self.input("starting_addr", self.config_width)
        self._starting_addr.add_attribute(
            ConfigRegAttr("Starting address of address generator"))
        self._starting_addr.add_attribute(
            FormalAttr(f"{self._starting_addr.name}",
                       FormalSignalConstraint.SOLVE))

        self._step = self.input("step", 1)

        self._mux_sel = self.input("mux_sel",
                                   max(clog2(self.iterator_support), 1))

        # OUTPUTS
        # TODO why is this config width instead of address width?
        self._addr_out = self.output("addr_out", self.config_width)

        # PORT DEFS: end

        # LOCAL VARIABLES: begin
        self._strt_addr = self.var("strt_addr", self.config_width)

        self._calc_addr = self.var("calc_addr", self.config_width)

        self._max_value = self.var("max_value", self.iterator_support)

        # LOCAL VARIABLES: end
        # GENERATION LOGIC: begin
        self.wire(self._strt_addr, self._starting_addr)
        self.wire(self._addr_out, self._calc_addr)

        self._current_addr = self.var("current_addr", self.config_width)
        # Calculate address by taking previous calculation and adding the muxed stride
        self.wire(self._calc_addr, self._strt_addr + self._current_addr)

        self.add_code(self.calculate_address)
예제 #15
0
    def __init__(self,
                 interconnect_input_ports,
                 interconnect_output_ports,
                 depth_width=16,
                 sprt_stcl_valid=False,
                 stcl_cnt_width=16,
                 stcl_iter_support=4):
        super().__init__("app_ctrl", debug=True)

        self.int_in_ports = interconnect_input_ports
        self.int_out_ports = interconnect_output_ports
        self.depth_width = depth_width
        self.sprt_stcl_valid = sprt_stcl_valid
        self.stcl_cnt_width = stcl_cnt_width
        self.stcl_iter_support = stcl_iter_support

        # Clock and Reset
        self._clk = self.clock("clk")
        self._rst_n = self.reset("rst_n")

        # IO
        self._wen_in = self.input("wen_in", self.int_in_ports)
        self._ren_in = self.input("ren_in", self.int_out_ports)

        self._ren_update = self.input("ren_update", self.int_out_ports)

        self._tb_valid = self.input("tb_valid", self.int_out_ports)

        self._valid_out_data = self.output("valid_out_data", self.int_out_ports)

        self._valid_out_stencil = self.output("valid_out_stencil", self.int_out_ports)

        # Send tb valid to valid out for now...

        if self.sprt_stcl_valid:
            # Add the config registers to watch
            self._ranges = self.input("ranges", self.stcl_cnt_width,
                                      size=self.stcl_iter_support,
                                      packed=True,
                                      explicit_array=True)
            self._ranges.add_attribute(ConfigRegAttr("Ranges of stencil valid generator"))

            self._threshold = self.input("threshold", self.stcl_cnt_width,
                                         size=self.stcl_iter_support,
                                         packed=True,
                                         explicit_array=True)
            self._threshold.add_attribute(ConfigRegAttr("Threshold of stencil valid generator"))

            self._dim_counter = self.var("dim_counter", self.stcl_cnt_width,
                                         size=self.stcl_iter_support,
                                         packed=True,
                                         explicit_array=True)

            self._update = self.var("update", self.stcl_iter_support)

            self.wire(self._update[0], const(1, 1))
            for i in range(self.stcl_iter_support - 1):
                self.wire(self._update[i + 1],
                          (self._dim_counter[i] == (self._ranges[i] - 1)) & self._update[i])

            for i in range(self.stcl_iter_support):
                self.add_code(self.dim_counter_update, idx=i)

            # Now we need to just compute stencil valid
            threshold_comps = [self._dim_counter[_i] >= self._threshold[_i] for _i in range(self.stcl_iter_support)]
            self.wire(self._valid_out_stencil[0], kts.concat(*threshold_comps).r_and())
            for i in range(self.int_out_ports - 1):
                # self.wire(self._valid_out_stencil[i + 1], 0)
                # for multiple ports
                self.wire(self._valid_out_stencil[i + 1], kts.concat(*threshold_comps).r_and())

        else:
            self.wire(self._valid_out_stencil, self._tb_valid)

        # Now gate the valid with stencil valid
        self.wire(self._valid_out_data, self._tb_valid & self._valid_out_stencil)
        self._wr_delay_state_n = self.var("wr_delay_state_n", self.int_out_ports)
        self._wen_out = self.output("wen_out", self.int_in_ports)
        self._ren_out = self.output("ren_out", self.int_out_ports)

        self._write_depth_wo = self.input("write_depth_wo", self.depth_width,
                                          size=self.int_in_ports,
                                          explicit_array=True,
                                          packed=True)
        self._write_depth_wo.add_attribute(ConfigRegAttr("Depth of writes"))

        self._write_depth_ss = self.input("write_depth_ss", self.depth_width,
                                          size=self.int_in_ports,
                                          explicit_array=True,
                                          packed=True)
        self._write_depth_ss.add_attribute(ConfigRegAttr("Depth of writes"))

        self._write_depth = self.var("write_depth", self.depth_width,
                                     size=self.int_in_ports,
                                     explicit_array=True,
                                     packed=True)

        for i in range(self.int_in_ports):
            self.wire(self._write_depth[i],
                      kts.ternary(self._wr_delay_state_n[i],
                                  self._write_depth_ss[i],
                                  self._write_depth_wo[i]))

        self._read_depth = self.input("read_depth", self.depth_width,
                                      size=self.int_out_ports,
                                      explicit_array=True,
                                      packed=True)
        self._read_depth.add_attribute(ConfigRegAttr("Depth of reads"))

        self._write_count = self.var("write_count", self.depth_width,
                                     size=self.int_in_ports,
                                     explicit_array=True,
                                     packed=True)

        self._read_count = self.var("read_count", self.depth_width,
                                    size=self.int_out_ports,
                                    explicit_array=True,
                                    packed=True)
        self._write_done = self.var("write_done", self.int_in_ports)
        self._write_done_ff = self.var("write_done_ff", self.int_in_ports)
        self._read_done = self.var("read_done", self.int_out_ports)
        self._read_done_ff = self.var("read_done_ff", self.int_out_ports)

        self.in_port_bits = max(1, kts.clog2(self.int_in_ports))
        self._input_port = self.input("input_port", self.in_port_bits,
                                      size=self.int_out_ports,
                                      explicit_array=True,
                                      packed=True)
        self._input_port.add_attribute(ConfigRegAttr("Relative input port for an output port"))

        self.out_port_bits = max(1, kts.clog2(self.int_out_ports))
        self._output_port = self.input("output_port", self.out_port_bits,
                                       size=self.int_in_ports,
                                       explicit_array=True,
                                       packed=True)
        self._output_port.add_attribute(ConfigRegAttr("Relative output port for an input port"))

        self._prefill = self.input("prefill", self.int_out_ports)
        self._prefill.add_attribute(ConfigRegAttr("Is the input stream prewritten?"))

        for i in range(self.int_out_ports):
            self.add_code(self.set_read_done, idx=i)
            if self.int_in_ports == 1:
                self.add_code(self.set_read_done_ff_one_wr, idx=i)
            else:
                self.add_code(self.set_read_done_ff, idx=i)

        # self._write_done_comb = self.var("write_done_comb", self.int_in_ports)
        for i in range(self.int_in_ports):
            self.add_code(self.set_write_done, idx=i)
            self.add_code(self.set_write_done_ff, idx=i)

        for i in range(self.int_in_ports):
            self.add_code(self.set_write_cnt, idx=i)
        for i in range(self.int_out_ports):
            if self.int_in_ports == 1:
                self.add_code(self.set_read_cnt_one_wr, idx=i)
            else:
                self.add_code(self.set_read_cnt, idx=i)

        for i in range(self.int_out_ports):
            if self.int_in_ports == 1:
                self.add_code(self.set_wr_delay_state_one_wr, idx=i)
            else:
                self.add_code(self.set_wr_delay_state, idx=i)

        self._read_on = self.var("read_on", self.int_out_ports)
        for i in range(self.int_out_ports):
            self.wire(self._read_on[i], self._read_depth[i].r_or())

        # If we have prefill enabled, we are skipping the initial delay step...
        self.wire(self._ren_out,
                  (self._wr_delay_state_n | self._prefill) & ~self._read_done_ff & self._ren_in &
                  self._read_on)
        self.wire(self._wen_out, ~self._write_done_ff & self._wen_in)
예제 #16
0
파일: util.py 프로젝트: StanfordAHA/lake
def add_config_reg(generator, name, description, bitwidth, **kwargs):
    cfg_reg = generator.input(name, bitwidth, **kwargs)
    cfg_reg.add_attribute(ConfigRegAttr(description))
    return cfg_reg
예제 #17
0
    def __init__(self,
                 fetch_width,
                 data_width,
                 max_prefetch):
        super().__init__("prefetcher")
        # Capture to the object
        self.fetch_width = fetch_width
        self.data_width = data_width
        self.fw_int = int(self.fetch_width / self.data_width)
        self.max_prefetch = max_prefetch

        # Clock and Reset
        self._clk = self.clock("clk")
        self._rst_n = self.reset("rst_n")

        # Inputs
        self._data_in = self.input("data_in",
                                   self.data_width,
                                   size=self.fw_int,
                                   explicit_array=True,
                                   packed=True)

        self._valid_read = self.input("valid_read", 1)
        self._tba_rdy_in = self.input("tba_rdy_in", 1)

        self._input_latency = self.input("input_latency", clog2(self.max_prefetch) + 1)
        doc = "This register is set to denote the input latency loop for reads. " + \
              "This is sent to an internal fifo and an almost full signal is " + \
              "used to pull more reads that the transpose buffers need."
        self._input_latency.add_attribute(ConfigRegAttr(doc))

        self._max_lat = const(self.max_prefetch - 1,
                              clog2(self.max_prefetch) + 1)

        # Outputs
        self._data_out = self.output("data_out",
                                     self.data_width,
                                     size=self.fw_int,
                                     explicit_array=True,
                                     packed=True)

        self._valid_out = self.output("valid_out", 1)
        self._prefetch_step = self.output("prefetch_step", 1)

        # Local Signals
        self._cnt = self.var("cnt", clog2(self.max_prefetch) + 1)
        self._fifo_empty = self.var("fifo_empty", 1)
        self._fifo_full = self.var("fifo_full", 1)

        reg_fifo = RegFIFO(data_width=self.data_width,
                           width_mult=self.fw_int,
                           depth=self.max_prefetch)

        self.add_child("fifo", reg_fifo,
                       clk=self._clk,
                       rst_n=self._rst_n,
                       clk_en=1,
                       data_in=self._data_in,
                       data_out=self._data_out,
                       push=self._valid_read,
                       pop=self._tba_rdy_in,
                       empty=self._fifo_empty,
                       full=self._fifo_full,
                       valid=self._valid_out)

        # Generate
        self.add_code(self.update_cnt)
        self.add_code(self.set_prefetch_step)
예제 #18
0
    def __init__(
            self,
            data_width=16,  # CGRA Params
            mem_depth=32,
            default_iterator_support=3,
            interconnect_input_ports=2,  # Connection to int
            interconnect_output_ports=2,
            mem_input_ports=1,
            mem_output_ports=1,
            config_data_width=32,
            config_addr_width=8,
            cycle_count_width=16,
            add_clk_enable=True,
            add_flush=True):
        super().__init__("pond", debug=True)

        self.interconnect_input_ports = interconnect_input_ports
        self.interconnect_output_ports = interconnect_output_ports
        self.mem_input_ports = mem_input_ports
        self.mem_output_ports = mem_output_ports
        self.mem_depth = mem_depth
        self.data_width = data_width
        self.config_data_width = config_data_width
        self.config_addr_width = config_addr_width
        self.add_clk_enable = add_clk_enable
        self.add_flush = add_flush
        self.cycle_count_width = cycle_count_width
        self.default_iterator_support = default_iterator_support
        self.default_config_width = kts.clog2(self.mem_depth)
        # inputs
        self._clk = self.clock("clk")
        self._clk.add_attribute(
            FormalAttr(f"{self._clk.name}", FormalSignalConstraint.CLK))
        self._rst_n = self.reset("rst_n")
        self._rst_n.add_attribute(
            FormalAttr(f"{self._rst_n.name}", FormalSignalConstraint.RSTN))
        self._clk_en = self.clock_en("clk_en", 1)

        # Enable/Disable tile
        self._tile_en = self.input("tile_en", 1)
        self._tile_en.add_attribute(
            ConfigRegAttr("Tile logic enable manifested as clock gate"))

        gclk = self.var("gclk", 1)
        self._gclk = kts.util.clock(gclk)
        self.wire(gclk, kts.util.clock(self._clk & self._tile_en))

        self._cycle_count = add_counter(self, "cycle_count",
                                        self.cycle_count_width)

        # Create write enable + addr, same for read.
        # self._write = self.input("write", self.interconnect_input_ports)
        self._write = self.var("write", self.mem_input_ports)
        # self._write.add_attribute(ControlSignalAttr(is_control=True))

        self._write_addr = self.var("write_addr",
                                    kts.clog2(self.mem_depth),
                                    size=self.interconnect_input_ports,
                                    explicit_array=True,
                                    packed=True)

        # Add "_pond" suffix to avoid error during garnet RTL generation
        self._data_in = self.input("data_in_pond",
                                   self.data_width,
                                   size=self.interconnect_input_ports,
                                   explicit_array=True,
                                   packed=True)
        self._data_in.add_attribute(
            FormalAttr(f"{self._data_in.name}",
                       FormalSignalConstraint.SEQUENCE))
        self._data_in.add_attribute(ControlSignalAttr(is_control=False))

        self._read = self.var("read", self.mem_output_ports)
        self._t_write = self.var("t_write", self.interconnect_input_ports)
        self._t_read = self.var("t_read", self.interconnect_output_ports)
        # self._read.add_attribute(ControlSignalAttr(is_control=True))

        self._read_addr = self.var("read_addr",
                                   kts.clog2(self.mem_depth),
                                   size=self.interconnect_output_ports,
                                   explicit_array=True,
                                   packed=True)

        self._s_read_addr = self.var("s_read_addr",
                                     kts.clog2(self.mem_depth),
                                     size=self.interconnect_output_ports,
                                     explicit_array=True,
                                     packed=True)

        self._data_out = self.output("data_out_pond",
                                     self.data_width,
                                     size=self.interconnect_output_ports,
                                     explicit_array=True,
                                     packed=True)
        self._data_out.add_attribute(
            FormalAttr(f"{self._data_out.name}",
                       FormalSignalConstraint.SEQUENCE))
        self._data_out.add_attribute(ControlSignalAttr(is_control=False))

        self._valid_out = self.output("valid_out_pond",
                                      self.interconnect_output_ports)
        self._valid_out.add_attribute(
            FormalAttr(f"{self._valid_out.name}",
                       FormalSignalConstraint.SEQUENCE))
        self._valid_out.add_attribute(ControlSignalAttr(is_control=False))

        self._mem_data_out = self.var("mem_data_out",
                                      self.data_width,
                                      size=self.mem_output_ports,
                                      explicit_array=True,
                                      packed=True)

        self._s_mem_data_in = self.var("s_mem_data_in",
                                       self.data_width,
                                       size=self.interconnect_input_ports,
                                       explicit_array=True,
                                       packed=True)

        self._mem_data_in = self.var("mem_data_in",
                                     self.data_width,
                                     size=self.mem_input_ports,
                                     explicit_array=True,
                                     packed=True)

        self._s_mem_write_addr = self.var("s_mem_write_addr",
                                          kts.clog2(self.mem_depth),
                                          size=self.interconnect_input_ports,
                                          explicit_array=True,
                                          packed=True)

        self._s_mem_read_addr = self.var("s_mem_read_addr",
                                         kts.clog2(self.mem_depth),
                                         size=self.interconnect_output_ports,
                                         explicit_array=True,
                                         packed=True)

        self._mem_write_addr = self.var("mem_write_addr",
                                        kts.clog2(self.mem_depth),
                                        size=self.mem_input_ports,
                                        explicit_array=True,
                                        packed=True)

        self._mem_read_addr = self.var("mem_read_addr",
                                       kts.clog2(self.mem_depth),
                                       size=self.mem_output_ports,
                                       explicit_array=True,
                                       packed=True)

        if self.interconnect_output_ports == 1:
            self.wire(self._data_out[0], self._mem_data_out[0])
        else:
            for i in range(self.interconnect_output_ports):
                self.wire(self._data_out[i], self._mem_data_out[0])

        # Valid out is simply passing the read signal through...
        self.wire(self._valid_out, self._t_read)

        # Create write addressors
        for wr_port in range(self.interconnect_input_ports):

            RF_WRITE_ITER = ForLoop(
                iterator_support=self.default_iterator_support,
                config_width=self.cycle_count_width)
            RF_WRITE_ADDR = AddrGen(
                iterator_support=self.default_iterator_support,
                config_width=self.default_config_width)
            RF_WRITE_SCHED = SchedGen(
                iterator_support=self.default_iterator_support,
                config_width=self.cycle_count_width,
                use_enable=True)

            self.add_child(f"rf_write_iter_{wr_port}",
                           RF_WRITE_ITER,
                           clk=self._gclk,
                           rst_n=self._rst_n,
                           step=self._t_write[wr_port])
            # Whatever comes through here should hopefully just pipe through seamlessly
            # addressor modules
            self.add_child(f"rf_write_addr_{wr_port}",
                           RF_WRITE_ADDR,
                           clk=self._gclk,
                           rst_n=self._rst_n,
                           step=self._t_write[wr_port],
                           mux_sel=RF_WRITE_ITER.ports.mux_sel_out,
                           restart=RF_WRITE_ITER.ports.restart)
            safe_wire(self, self._write_addr[wr_port],
                      RF_WRITE_ADDR.ports.addr_out)

            self.add_child(f"rf_write_sched_{wr_port}",
                           RF_WRITE_SCHED,
                           clk=self._gclk,
                           rst_n=self._rst_n,
                           mux_sel=RF_WRITE_ITER.ports.mux_sel_out,
                           finished=RF_WRITE_ITER.ports.restart,
                           cycle_count=self._cycle_count,
                           valid_output=self._t_write[wr_port])

        # Create read addressors
        for rd_port in range(self.interconnect_output_ports):

            RF_READ_ITER = ForLoop(
                iterator_support=self.default_iterator_support,
                config_width=self.cycle_count_width)
            RF_READ_ADDR = AddrGen(
                iterator_support=self.default_iterator_support,
                config_width=self.default_config_width)
            RF_READ_SCHED = SchedGen(
                iterator_support=self.default_iterator_support,
                config_width=self.cycle_count_width,
                use_enable=True)

            self.add_child(f"rf_read_iter_{rd_port}",
                           RF_READ_ITER,
                           clk=self._gclk,
                           rst_n=self._rst_n,
                           step=self._t_read[rd_port])

            self.add_child(f"rf_read_addr_{rd_port}",
                           RF_READ_ADDR,
                           clk=self._gclk,
                           rst_n=self._rst_n,
                           step=self._t_read[rd_port],
                           mux_sel=RF_READ_ITER.ports.mux_sel_out,
                           restart=RF_READ_ITER.ports.restart)
            if self.interconnect_output_ports > 1:
                safe_wire(self, self._read_addr[rd_port],
                          RF_READ_ADDR.ports.addr_out)
            else:
                safe_wire(self, self._read_addr[rd_port],
                          RF_READ_ADDR.ports.addr_out)

            self.add_child(f"rf_read_sched_{rd_port}",
                           RF_READ_SCHED,
                           clk=self._gclk,
                           rst_n=self._rst_n,
                           mux_sel=RF_READ_ITER.ports.mux_sel_out,
                           finished=RF_READ_ITER.ports.restart,
                           cycle_count=self._cycle_count,
                           valid_output=self._t_read[rd_port])

        self.wire(self._write, self._t_write.r_or())
        self.wire(self._mem_write_addr[0],
                  decode(self, self._t_write, self._s_mem_write_addr))

        self.wire(self._mem_data_in[0],
                  decode(self, self._t_write, self._s_mem_data_in))

        self.wire(self._read, self._t_read.r_or())
        self.wire(self._mem_read_addr[0],
                  decode(self, self._t_read, self._s_mem_read_addr))
        # ===================================
        # Instantiate config hooks...
        # ===================================
        self.fw_int = 1
        self.data_words_per_set = 2**self.config_addr_width
        self.sets = int(
            (self.fw_int * self.mem_depth) / self.data_words_per_set)

        self.sets_per_macro = max(
            1, int(self.mem_depth / self.data_words_per_set))
        self.total_sets = max(1, 1 * self.sets_per_macro)

        self._config_data_in = self.input("config_data_in",
                                          self.config_data_width)
        self._config_data_in.add_attribute(ControlSignalAttr(is_control=False))

        self._config_data_in_shrt = self.var("config_data_in_shrt",
                                             self.data_width)

        self.wire(self._config_data_in_shrt,
                  self._config_data_in[self.data_width - 1, 0])

        self._config_addr_in = self.input("config_addr_in",
                                          self.config_addr_width)
        self._config_addr_in.add_attribute(ControlSignalAttr(is_control=False))

        self._config_data_out_shrt = self.var("config_data_out_shrt",
                                              self.data_width,
                                              size=self.total_sets,
                                              explicit_array=True,
                                              packed=True)

        self._config_data_out = self.output("config_data_out",
                                            self.config_data_width,
                                            size=self.total_sets,
                                            explicit_array=True,
                                            packed=True)
        self._config_data_out.add_attribute(
            ControlSignalAttr(is_control=False))

        for i in range(self.total_sets):
            self.wire(
                self._config_data_out[i],
                self._config_data_out_shrt[i].extend(self.config_data_width))

        self._config_read = self.input("config_read", 1)
        self._config_read.add_attribute(ControlSignalAttr(is_control=False))

        self._config_write = self.input("config_write", 1)
        self._config_write.add_attribute(ControlSignalAttr(is_control=False))

        self._config_en = self.input("config_en", self.total_sets)
        self._config_en.add_attribute(ControlSignalAttr(is_control=False))

        self._mem_data_cfg = self.var("mem_data_cfg",
                                      self.data_width,
                                      explicit_array=True,
                                      packed=True)

        self._mem_addr_cfg = self.var("mem_addr_cfg",
                                      kts.clog2(self.mem_depth))

        # Add config...
        stg_cfg_seq = StorageConfigSeq(
            data_width=self.data_width,
            config_addr_width=self.config_addr_width,
            addr_width=kts.clog2(self.mem_depth),
            fetch_width=self.data_width,
            total_sets=self.total_sets,
            sets_per_macro=self.sets_per_macro)

        # The clock to config sequencer needs to be the normal clock or
        # if the tile is off, we bring the clock back in based on config_en
        cfg_seq_clk = self.var("cfg_seq_clk", 1)
        self._cfg_seq_clk = kts.util.clock(cfg_seq_clk)
        self.wire(cfg_seq_clk, kts.util.clock(self._gclk))

        self.add_child(f"config_seq",
                       stg_cfg_seq,
                       clk=self._cfg_seq_clk,
                       rst_n=self._rst_n,
                       clk_en=self._clk_en | self._config_en.r_or(),
                       config_data_in=self._config_data_in_shrt,
                       config_addr_in=self._config_addr_in,
                       config_wr=self._config_write,
                       config_rd=self._config_read,
                       config_en=self._config_en,
                       wr_data=self._mem_data_cfg,
                       rd_data_out=self._config_data_out_shrt,
                       addr_out=self._mem_addr_cfg)

        if self.interconnect_output_ports == 1:
            self.wire(stg_cfg_seq.ports.rd_data_stg, self._mem_data_out)
        else:
            self.wire(stg_cfg_seq.ports.rd_data_stg[0], self._mem_data_out[0])

        self.RF_GEN = RegisterFile(data_width=self.data_width,
                                   write_ports=self.mem_input_ports,
                                   read_ports=self.mem_output_ports,
                                   width_mult=1,
                                   depth=self.mem_depth,
                                   read_delay=0)

        # Now we can instantiate and wire up the register file
        self.add_child(f"rf",
                       self.RF_GEN,
                       clk=self._gclk,
                       rst_n=self._rst_n,
                       data_out=self._mem_data_out)

        # Opt in for config_write
        self._write_rf = self.var("write_rf", self.mem_input_ports)
        self.wire(
            self._write_rf[0],
            kts.ternary(self._config_en.r_or(), self._config_write,
                        self._write[0]))
        for i in range(self.mem_input_ports - 1):
            self.wire(
                self._write_rf[i + 1],
                kts.ternary(self._config_en.r_or(), kts.const(0, 1),
                            self._write[i + 1]))
        self.wire(self.RF_GEN.ports.wen, self._write_rf)

        # Opt in for config_data_in
        for i in range(self.interconnect_input_ports):
            self.wire(
                self._s_mem_data_in[i],
                kts.ternary(self._config_en.r_or(), self._mem_data_cfg,
                            self._data_in[i]))
        self.wire(self.RF_GEN.ports.data_in, self._mem_data_in)

        # Opt in for config_addr
        for i in range(self.interconnect_input_ports):
            self.wire(
                self._s_mem_write_addr[i],
                kts.ternary(self._config_en.r_or(), self._mem_addr_cfg,
                            self._write_addr[i]))

        self.wire(self.RF_GEN.ports.wr_addr, self._mem_write_addr[0])

        for i in range(self.interconnect_output_ports):
            self.wire(
                self._s_mem_read_addr[i],
                kts.ternary(self._config_en.r_or(), self._mem_addr_cfg,
                            self._read_addr[i]))

        self.wire(self.RF_GEN.ports.rd_addr, self._mem_read_addr[0])

        if self.add_clk_enable:
            # self.clock_en("clk_en")
            kts.passes.auto_insert_clock_enable(self.internal_generator)
            clk_en_port = self.internal_generator.get_port("clk_en")
            clk_en_port.add_attribute(ControlSignalAttr(False))

        if self.add_flush:
            self.add_attribute("sync-reset=flush")
            kts.passes.auto_insert_sync_reset(self.internal_generator)
            flush_port = self.internal_generator.get_port("flush")
            flush_port.add_attribute(ControlSignalAttr(True))

        # Finally, lift the config regs...
        lift_config_reg(self.internal_generator)
예제 #19
0
    def __init__(self,
                 interconnect_input_ports=2,
                 mem_depth=32,
                 num_tiles=1,
                 banks=1,
                 iterator_support=6,
                 address_width=5,
                 data_width=16,
                 fetch_width=16,
                 multiwrite=1,
                 strg_wr_ports=2,
                 config_width=16):
        super().__init__("input_addr_ctrl", debug=True)

        assert multiwrite >= 1, "Multiwrite must be at least 1..."

        self.interconnect_input_ports = interconnect_input_ports
        self.mem_depth = mem_depth
        self.num_tiles = num_tiles
        self.banks = banks
        self.iterator_support = iterator_support
        self.address_width = address_width
        self.port_sched_width = max(1, clog2(self.interconnect_input_ports))
        self.data_width = data_width
        self.fetch_width = fetch_width
        self.fw_int = int(self.fetch_width / self.data_width)
        self.multiwrite = multiwrite
        self.strg_wr_ports = strg_wr_ports
        self.config_width = config_width

        self.mem_addr_width = clog2(self.num_tiles * self.mem_depth)
        if self.banks > 1:
            self.bank_addr_width = clog2(self.banks)
        else:
            self.bank_addr_width = 0
        self.address_width = self.mem_addr_width + self.bank_addr_width

        # Clock and Reset
        self._clk = self.clock("clk")
        self._rst_n = self.reset("rst_n")

        # Inputs
        # phases = [] TODO
        # Take in the valid and data and attach an address + direct to a port
        self._valid_in = self.input("valid_in", self.interconnect_input_ports)
        self._wen_en = self.input("wen_en", self.interconnect_input_ports)

        self._wen_en_saved = self.var("wen_en_saved",
                                      self.interconnect_input_ports)

        self._data_in = self.input("data_in",
                                   self.data_width,
                                   size=(self.interconnect_input_ports,
                                         self.fw_int),
                                   explicit_array=True,
                                   packed=True)

        self._data_in_saved = self.var("data_in_saved",
                                       self.data_width,
                                       size=(self.interconnect_input_ports,
                                             self.fw_int),
                                       explicit_array=True,
                                       packed=True)

        # Outputs
        self._wen = self.output("wen_to_sram",
                                self.strg_wr_ports,
                                size=self.banks,
                                explicit_array=True,
                                packed=True)

        wen_full_size = (self.interconnect_input_ports, self.multiwrite)
        self._wen_full = self.var("wen_full",
                                  self.banks,
                                  size=wen_full_size,
                                  explicit_array=True,
                                  packed=True)

        self._wen_reduced = self.var("wen_reduced",
                                     self.banks,
                                     size=self.interconnect_input_ports,
                                     explicit_array=True,
                                     packed=True)

        self._wen_reduced_saved = self.var("wen_reduced_saved",
                                           self.banks,
                                           size=self.interconnect_input_ports,
                                           explicit_array=True,
                                           packed=True)

        self._addresses = self.output("addr_out",
                                      self.mem_addr_width,
                                      size=(self.banks, self.strg_wr_ports),
                                      explicit_array=True,
                                      packed=True)

        self._data_out = self.output("data_out",
                                     self.data_width,
                                     size=(self.banks, self.strg_wr_ports,
                                           self.fw_int),
                                     explicit_array=True,
                                     packed=True)

        self._port_out_exp = self.var("port_out_exp",
                                      self.interconnect_input_ports,
                                      size=self.banks,
                                      explicit_array=True,
                                      packed=True)

        self._port_out = self.output("port_out", self.interconnect_input_ports)

        self._counter = self.var("counter", self.port_sched_width)

        # Wire to port out
        for i in range(self.interconnect_input_ports):
            new_tmp = []
            for j in range(self.banks):
                new_tmp.append(self._port_out_exp[j][i])
            self.wire(self._port_out[i], kts.concat(*new_tmp).r_or())

        self._done = self.var("done",
                              self.strg_wr_ports,
                              size=self.banks,
                              explicit_array=True,
                              packed=True)

        # LOCAL VARS
        self._local_addrs = self.var("local_addrs",
                                     self.address_width,
                                     size=(self.interconnect_input_ports,
                                           self.multiwrite),
                                     packed=True,
                                     explicit_array=True)

        self._local_addrs_saved = self.var("local_addrs_saved",
                                           self.address_width,
                                           size=(self.interconnect_input_ports,
                                                 self.multiwrite),
                                           packed=True,
                                           explicit_array=True)

        for i in range(self.interconnect_input_ports):
            for j in range(self.banks):
                concat_ports = []
                for k in range(self.multiwrite):
                    concat_ports.append(self._wen_full[i][k][j])
                self.wire(self._wen_reduced[i][j],
                          kts.concat(*concat_ports).r_or())

        if self.banks == 1 and self.interconnect_input_ports == 1:
            self.wire(self._wen_full[0][0][0], self._valid_in)
        elif self.banks == 1 and self.interconnect_input_ports > 1:
            self.add_code(self.set_wen_single)
        else:
            self.add_code(self.set_wen_mult)

        # MAIN
        # Iterate through all banks to priority decode the wen
        self.add_code(self.decode_out_lowest)
        # Also set the write ports on the storage
        if self.strg_wr_ports > 1:
            self._idx_cnt = self.var("idx_cnt",
                                     8,
                                     size=(self.banks, self.strg_wr_ports - 1),
                                     explicit_array=True,
                                     packed=True)
            for i in range(self.strg_wr_ports - 1):
                self.add_code(self.decode_out_alt, idx=i + 1)

        # Now we should instantiate the child address generators
        # (1 per input port) to send to the sram banks
        for i in range(self.interconnect_input_ports):
            self.add_child(f"address_gen_{i}",
                           AddrGen(iterator_support=self.iterator_support,
                                   config_width=self.config_width),
                           clk=self._clk,
                           rst_n=self._rst_n,
                           clk_en=const(1, 1),
                           flush=const(0, 1),
                           step=self._valid_in[i])

        # Need to check that the address falls into the bank for implicit banking

        # Then, obey the input schedule to send the proper Aggregator to the output
        # The wen to sram should be that the valid for the selected port is high
        # Do the same thing for the output address
        assert self.multiwrite <= self.banks and self.multiwrite > 0,\
            "Multiwrite should be between 1 and banks"
        if self.multiwrite > 1:
            size = (self.interconnect_input_ports, self.multiwrite - 1)
            self._offsets_cfg = self.input("offsets_cfg",
                                           self.address_width,
                                           size=size,
                                           packed=True,
                                           explicit_array=True)
            doc = "These offsets provide the ability to write to multiple banks explicitly"
            self._offsets_cfg.add_attribute(ConfigRegAttr(doc))
        self.add_code(self.set_multiwrite_addrs)

        # to handle multiple input ports going to fewer SRAM write ports
        self.add_code(self.set_int_ports_counter)
        self.add_code(self.save_mult_int_signals)