def __init__(self, data_width, max_line_length): super().__init__("agg_aligner") # Capture to the object self.data_width = data_width self.max_line_length = max_line_length self.counter_width = clog2(self.max_line_length) # Clock and Reset self._clk = self.clock("clk") self._rst_n = self.reset("rst_n") # Inputs self._in_dat = self.input("in_dat", self.data_width) self._in_valid = self.input("in_valid", 1) self._line_length = self.input("line_length", self.counter_width) self._line_length.add_attribute( ConfigRegAttr("Line Length/Image Width for alignment")) # Outputs self._out_dat = self.output("out_dat", self.data_width) self._out_valid = self.output("out_valid", 1) self._out_align = self.output("align", 1) # Local Signals self._cnt = self.var("cnt", self.counter_width) # Generate self.add_code(self.update_cnt) self.add_code(self.set_align) self.wire(self._out_dat, self._in_dat) self.wire(self._out_valid, self._in_valid)
def visit(self, node): if isinstance(node, _kratos.Generator): ports_ = node.get_port_names() for port_name in ports_: curr_port = node.get_port(port_name) attrs = curr_port.find_attribute( lambda a: isinstance(a, ConfigRegAttr)) annotation_attr = curr_port.find_attribute( lambda a: isinstance(a, FormalAttr)) if port_name is "mode": print("Found mode...") print(attrs) if len(attrs) != 1: continue cr_attr = attrs[0] doc = cr_attr.get_documentation() # need to wire it to the instances parent and up gen = node parent_gen = gen.parent_generator() child_port = curr_port child_gen = gen top_lvl_cfg = parent_gen is None while parent_gen is not None: # create a port based on the target's definition new_name = child_gen.instance_name + "_" + child_port.name p = parent_gen.port(child_port, new_name, False) parent_gen.wire(child_port, p) # move up the hierarchy child_port = p child_gen = parent_gen parent_gen = parent_gen.parent_generator() # Only add the attribute if this is a newly created port, not a top-level cfg reg if top_lvl_cfg is False: child_port_cra = ConfigRegAttr() child_port_cra.set_documentation(doc) child_port.add_attribute(child_port_cra) if annotation_attr: ann_att = annotation_attr[0] annot_type = ann_att.get_formal_ann() child_port.add_attribute( FormalAttr(f"{child_port}", annot_type))
def __init__(self, name, data_width, init_val): super().__init__(f"config_reg_{name}_{data_width}") self.add_attribute(ConfigRegAttr()) self.init_val = init_val self.data_width = data_width self._out = self.output("o_data_out", data_width) self._clk = self.clock("i_clk") self._rst_n = self.reset("i_rst_n") self.add_code(self.set_out)
def __init__(self, data_width, interconnect_output_ports): super().__init__("Chain", debug=True) # generator parameters self.data_width = data_width self.interconnect_output_ports = interconnect_output_ports # chain enable configuration register self._chain_en = self.input("chain_en", 1) self._chain_en.add_attribute( ConfigRegAttr("Signal indicating whether to enable chaining")) self._chain_en.add_attribute( FormalAttr(self._chain_en.name, FormalSignalConstraint.SET0)) # inputs self._curr_tile_data_out = self.input( "curr_tile_data_out", self.data_width, size=self.interconnect_output_ports, packed=True, explicit_array=True) self._chain_data_in = self.input("chain_data_in", self.data_width, size=self.interconnect_output_ports, packed=True, explicit_array=True) self._accessor_output = self.input("accessor_output", self.interconnect_output_ports) self._data_out_tile = self.output("data_out_tile", self.data_width, size=self.interconnect_output_ports, packed=True, explicit_array=True) self.add_code(self.set_data_out)
def __init__(self, mem_params, word_width): super().__init__("lake_mem", debug=True) ################################################################ # PARAMETERS ################################################################ # print("MEM PARAMS ", mem_params) # basic parameters self.word_width = word_width # general memory parameters self.mem_name = mem_params["name"] self.capacity = mem_params["capacity"] self.rw_same_cycle = mem_params["rw_same_cycle"] self.use_macro = mem_params["use_macro"] self.macro_name = mem_params["macro_name"] # number of port types self.num_read_write_ports = mem_params["num_read_write_ports"] self.num_read_only_ports = mem_params["num_read_ports"] self.num_write_only_ports = mem_params["num_write_ports"] self.num_read_ports = self.num_read_only_ports + self.num_read_write_ports self.num_write_ports = self.num_write_only_ports + self.num_read_write_ports # info for port types self.write_info = mem_params["write_info"] self.read_info = mem_params["read_info"] self.read_write_info = mem_params["read_write_info"] # TODO change - for now, we assume you cannot have read/write and read or write ports # should be the max of write vs read_write and need to handle more general case if self.num_read_write_ports == 0: self.write_width = mem_params["write_port_width"] self.read_width = mem_params["read_port_width"] else: self.write_width = mem_params["read_write_port_width"] self.read_width = mem_params["read_write_port_width"] assert self.capacity % self.write_width == 0, \ "Memory capacity is not a multiple of the port width for writes" assert self.capacity % self.read_width == 0, \ "Memory capacity is not a multiple of the port width for reads" # innermost dimension for size of memory is the size of whichever port # type has a wider width between reads and writes self.mem_size = max(self.read_width, self.write_width) # this assert has to be true if previous two asserts are true assert self.capacity % self.mem_size == 0 # this is the last dimension for size of memory - equal to the number # of the port type with wider width addresses can fit in the memory self.mem_last_dim = int(self.capacity / self.mem_size) self.mem_size_bits = max(1, clog2(self.mem_size)) self.mem_last_dim_bits = max(1, clog2(self.mem_last_dim)) # chaining parameters and config regs self.chaining = mem_params["chaining"] self.num_chain = mem_params["num_chain"] self.num_chain_bits = clog2(self.num_chain) if self.chaining: self.chain_index = self.var("chain_index", width=self.num_chain_bits) self.chain_index.add_attribute( ConfigRegAttr("Chain index for chaining")) self.chain_index.add_attribute( FormalAttr(self.chain_index.name, FormalSignalConstraint.SET0)) # minimum required widths for address signals if self.mem_size == self.write_width and self.mem_size == self.read_width: self.write_addr_width = self.mem_last_dim_bits + self.num_chain_bits self.read_addr_width = self.mem_last_dim_bits + self.num_chain_bits elif self.mem_size == self.write_width: self.write_addr_width = self.mem_last_dim_bits + self.num_chain_bits self.read_addr_width = self.mem_size_bits + self.mem_last_dim_bits + self.num_chain_bits elif self.mem_size == self.read_width: self.write_addr_width = self.mem_size_bits + self.mem_last_dim_bits + self.num_chain_bits self.read_addr_width = self.mem_last_dim_bits + self.num_chain_bits else: print("Error occurred! Memory size does not make sense.") ################################################################ # I/O INTERFACE (WITHOUT ADDRESSING) + MEMORY ################################################################ self.clk = self.clock("clk") # active low asynchornous reset self.rst_n = self.reset("rst_n", 1) self.data_in = self.input("data_in", width=self.word_width, size=(self.num_write_ports, self.write_width), explicit_array=True, packed=True) self.chain_en = self.input("chain_en", 1) # write enable (high: write, low: read when rw_same_cycle = False, else # only indicates write) self.write = self.input("write", width=1, size=self.num_write_ports) self.data_out = self.output("data_out", width=self.word_width, size=(self.num_read_ports, self.read_width), explicit_array=True, packed=True) self.write_chain = self.var("write_chain", width=1, size=self.num_write_ports) if self.use_macro: self.read_write_addr = self.input("read_write_addr", width=self.addr_width, size=self.num_read_write_ports, explicit_array=True) sram = SRAM(not self.use_macro, self.macro_name, word_width, mem_params["read_write_port_width"], mem_params["capacity"], mem_params["num_read_write_ports"], mem_params["num_read_write_ports"], clog2(mem_params["capacity"]), 0, 1) self.add_child( "SRAM_" + mem_params["name"], sram, clk=self.clk, clk_en=1, mem_data_in_bank=self.data_in, mem_data_out_bank=self.data_out, mem_addr_in_bank=self.read_write_addr, # TODO adjust mem_cen_in_bank=1, mem_wen_in_bank=self.write_chain, wtsel=0, rtsel=1) else: # memory variable (not I/O) self.memory = self.var("memory", width=self.word_width, size=(self.mem_last_dim, self.mem_size), explicit_array=True, packed=True) ################################################################ # ADDRESSING I/O AND SIGNALS ################################################################ # I/O is different depending on whether we have read and write ports or # read/write ports # we keep address width at 16 to avoid unpacked # safe_wire errors for addr in hw_top_lake - can change by changing # default_config_width for those addr gens while accounting for muxing # bits, but the extra bits are unused anyway if self.rw_same_cycle: self.read = self.input("read", width=1, size=self.num_read_ports) else: self.read = self.var("read", width=1, size=self.num_read_ports) for i in range(self.num_read_ports): self.wire(self.read[i], 1) # TODO change later - same read/write or read and write assumption as above if self.num_write_only_ports != 0 and self.num_read_only_ports != 0: # writes self.write_addr = self.input( "write_addr", width=16, # self.write_addr_width, size=self.num_write_ports, explicit_array=True) assert self.write_info[0]["latency"] > 0, \ "Latency for write ports must be greater than 1 clock cycle." # reads self.read_addr = self.input( "read_addr", width=16, # self.read_addr_width, size=self.num_read_ports, explicit_array=True) # TODO for now assuming all read ports have same latency # TODO also should add support for other latencies # rw_same_cycle is not valid here because read/write share the same port elif self.num_read_write_ports != 0: self.read_write_addr = self.input( "read_write_addr", width= 16, # max(self.read_addr_width, self.write_addr_width), size=self.num_read_write_ports, explicit_array=True) # writes self.write_addr = self.var( "write_addr", width=16, # self.write_addr_width, size=self.num_read_write_ports, explicit_array=True) for p in range(self.num_read_write_ports): safe_wire(gen=self, w_to=self.write_addr[p], w_from=self.read_write_addr[p]) # reads self.read_addr = self.var("read_addr", width=self.read_addr_width, size=self.num_read_write_ports, explicit_array=True) for p in range(self.num_read_write_ports): safe_wire(gen=self, w_to=self.read_addr[p], w_from=self.read_write_addr[p]) # TODO in self.read_write_info we should allow for different read # and write latencies? self.read_info = self.read_write_info # TODO just doing chaining for SRAM if self.chaining and self.num_read_write_ports > 0: self.wire( self.write_chain, # chaining not enabled ( ~self.chain_en | # chaining enabled (self.chain_en & (self.chain_index == self.read_write_addr[ self.write_addr_width + self.num_chain_bits, self.write_addr_width]))) & self.write) # chaining not supported else: self.wire(self.write_chain, self.write) if self.use_macro: self.wire(sram.ports.mem_wen_in_bank, self.write_chain) self.add_write_data_block() self.add_read_data_block()
def __init__(self, data_width=16, banks=1, memory_width=64, rw_same_cycle=False, read_delay=1, addr_width=9): super().__init__("strg_fifo") # Generation parameters self.banks = banks self.data_width = data_width self.memory_width = memory_width self.rw_same_cycle = rw_same_cycle self.read_delay = read_delay self.addr_width = addr_width self.fw_int = int(self.memory_width / self.data_width) # assert banks > 1 or rw_same_cycle is True or self.fw_int > 1, \ # "Can't sustain throughput with this setup. Need potential bandwidth for " + \ # "1 write and 1 read in a cycle - try using more banks or a macro that supports 1R1W" # Clock and Reset self._clk = self.clock("clk") self._rst_n = self.reset("rst_n") # Inputs + Outputs self._push = self.input("push", 1) self._data_in = self.input("data_in", self.data_width) self._pop = self.input("pop", 1) self._data_out = self.output("data_out", self.data_width) self._valid_out = self.output("valid_out", 1) self._empty = self.output("empty", 1) self._full = self.output("full", 1) # get relevant signals from the storage banks self._data_from_strg = self.input("data_from_strg", self.data_width, size=(self.banks, self.fw_int), explicit_array=True, packed=True) self._wen_addr = self.var("wen_addr", self.addr_width, size=self.banks, explicit_array=True, packed=True) self._ren_addr = self.var("ren_addr", self.addr_width, size=self.banks, explicit_array=True, packed=True) self._front_combined = self.var("front_combined", self.data_width, size=self.fw_int, explicit_array=True, packed=True) self._data_to_strg = self.output("data_to_strg", self.data_width, size=(self.banks, self.fw_int), explicit_array=True, packed=True) self._wen_to_strg = self.output("wen_to_strg", self.banks) self._ren_to_strg = self.output("ren_to_strg", self.banks) self._num_words_mem = self.var("num_words_mem", self.data_width) if self.banks == 1: self._curr_bank_wr = self.var("curr_bank_wr", 1) self.wire(self._curr_bank_wr, kts.const(0, 1)) self._curr_bank_rd = self.var("curr_bank_rd", 1) self.wire(self._curr_bank_rd, kts.const(0, 1)) else: self._curr_bank_wr = self.var("curr_bank_wr", kts.clog2(self.banks)) self._curr_bank_rd = self.var("curr_bank_rd", kts.clog2(self.banks)) self._write_queue = self.var("write_queue", self.data_width, size=(self.banks, self.fw_int), explicit_array=True, packed=True) # Lets us know if the bank has a write queued up self._queued_write = self.var("queued_write", self.banks) self._front_data_out = self.var("front_data_out", self.data_width) self._front_pop = self.var("front_pop", 1) self._front_empty = self.var("front_empty", 1) self._front_full = self.var("front_full", 1) self._front_valid = self.var("front_valid", 1) self._front_par_read = self.var("front_par_read", 1) self._front_par_out = self.var("front_par_out", self.data_width, size=(self.fw_int, 1), explicit_array=True, packed=True) self._front_rd_ptr = self.var("front_rd_ptr", max(1, clog2(self.fw_int))) self._front_push = self.var("front_push", 1) self.wire(self._front_push, self._push & (~self._full | self._pop)) self._front_rf = RegFIFO(data_width=self.data_width, width_mult=1, depth=self.fw_int, parallel=True, break_out_rd_ptr=True) # This one breaks out the read pointer so we can properly # reorder the data to storage self.add_child("front_rf", self._front_rf, clk=self._clk, clk_en=kts.const(1, 1), rst_n=self._rst_n, push=self._front_push, pop=self._front_pop, empty=self._front_empty, full=self._front_full, valid=self._front_valid, parallel_read=self._front_par_read, parallel_load=kts.const(0, 1), # We don't need to parallel load the front parallel_in=0, # Same reason as above parallel_out=self._front_par_out, num_load=0, rd_ptr_out=self._front_rd_ptr) self.wire(self._front_rf.ports.data_in[0], self._data_in) self.wire(self._front_data_out, self._front_rf.ports.data_out[0]) self._back_data_in = self.var("back_data_in", self.data_width) self._back_data_out = self.var("back_data_out", self.data_width) self._back_push = self.var("back_push", 1) self._back_empty = self.var("back_empty", 1) self._back_full = self.var("back_full", 1) self._back_valid = self.var("back_valid", 1) self._back_pl = self.var("back_pl", 1) self._back_par_in = self.var("back_par_in", self.data_width, size=(self.fw_int, 1), explicit_array=True, packed=True) self._back_num_load = self.var("back_num_load", clog2(self.fw_int) + 1) self._back_occ = self.var("back_occ", clog2(self.fw_int) + 1) self._front_occ = self.var("front_occ", clog2(self.fw_int) + 1) self._back_rf = RegFIFO(data_width=self.data_width, width_mult=1, depth=self.fw_int, parallel=True, break_out_rd_ptr=False) self._fw_is_1 = self.var("fw_is_1", 1) self.wire(self._fw_is_1, kts.const(self.fw_int == 1, 1)) self._back_pop = self.var("back_pop", 1) if self.fw_int == 1: self.wire(self._back_pop, self._pop & (~self._empty | self._push) & ~self._back_pl) else: self.wire(self._back_pop, self._pop & (~self._empty | self._push)) self.add_child("back_rf", self._back_rf, clk=self._clk, clk_en=kts.const(1, 1), rst_n=self._rst_n, push=self._back_push, pop=self._back_pop, empty=self._back_empty, full=self._back_full, valid=self._back_valid, parallel_read=kts.const(0, 1), # Only do back load when data is going there parallel_load=self._back_pl & self._back_num_load.r_or(), parallel_in=self._back_par_in, num_load=self._back_num_load) self.wire(self._back_rf.ports.data_in[0], self._back_data_in) self.wire(self._back_data_out, self._back_rf.ports.data_out[0]) # send the writes through when a read isn't happening for i in range(self.banks): self.add_code(self.send_writes, idx=i) self.add_code(self.send_reads, idx=i) # Set the parallel load to back bank - if no delay it's immediate # if not, it's delayed :) if self.read_delay == 1: self._ren_delay = self.var("ren_delay", 1) self.add_code(self.set_parallel_ld_delay_1) self.wire(self._back_pl, self._ren_delay) else: self.wire(self._back_pl, self._ren_to_strg.r_or()) # Combine front end data - just the items + incoming # this data is actually based on the rd_ptr from the front fifo for i in range(self.fw_int): self.wire(self._front_combined[i], self._front_par_out[self._front_rd_ptr + i]) # This is always true # self.wire(self._front_combined[self.fw_int - 1], self._data_in) # prioritize queued writes, otherwise send combined data for i in range(self.banks): self.wire(self._data_to_strg[i], kts.ternary(self._queued_write[i], self._write_queue[i], self._front_combined)) # Wire the thin output from front to thin input to back self.wire(self._back_data_in, self._front_data_out) self.wire(self._back_push, self._front_valid) self.add_code(self.set_front_pop) # Queue writes for i in range(self.banks): self.add_code(self.set_write_queue, idx=i) # Track number of words in memory # if self.read_delay == 1: # self.add_code(self.set_num_words_mem_delay) # else: self.add_code(self.set_num_words_mem) # Track occupancy of the two small fifos self.add_code(self.set_front_occ) self.add_code(self.set_back_occ) if self.banks > 1: self.add_code(self.set_curr_bank_wr) self.add_code(self.set_curr_bank_rd) if self.read_delay == 1: self._prev_bank_rd = self.var("prev_bank_rd", max(1, kts.clog2(self.banks))) self.add_code(self.set_prev_bank_rd) # Parallel load data to back - based on num_load index_into = self._curr_bank_rd if self.read_delay == 1: index_into = self._prev_bank_rd for i in range(self.fw_int - 1): # Shift data over if you bypassed from the memory output self.wire(self._back_par_in[i], kts.ternary(self._back_num_load == self.fw_int, self._data_from_strg[index_into][i], self._data_from_strg[index_into][i + 1])) self.wire(self._back_par_in[self.fw_int - 1], kts.ternary(self._back_num_load == self.fw_int, self._data_from_strg[index_into][self.fw_int - 1], kts.const(0, self.data_width))) # Set the parallel read to the front fifo - analogous with trying to write to the memory self.add_code(self.set_front_par_read) # Set the number being parallely loaded into the register self.add_code(self.set_back_num_load) # Data out and valid out are (in the general case) just the data and valid from the back fifo # In the case where we have a fresh memory read, it would be from that bank_idx_read = self._curr_bank_rd if self.read_delay == 1: bank_idx_read = self._prev_bank_rd self.wire(self._data_out, kts.ternary(self._back_pl, self._data_from_strg[bank_idx_read][0], self._back_data_out)) self.wire(self._valid_out, kts.ternary(self._back_pl, self._pop, self._back_valid)) # Set addresses to storage for i in range(self.banks): self.add_code(self.set_wen_addr, idx=i) self.add_code(self.set_ren_addr, idx=i) # Now deal with a shared address vs separate addresses if self.rw_same_cycle: # Separate self._wen_addr_out = self.output("wen_addr_out", self.addr_width, size=self.banks, explicit_array=True, packed=True) self._ren_addr_out = self.output("ren_addr_out", self.addr_width, size=self.banks, explicit_array=True, packed=True) self.wire(self._wen_addr_out, self._wen_addr) self.wire(self._ren_addr_out, self._ren_addr) else: self._addr_out = self.output("addr_out", self.addr_width, size=self.banks, explicit_array=True, packed=True) # If sharing the addresses, send read addr with priority for i in range(self.banks): self.wire(self._addr_out[i], kts.ternary(self._wen_to_strg[i], self._wen_addr[i], self._ren_addr[i])) # Do final empty/full self._num_items = self.var("num_items", self.data_width) self.add_code(self.set_num_items) self._fifo_depth = self.input("fifo_depth", self.data_width) self._fifo_depth.add_attribute(ConfigRegAttr("Fifo depth...")) self.wire(self._empty, self._num_items == 0) self.wire(self._full, self._num_items == (self._fifo_depth))
def __init__(self, word_width, input_ports, output_ports, memories, edges): super().__init__("LakeTop", debug=True) # parameters self.word_width = word_width self.input_ports = input_ports self.output_ports = output_ports self.default_config_width = 16 self.cycle_count_width = 16 self.stencil_valid = False # objects self.memories = memories self.edges = edges # tile enable and clock self.tile_en = self.input("tile_en", 1) self.tile_en.add_attribute(ConfigRegAttr("Tile logic enable manifested as clock gate")) self.tile_en.add_attribute(FormalAttr(self.tile_en.name, FormalSignalConstraint.SET1)) self.clk_mem = self.clock("clk") self.clk_mem.add_attribute(FormalAttr(self.clk_mem.name, FormalSignalConstraint.CLK)) # chaining chain_supported = False for mem in self.memories.keys(): if self.memories[mem]["chaining"]: chain_supported = True break if chain_supported: self.chain_en = self.input("chain_en", 1) self.chain_en.add_attribute(ConfigRegAttr("Chaining enable")) self.chain_en.add_attribute(FormalAttr(self.chain_en.name, FormalSignalConstraint.SET0)) else: self.chain_en = self.var("chain_en", 1) self.wire(self.chain_en, 0) # gate clock with tile_en gclk = self.var("gclk", 1) self.gclk = kts.util.clock(gclk) self.wire(gclk, self.clk_mem & self.tile_en) self.clk_en = self.clock_en("clk_en", 1) # active low asynchornous reset self.rst_n = self.reset("rst_n", 1) self.rst_n.add_attribute(FormalAttr(self.rst_n.name, FormalSignalConstraint.RSTN)) # data in and out of top level Lake memory object self.data_in = self.input("data_in", width=self.word_width, size=self.input_ports, explicit_array=True, packed=True) self.data_in.add_attribute(FormalAttr(self.data_in.name, FormalSignalConstraint.SEQUENCE)) self.data_out = self.output("data_out", width=self.word_width, size=self.output_ports, explicit_array=True, packed=True) self.data_out.add_attribute(FormalAttr(self.data_out.name, FormalSignalConstraint.SEQUENCE)) # global cycle count for accessor comparison self._cycle_count = self.var("cycle_count", 16) @always_ff((posedge, self.gclk), (negedge, "rst_n")) def increment_cycle_count(self): if ~self.rst_n: self._cycle_count = 0 else: self._cycle_count = self._cycle_count + 1 self.add_always(increment_cycle_count) # info about memories num_mem = len(memories) subscript_mems = list(self.memories.keys()) # list of the data out from each memory self.mem_data_outs = [self.var(f"mem_data_out_{subscript_mems[i]}", width=self.word_width, size=self.memories[subscript_mems[i]] ["read_port_width" if "read_port_width" in self.memories[subscript_mems[i]] else "read_write_port_width"], explicit_array=True, packed=True) for i in range(num_mem)] # keep track of write, read_addr, and write_addr vars for read/write memories # to later check whether there is a write and what to use for the shared port self.mem_read_write_addrs = {} # create memory instance for each memory self.mem_insts = {} i = 0 for mem in self.memories.keys(): m = mem_inst(self.memories[mem], self.word_width) self.mem_insts[mem] = m self.add_child(mem, m, clk=self.gclk, rst_n=self.rst_n, # put data out in memory data out list data_out=self.mem_data_outs[i], chain_en=self.chain_en) i += 1 # get input and output memories is_input, is_output = [], [] for mem_name in self.memories.keys(): mem = self.memories[mem_name] if mem["is_input"]: is_input.append(mem_name) if mem["is_output"]: is_output.append(mem_name) # TODO direct connection to write doesn't work (?), so have to do this... self.low = self.var("low", 1) self.wire(self.low, 0) # TODO adding multiple ports to 1 memory after talking about mux with compiler team # set up input memories for i in range(len(is_input)): in_mem = is_input[i] # input addressor / accessor parameters input_dim = self.memories[in_mem]["input_edge_params"]["dim"] input_range = self.memories[in_mem]["input_edge_params"]["max_range"] input_stride = self.memories[in_mem]["input_edge_params"]["max_stride"] # input port associated with memory input_port_index = self.memories[in_mem]["input_port"] self.valid = self.var( f"input_port{input_port_index}_2{in_mem}_accessor_valid", 1) self.wire(self.mem_insts[in_mem].ports.write, self.valid) # hook up data from the specified input port to the memory safe_wire(self, self.mem_insts[in_mem].ports.data_in[0], self.data_in[input_port_index]) if self.memories[in_mem]["num_read_write_ports"] > 0: self.mem_read_write_addrs[in_mem] = {"write": self.valid} # create IteratorDomain, AddressGenerator, and ScheduleGenerator # for writes to this input memory forloop = ForLoop(iterator_support=input_dim, config_width=max(1, clog2(input_range))) # self.default_config_width) loop_itr = forloop.get_iter() loop_wth = forloop.get_cfg_width() self.add_child(f"input_port{input_port_index}_2{in_mem}_forloop", forloop, clk=self.gclk, rst_n=self.rst_n, step=self.valid) newAG = AddrGen(iterator_support=input_dim, config_width=max(1, clog2(input_stride))) # self.default_config_width) self.add_child(f"input_port{input_port_index}_2{in_mem}_write_addr_gen", newAG, clk=self.gclk, rst_n=self.rst_n, step=self.valid, mux_sel=forloop.ports.mux_sel_out, restart=forloop.ports.restart) if self.memories[in_mem]["num_read_write_ports"] == 0: safe_wire(self, self.mem_insts[in_mem].ports.write_addr[0], newAG.ports.addr_out) else: self.mem_read_write_addrs[in_mem]["write_addr"] = newAG.ports.addr_out newSG = SchedGen(iterator_support=input_dim, config_width=self.cycle_count_width) self.add_child(f"input_port{input_port_index}_2{in_mem}_write_sched_gen", newSG, clk=self.gclk, rst_n=self.rst_n, mux_sel=forloop.ports.mux_sel_out, finished=forloop.ports.restart, cycle_count=self._cycle_count, valid_output=self.valid) # set up output memories for i in range(len(is_output)): out_mem = is_output[i] # output addressor / accessor parameters output_dim = self.memories[out_mem]["output_edge_params"]["dim"] output_range = self.memories[out_mem]["output_edge_params"]["max_range"] output_stride = self.memories[out_mem]["output_edge_params"]["max_stride"] # output port associated with memory output_port_index = self.memories[out_mem]["output_port"] # hook up data from the memory to the specified output port self.wire(self.data_out[output_port_index], self.mem_insts[out_mem].ports.data_out[0][0]) # self.mem_data_outs[subscript_mems.index(out_mem)][0]) self.valid = self.var(f"{out_mem}2output_port{output_port_index}_accessor_valid", 1) if self.memories[out_mem]["rw_same_cycle"]: self.wire(self.mem_insts[out_mem].ports.read, self.valid) # create IteratorDomain, AddressGenerator, and ScheduleGenerator # for reads from this output memory forloop = ForLoop(iterator_support=output_dim, config_width=max(1, clog2(output_range))) # self.default_config_width) loop_itr = forloop.get_iter() loop_wth = forloop.get_cfg_width() self.add_child(f"{out_mem}2output_port{output_port_index}_forloop", forloop, clk=self.gclk, rst_n=self.rst_n, step=self.valid) newAG = AddrGen(iterator_support=output_dim, config_width=max(1, clog2(output_stride))) # self.default_config_width) self.add_child(f"{out_mem}2output_port{output_port_index}_read_addr_gen", newAG, clk=self.gclk, rst_n=self.rst_n, step=self.valid, mux_sel=forloop.ports.mux_sel_out, restart=forloop.ports.restart) if self.memories[out_mem]["num_read_write_ports"] == 0: safe_wire(self, self.mem_insts[out_mem].ports.read_addr[0], newAG.ports.addr_out) else: self.mem_read_write_addrs[in_mem]["read_addr"] = newAG.ports.addr_out newSG = SchedGen(iterator_support=output_dim, config_width=self.cycle_count_width) # self.default_config_width) self.add_child(f"{out_mem}2output_port{output_port_index}_read_sched_gen", newSG, clk=self.gclk, rst_n=self.rst_n, mux_sel=forloop.ports.mux_sel_out, finished=forloop.ports.restart, cycle_count=self._cycle_count, valid_output=self.valid) # create shared IteratorDomains and accessors as well as # read/write addressors for memories connected by each edge for edge in self.edges: # see how many signals need to be selected between for # from and to signals for edge num_mux_from = len(edge["from_signal"]) num_mux_to = len(edge["to_signal"]) # get unique edge_name identifier for hardware modules edge_name = get_edge_name(edge) # create forloop and accessor valid output signal self.valid = self.var(edge_name + "_accessor_valid", 1) forloop = ForLoop(iterator_support=edge["dim"]) self.forloop = forloop loop_itr = forloop.get_iter() loop_wth = forloop.get_cfg_width() self.add_child(edge_name + "_forloop", forloop, clk=self.gclk, rst_n=self.rst_n, step=self.valid) # create input addressor readAG = AddrGen(iterator_support=edge["dim"], config_width=self.default_config_width) self.add_child(f"{edge_name}_read_addr_gen", readAG, clk=self.gclk, rst_n=self.rst_n, step=self.valid, mux_sel=forloop.ports.mux_sel_out, restart=forloop.ports.restart) # assign read address to all from memories if self.memories[edge["from_signal"][0]]["num_read_write_ports"] == 0: # can assign same read addrs to all the memories for i in range(len(edge["from_signal"])): safe_wire(self, self.mem_insts[edge["from_signal"][i]].ports.read_addr[0], readAG.ports.addr_out) else: for i in range(len(edge["from_signal"])): self.mem_read_write_addrs[edge["from_signal"][i]]["read_addr"] = readAG.ports.addr_out # if needing to mux, choose which from memory we get data # from for to memory data in if num_mux_from > 1: num_mux_bits = clog2(num_mux_from) self.mux_sel = self.var(f"{edge_name}_mux_sel", width=num_mux_bits) read_addr_width = max(1, clog2(self.memories[edge["from_signal"][0]]["capacity"])) # decide which memory to get data from for to memory's data in safe_wire(self, self.mux_sel, readAG.ports.addr_out[read_addr_width + num_mux_from - 1, read_addr_width]) comb_mux_from = self.combinational() # for i in range(num_mux_from): # TODO want to use a switch statement here, but get add_fn_ln issue if_mux_sel = IfStmt(self.mux_sel == 0) for j in range(len(edge["to_signal"])): # print("TO ", edge["to_signal"][j]) # print("FROM ", edge["from_signal"][i]) if_mux_sel.then_(self.mem_insts[edge["to_signal"][j]].ports.data_in.assign(self.mem_insts[edge["from_signal"][0]].ports.data_out)) if_mux_sel.else_(self.mem_insts[edge["to_signal"][j]].ports.data_in.assign(self.mem_insts[edge["from_signal"][1]].ports.data_out)) comb_mux_from.add_stmt(if_mux_sel) # no muxing from, data_out from the one and only memory # goes to all to memories (valid determines whether it is # actually written) else: for j in range(len(edge["to_signal"])): # print("TO ", edge["to_signal"][j]) # print("FROM ", edge["from_signal"][0]) safe_wire(self, self.mem_insts[edge["to_signal"][j]].ports.data_in, # only one memory to read from self.mem_insts[edge["from_signal"][0]].ports.data_out) # create output addressor writeAG = AddrGen(iterator_support=edge["dim"], config_width=self.default_config_width) # step, mux_sel, restart may need delayed signals (assigned later) self.add_child(f"{edge_name}_write_addr_gen", writeAG, clk=self.gclk, rst_n=self.rst_n) # set write addr for to memories if self.memories[edge["to_signal"][0]]["num_read_write_ports"] == 0: for i in range(len(edge["to_signal"])): safe_wire(self, self.mem_insts[edge["to_signal"][i]].ports.write_addr[0], writeAG.ports.addr_out) else: for i in range(len(edge["to_signal"])): self.mem_read_write_addrs[edge["to_signal"][i]] = {"write": self.valid, "write_addr": writeAG.ports.addr_out} # calculate necessary delay between from_signal to to_signal # TODO this may need to be more sophisticated and based on II as well # TODO just need to add for loops for all the ports if self.memories[edge["from_signal"][0]]["num_read_write_ports"] == 0: self.delay = self.memories[edge["from_signal"][0]]["read_info"][0]["latency"] else: self.delay = self.memories[edge["from_signal"][0]]["read_write_info"][0]["latency"] if self.delay > 0: # signals that need to be delayed due to edge latency self.delayed_writes = self.var(f"{edge_name}_delayed_writes", width=self.delay) self.delayed_mux_sels = self.var(f"{edge_name}_delayed_mux_sels", width=self.forloop.ports.mux_sel_out.width, size=self.delay, explicit_array=True, packed=True) self.delayed_restarts = self.var(f"{edge_name}_delayed_restarts", width=self.delay) # delay in valid between read from memory and write to next memory @always_ff((posedge, self.gclk), (negedge, "rst_n")) def get_delayed_write(self): if ~self.rst_n: self.delayed_writes = 0 self.delayed_mux_sels = 0 self.delayed_restarts = 0 else: for i in range(self.delay - 1): self.delayed_writes[i + 1] = self.delayed_writes[i] self.delayed_mux_sels[i + 1] = self.delayed_mux_sels[i] self.delayed_restarts[i + 1] = self.delayed_restarts[i] self.delayed_writes[0] = self.valid self.delayed_mux_sels[0] = self.forloop.ports.mux_sel_out self.delayed_restarts[0] = self.forloop.ports.restart self.add_always(get_delayed_write) # if we have a mux for the destination memories, # choose which mux to write to if num_mux_to > 1: num_mux_bits = clog2(num_mux_to) self.mux_sel_to = self.var(f"{edge_name}_mux_sel_to", width=num_mux_bits) write_addr_width = max(1, clog2(self.memories[edge["to_signal"][0]]["capacity"])) # decide which destination memory gets written to safe_wire(self, self.mux_sel_to, writeAG.ports.addr_out[write_addr_width + num_mux_to - 1, write_addr_width]) # wire the write (or if needed, delayed write) signal to the selected destination memory # and set write enable low for all other destination memories comb_mux_to = self.combinational() for i in range(num_mux_to): if_mux_sel_to = IfStmt(self.mux_sel_to == i) if self.delay == 0: if_mux_sel_to.then_(self.mem_insts[edge["to_signal"][i]].ports.write.assign(self.valid)) else: if_mux_sel_to.then_(self.mem_insts[edge["to_signal"][i]].ports.write.assign(self.delayed_writes[self.delay - 1])) if_mux_sel_to.else_(self.mem_insts[edge["to_signal"][i]].ports.write.assign(self.low)) comb_mux_to.add_stmt(if_mux_sel_to) # no muxing to, just write to the one destination memory else: if self.delay == 0: self.wire(self.mem_insts[edge["to_signal"][0]].ports.write, self.valid) else: self.wire(self.mem_insts[edge["to_signal"][0]].ports.write, self.delayed_writes[self.delay - 1]) # assign delayed signals for write addressor if needed if self.delay == 0: self.wire(writeAG.ports.step, self.valid) self.wire(writeAG.ports.mux_sel, self.forloop.ports.mux_sel_out) self.wire(writeAG.ports.restart, self.forloop.ports.restart) else: self.wire(writeAG.ports.step, self.delayed_writes[self.delay - 1]) self.wire(writeAG.ports.mux_sel, self.delayed_mux_sels[self.delay - 1]) self.wire(writeAG.ports.restart, self.delayed_restarts[self.delay - 1]) # create accessor for edge newSG = SchedGen(iterator_support=edge["dim"], config_width=self.cycle_count_width) # self.default_config_width) self.add_child(edge_name + "_sched_gen", newSG, clk=self.gclk, rst_n=self.rst_n, mux_sel=forloop.ports.mux_sel_out, finished=forloop.ports.restart, cycle_count=self._cycle_count, valid_output=self.valid) # for read write memories, choose either read or write address based on whether # we are writing to the memory (whether write enable is high) read_write_addr_comb = self.combinational() for mem_name in self.memories: if mem_name in self.mem_read_write_addrs: mem_info = self.mem_read_write_addrs[mem_name] if_write = IfStmt(mem_info["write"] == 1) addr_width = self.mem_insts[mem_name].ports.read_write_addr[0].width if_write.then_(self.mem_insts[mem_name].ports.read_write_addr[0].assign(mem_info["write_addr"][addr_width - 1, 0])) if_write.else_(self.mem_insts[mem_name].ports.read_write_addr[0].assign(mem_info["read_addr"][addr_width - 1, 0])) read_write_addr_comb.add_stmt(if_write) # clock enable and flush passes kts.passes.auto_insert_clock_enable(self.internal_generator) clk_en_port = self.internal_generator.get_port("clk_en") clk_en_port.add_attribute(FormalAttr(clk_en_port.name, FormalSignalConstraint.SET1)) self.add_attribute("sync-reset=flush") kts.passes.auto_insert_sync_reset(self.internal_generator) flush_port = self.internal_generator.get_port("flush") # bring config registers up to top level lift_config_reg(self.internal_generator)
def __init__(self, fetch_width, data_width, int_out_ports): assert not (fetch_width & (fetch_width - 1)), "Memory width needs to be a power of 2" super().__init__("sync_groups") # Absorb inputs self.fetch_width = fetch_width self.data_width = data_width self.fw_int = int(self.fetch_width / self.data_width) self.int_out_ports = int_out_ports self.groups = self.int_out_ports # Clock and Reset self._clk = self.clock("clk") self._rst_n = self.reset("rst_n") # Inputs self._ack_in = self.input("ack_in", self.int_out_ports) self._data_in = self.input("data_in", self.data_width, size=(self.int_out_ports, self.fw_int), explicit_array=True, packed=True) self._mem_valid_data = self.input("mem_valid_data", self.int_out_ports) self._mem_valid_data_out = self.output("mem_valid_data_out", self.int_out_ports) self._valid_in = self.input("valid_in", self.int_out_ports) # Indicates which port belongs to which synchronization group self._sync_group = self.input("sync_group", self.int_out_ports, size=self.int_out_ports, explicit_array=True, packed=True) sync_config = ConfigRegAttr( "This array of one hot vectors" + " is used to denote which ports are synchronized to eachother." + " If multiple ports should output data relative to eachother" + " one should put them in the same group.") self._sync_group.add_attribute(sync_config) # Outputs self._data_out = self.output("data_out", self.data_width, size=(self.int_out_ports, self.fw_int), explicit_array=True, packed=True) self._valid_out = self.output("valid_out", self.int_out_ports) # Locals self._sync_agg = self.var("sync_agg", self.int_out_ports, size=self.int_out_ports, explicit_array=True, packed=True) self._sync_valid = self.var("sync_valid", self.int_out_ports) self._data_reg = self.var("data_reg", self.data_width, size=(self.int_out_ports, self.fw_int), explicit_array=True, packed=True) self._valid_reg = self.var("valid_reg", self.int_out_ports) # This signal allows us to orchestrate the synchronization groups # at the output address controller self._ren_in = self.input("ren_in", self.int_out_ports) self._ren_int = self.var("ren_int", self.int_out_ports) self._rd_sync_gate = self.output("rd_sync_gate", self.int_out_ports) self._local_gate_bus = self.var("local_gate_bus", self.int_out_ports, size=self.int_out_ports, explicit_array=True, packed=True) self._local_gate_bus_n = self.var("local_gate_bus_n", self.int_out_ports, size=self.int_out_ports, explicit_array=True, packed=True) self.wire(self._local_gate_bus, ~self._local_gate_bus_n) self._local_gate_bus_tpose = self.var("local_gate_bus_tpose", self.int_out_ports, size=self.int_out_ports, explicit_array=True, packed=True) self._local_gate_reduced = self.var("local_gate_reduced", self.int_out_ports) self._local_gate_mask = self.var("local_gate_mask", self.int_out_ports, size=self.int_out_ports, explicit_array=True, packed=True) self._group_finished = self.var("group_finished", self.int_out_ports) self._grp_fin_large = self.var("grp_fin_large", self.int_out_ports, size=self.int_out_ports, explicit_array=True, packed=True) self._done = self.var("done", self.int_out_ports) self._done_alt = self.var("done_alt", self.int_out_ports) # Output data is ungated self.wire(self._data_out, self._data_reg) self.wire(self._rd_sync_gate, self._local_gate_reduced) # Valid requires gating based on sync_valid self.wire(self._ren_int, self._ren_in & self._local_gate_reduced) # Add Code self.add_code(self.set_sync_agg, unroll_for=True) self.add_code(self.set_sync_valid) for i in range(self.int_out_ports): self.add_code(self.set_sync_stage, idx=i) self.add_code(self.set_out_valid) self.add_code(self.set_reduce_gate) for i in range(self.groups): self.add_code(self.set_rd_gates, idx=i) self.add_code(self.set_tpose) self.add_code(self.set_finished) self.add_code(self.next_gate_mask, unroll_for=True) self.add_code(self.set_grp_fin, unroll_for=True)
def __init__(self, agg_height, data_width, mem_width, max_agg_schedule): super().__init__("aggregation_buffer") self.agg_height = agg_height self.data_width = data_width self.mem_width = mem_width self.max_agg_schedule = max_agg_schedule # This is the maximum length of the schedule self.fw_int = int(self.mem_width / self.data_width) # Clock and Reset self._clk = self.clock("clk") self._rst_n = self.reset("rst_n") # Inputs # Bring in a single element into an AggregationBuffer w/ valid signaling self._data_in = self.input("data_in", self.data_width) self._valid_in = self.input("valid_in", 1) self._align = self.input("align", 1) # Outputs self._data_out = self.output("data_out", self.mem_width) self._valid_out = self.output("valid_out", 1) self._data_out_chop = [] for i in range(self.fw_int): self._data_out_chop.append( self.output(f"data_out_chop_{i}", self.data_width)) self.add_stmt(self._data_out_chop[i].assign( self._data_out[(self.data_width * (i + 1)) - 1, self.data_width * i])) # CONFIG: # We receive a periodic (doesn't need to be, but has a maximum schedule, # so...possibly the schedule is a for loop? # Tells us where to write successive elements... self._in_schedule = self.input("in_sched", max(1, clog2(self.agg_height)), size=self.max_agg_schedule, explicit_array=True, packed=True) doc = "Input schedule for aggregation buffer. Enumerate which" + \ f" of {self.agg_height} buffers to write to." self._in_schedule.add_attribute(ConfigRegAttr(doc)) self._in_period = self.input("in_period", clog2(self.max_agg_schedule)) doc = "Input period for aggregation buffer. 1 is a reasonable" + \ " setting for most applications" self._in_period.add_attribute(ConfigRegAttr(doc)) # ...and which order to output the blocks self._out_schedule = self.input("out_sched", max(1, clog2(agg_height)), size=self.max_agg_schedule, explicit_array=True, packed=True) doc = "Output schedule for aggregation buffer. Enumerate which" + \ f" of {self.agg_height} buffers to write to SRAM from." self._out_schedule.add_attribute(ConfigRegAttr(doc)) self._out_period = self.input("out_period", clog2(self.max_agg_schedule)) self._out_period.add_attribute( ConfigRegAttr("Output period for aggregation buffer")) self._in_sched_ptr = self.var("in_sched_ptr", clog2(self.max_agg_schedule)) self._out_sched_ptr = self.var("out_sched_ptr", clog2(self.max_agg_schedule)) # Local Signals self._aggs_out = self.var("aggs_out", self.mem_width, size=self.agg_height, packed=True, explicit_array=True) self._aggs_sep = [] for i in range(self.agg_height): self._aggs_sep.append( self.var(f"aggs_sep_{i}", self.data_width, size=self.fw_int, packed=True)) self._valid_demux = self.var("valid_demux", self.agg_height) self._align_demux = self.var("align_demux", self.agg_height) self._next_full = self.var("next_full", self.agg_height) self._valid_out_mux = self.var("valid_out_mux", self.agg_height) for i in range(self.agg_height): # Add in the children aggregators... self.add_child(f"agg_{i}", Aggregator(self.data_width, mem_word_width=self.fw_int), clk=self._clk, rst_n=self._rst_n, in_pixels=self._data_in, valid_in=self._valid_demux[i], agg_out=self._aggs_sep[i], valid_out=self._valid_out_mux[i], next_full=self._next_full[i], align=self._align_demux[i]) portlist = [] if self.fw_int == 1: self.wire(self._aggs_out[i], self._aggs_sep[i]) else: for j in range(self.fw_int): portlist.append(self._aggs_sep[i][self.fw_int - 1 - j]) self.wire(self._aggs_out[i], kts.concat(*portlist)) # Sequential code blocks self.add_code(self.update_in_sched_ptr) self.add_code(self.update_out_sched_ptr) # Combinational code blocks self.add_code(self.valid_demux_comb) self.add_code(self.align_demux_comb) self.add_code(self.valid_out_comb) self.add_code(self.output_data_comb)
def __init__(self, iterator_support=6, config_width=16, use_enable=True): super().__init__(f"sched_gen_{iterator_support}_{config_width}") self.iterator_support = iterator_support self.config_width = config_width self.use_enable = use_enable # PORT DEFS: begin # INPUTS self._clk = self.clock("clk") self._rst_n = self.reset("rst_n") # OUTPUTS self._valid_output = self.output("valid_output", 1) # VARS self._valid_out = self.var("valid_out", 1) self._cycle_count = self.input("cycle_count", self.config_width) self._mux_sel = self.input("mux_sel", max(clog2(self.iterator_support), 1)) self._addr_out = self.var("addr_out", self.config_width) # Receive signal on last iteration of looping structure and # gate the output... self._finished = self.input("finished", 1) self._valid_gate_inv = self.var("valid_gate_inv", 1) self._valid_gate = self.var("valid_gate", 1) self.wire(self._valid_gate, ~self._valid_gate_inv) # Since dim = 0 is not sufficient, we need a way to prevent # the controllers from firing on the starting offset if self.use_enable: self._enable = self.input("enable", 1) self._enable.add_attribute( ConfigRegAttr("Disable the controller so it never fires...")) self._enable.add_attribute( FormalAttr(f"{self._enable.name}", FormalSignalConstraint.SOLVE)) # Otherwise we set it as a 1 and leave it up to synthesis... else: self._enable = self.var("enable", 1) self.wire(self._enable, kratos.const(1, 1)) @always_ff((posedge, "clk"), (negedge, "rst_n")) def valid_gate_inv_ff(): if ~self._rst_n: self._valid_gate_inv = 0 # If we are finishing the looping structure, turn this off to implement one-shot elif self._finished: self._valid_gate_inv = 1 self.add_code(valid_gate_inv_ff) # Compare based on minimum of addr + global cycle... self.c_a_cmp = min(self._cycle_count.width, self._addr_out.width) # PORT DEFS: end self.add_child(f"sched_addr_gen", AddrGen(iterator_support=self.iterator_support, config_width=self.config_width), clk=self._clk, rst_n=self._rst_n, step=self._valid_out, mux_sel=self._mux_sel, addr_out=self._addr_out, restart=const(0, 1)) self.add_code(self.set_valid_out) self.add_code(self.set_valid_output)
def __init__(self, iterator_support=6, config_width=16): super().__init__(f"for_loop_{iterator_support}_{config_width}", debug=True) self.iterator_support = iterator_support self.config_width = config_width # Create params for instancing this module... self.iterator_support_par = self.param("ITERATOR_SUPPORT", clog2(iterator_support) + 1, value=self.iterator_support) self.config_width_par = self.param("CONFIG_WIDTH", clog2(config_width) + 1, value=self.config_width) # PORT DEFS: begin # INPUTS self._clk = self.clock("clk") self._rst_n = self.reset("rst_n") self._ranges = self.input("ranges", self.config_width, size=self.iterator_support, packed=True, explicit_array=True) self._ranges.add_attribute( ConfigRegAttr("Ranges of address generator")) self._ranges.add_attribute( FormalAttr(f"{self._ranges.name}", FormalSignalConstraint.SOLVE)) self._dimensionality = self.input("dimensionality", 1 + clog2(self.iterator_support)) self._dimensionality.add_attribute( ConfigRegAttr("Dimensionality of address generator")) self._dimensionality.add_attribute( FormalAttr(f"{self._dimensionality.name}", FormalSignalConstraint.SOLVE)) self._step = self.input("step", 1) # OUTPUTS # PORT DEFS: end # LOCAL VARIABLES: begin self._dim_counter = self.var("dim_counter", self.config_width, size=self.iterator_support, packed=True, explicit_array=True) self._strt_addr = self.var("strt_addr", self.config_width) self._counter_update = self.var("counter_update", 1) self._calc_addr = self.var("calc_addr", self.config_width) self._max_value = self.var("max_value", self.iterator_support) self._mux_sel = self.var("mux_sel", max(clog2(self.iterator_support), 1)) self._mux_sel_out = self.output("mux_sel_out", max(clog2(self.iterator_support), 1)) self.wire(self._mux_sel_out, self._mux_sel) # LOCAL VARIABLES: end # GENERATION LOGIC: begin self._done = self.var("done", 1) self._clear = self.var("clear", self.iterator_support) self._inc = self.var("inc", self.iterator_support) self._inced_cnt = self.var("inced_cnt", self._dim_counter.width) self.wire(self._inced_cnt, self._dim_counter[self._mux_sel] + 1) # Next_max_value self._maxed_value = self.var("maxed_value", 1) self.wire( self._maxed_value, (self._dim_counter[self._mux_sel] == self._ranges[self._mux_sel]) & self._inc[self._mux_sel]) self.add_code(self.set_mux_sel) for i in range(self.iterator_support): self.add_code(self.set_clear, idx=i) self.add_code(self.set_inc, idx=i) self.add_code(self.dim_counter_update, idx=i) self.add_code(self.max_value_update, idx=i) # GENERATION LOGIC: end self._restart = self.output("restart", 1) self.wire(self._restart, self._step & (~self._done))
def __init__( self, # number of bits in a word word_width, # number of words that can be sotred at an address in SRAM # note fetch_width must be powers of 2 fetch_width, # total number of transpose buffers num_tb, # height of this particular transpose buffer max_tb_height, # maximum value for range parameters in nested for loop # (and as a result, maximum length of indices input vector) # specifying inner for loop values for output column # addressing max_range, max_range_inner, max_stride, tb_iterator_support): super().__init__("transpose_buffer", debug=True) ######################### # GENERATION PARAMETERS # ######################### self.word_width = word_width self.fetch_width = fetch_width self.num_tb = num_tb self.max_tb_height = max_tb_height self.max_range = max_range self.max_range_inner = max_range_inner self.max_stride = max_stride self.tb_iterator_support = tb_iterator_support ################################## # BITS FOR GENERATION PARAMETERS # ################################## self.fetch_width_bits = max(1, clog2(self.fetch_width)) self.num_tb_bits = max(1, clog2(self.num_tb)) self.max_range_bits = max(1, clog2(self.max_range)) self.max_range_inner_bits = max(1, clog2(self.max_range_inner)) self.max_stride_bits = max(1, clog2(self.max_stride)) self.tb_col_index_bits = 2 * max(self.fetch_width_bits, self.num_tb_bits) + 1 self.max_tb_height_bits2 = max(1, clog2(2 * self.max_tb_height)) self.max_tb_height_bits = max(1, clog2(self.max_tb_height)) self.tb_iterator_support_bits = max( 1, clog2(self.tb_iterator_support) + 1) self.max_range_stride_bits2 = max(2 * self.max_range_bits, 2 * self.max_stride_bits) ########## # INPUTS # ########## self.clk = self.clock("clk") # active low asynchronous reset self.rst_n = self.reset("rst_n", 1) # data input from SRAM if self.fetch_width == 1: self.input_data = self.input("input_data", self.word_width) else: self.input_data = self.input("input_data", width=self.word_width, size=self.fetch_width, packed=True) # valid indicating whether data input from SRAM is valid and # should be stored in transpose buffer self.valid_data = self.input("valid_data", 1) self.ack_in = self.input("ack_in", 1) self.ren = self.input("ren", 1) self.mem_valid_data = self.input("mem_valid_data", 1) ########################### # CONFIGURATION REGISTERS # ########################### # the range of the outer for loop in nested for loop for output # column address generation self.range_outer = self.input("range_outer", self.max_range_bits) self.range_outer.add_attribute( ConfigRegAttr("Outer range for output for loop pattern")) # the range of the inner for loop in nested for loop for output # column address generation self.range_inner = self.input("range_inner", self.max_range_inner_bits) self.range_inner.add_attribute( ConfigRegAttr("Inner range for output for for loop pattern")) # stride for the given application self.stride = self.input("stride", self.max_stride_bits) self.stride.add_attribute(ConfigRegAttr("Application stride")) self.tb_height = self.input("tb_height", self.max_tb_height_bits) self.tb_height.add_attribute(ConfigRegAttr("Transpose Buffer height")) self.dimensionality = self.input("dimensionality", self.tb_iterator_support_bits) self.dimensionality.add_attribute( ConfigRegAttr("Transpose Buffer dimensionality")) # specifies inner for loop values for output column # addressing self.indices = self.input( "indices", width=clog2(2 * self.num_tb * self.fetch_width), # the length of indices is equal to range_inner, # so the maximum possible size for self.indices # is the maximum value of range_inner, which is # self.max_range_inner size=self.max_range_inner, explicit_array=True, packed=True) self.indices.add_attribute( ConfigRegAttr("Output indices for for loop pattern")) # offset to start output address if we're starting in the middle of a wider # fetch width word for example self.starting_addr = self.input("starting_addr", self.fetch_width_bits) self.starting_addr.add_attribute(ConfigRegAttr("TB starting address")) ########### # OUTPUTS # ########### self.col_pixels = self.output("col_pixels", width=self.word_width, size=self.max_tb_height, packed=True, explicit_array=True) self.output_valid = self.output("output_valid", 1) self.rdy_to_arbiter = self.output("rdy_to_arbiter", 1) ################### # LOCAL VARIABLES # ################### # transpose buffer if self.fetch_width == 1: self.tb = self.var("tb", width=self.word_width, size=2 * self.max_tb_height, packed=True) else: self.tb = self.var("tb", width=self.word_width, size=[2 * self.max_tb_height, self.fetch_width], packed=True) self.tb_valid = self.var("tb_valid", 2 * self.max_tb_height) self.index_outer = self.var("index_outer", self.max_range_bits) self.index_inner = self.var("index_inner", self.max_range_inner_bits) self.input_buf_index = self.var("input_buf_index", 1) self.out_buf_index = self.var("out_buf_index", 1) self.switch_out_buf = self.var("switch_out_buf", 1) self.switch_next_line = self.var("switch_next_line", 1) self.row_index = self.var("row_index", self.max_tb_height_bits) self.input_index = self.var("input_index", self.max_tb_height_bits2) self.output_index_abs = self.var("output_index_abs", self.max_range_stride_bits2) if self.fetch_width != 1: self.output_index_long = self.var("output_index_long", self.max_range_stride_bits2) self.output_index = self.var("output_index", self.fetch_width_bits) self.indices_index_inner = self.var( "indices_index_inner", clog2(2 * self.num_tb * self.fetch_width)) self.curr_out_start = self.var("curr_out_start", self.max_range_stride_bits2) self.start_data = self.var("start_data", 1) self.old_start_data = self.var("old_start_data", 1) self.pause_tb = self.var("pause_tb", 1) self.pause_output = self.var("pause_output", 1) self.on_next_line = self.var("on_next_line", 1) self.mask_valid = self.var("mask_valid", 1) self.pause_tbinv = self.var("pause_tbinv", 1) self.rdy_to_arbiterinv = self.var("rdy_to_arbiterinv", 1) self.out_buf_indexinv = self.var("out_buf_indexinv", 1) ########################## # SEQUENTIAL CODE BLOCKS # ########################## self.add_code(self.set_index_outer) self.add_code(self.set_index_inner) self.add_code(self.set_pause_tb) self.add_code(self.set_row_index) self.add_code(self.set_input_buf_index) self.add_code(self.input_to_tb) self.add_code(self.output_from_tb) self.add_code(self.set_output_valid) self.add_code(self.set_out_buf_index) self.add_code(self.set_rdy_to_arbiter) self.add_code(self.set_start_data) self.add_code(self.set_curr_out_start) if self.fetch_width != 1: self.add_code(self.set_output_index) self.add_code(self.set_old_start_data) self.add_code(self.set_on_next_line) ############################# # COMBINATIONAL CODE BLOCKS # ############################# self.add_code(self.set_pause_output) self.add_code(self.set_input_index) self.add_code(self.set_tb_out_indices) self.add_code(self.set_switch_out_buf) self.add_code(self.set_switch_next_line) if self.fetch_width != 1: self.add_code(self.set_output_index_long) self.add_code(self.set_mask_valid) self.add_code(self.set_invs)
def __init__(self, data_width=16, # CGRA Params mem_width=64, mem_depth=512, banks=1, input_iterator_support=6, # Addr Controllers output_iterator_support=6, input_config_width=16, output_config_width=16, interconnect_input_ports=2, # Connection to int interconnect_output_ports=2, mem_input_ports=1, mem_output_ports=1, read_delay=1, # Cycle delay in read (SRAM vs Register File) rw_same_cycle=False, # Does the memory allow r+w in same cycle? agg_height=4, max_agg_schedule=32, input_max_port_sched=32, output_max_port_sched=32, align_input=1, max_line_length=128, max_tb_height=1, tb_range_max=128, tb_range_inner_max=5, tb_sched_max=64, max_tb_stride=15, num_tb=1, tb_iterator_support=2, multiwrite=1, num_tiles=1, max_prefetch=8, app_ctrl_depth_width=16, remove_tb=False, stcl_valid_iter=4): super().__init__("strg_ub") self.data_width = data_width self.mem_width = mem_width self.mem_depth = mem_depth self.banks = banks self.input_iterator_support = input_iterator_support self.output_iterator_support = output_iterator_support self.input_config_width = input_config_width self.output_config_width = output_config_width self.interconnect_input_ports = interconnect_input_ports self.interconnect_output_ports = interconnect_output_ports self.mem_input_ports = mem_input_ports self.mem_output_ports = mem_output_ports self.agg_height = agg_height self.max_agg_schedule = max_agg_schedule self.input_max_port_sched = input_max_port_sched self.output_max_port_sched = output_max_port_sched self.input_port_sched_width = clog2(self.interconnect_input_ports) self.align_input = align_input self.max_line_length = max_line_length assert self.mem_width >= self.data_width, "Data width needs to be smaller than mem" self.fw_int = int(self.mem_width / self.data_width) self.num_tb = num_tb self.max_tb_height = max_tb_height self.tb_range_max = tb_range_max self.tb_range_inner_max = tb_range_inner_max self.max_tb_stride = max_tb_stride self.tb_sched_max = tb_sched_max self.tb_iterator_support = tb_iterator_support self.multiwrite = multiwrite self.max_prefetch = max_prefetch self.num_tiles = num_tiles self.app_ctrl_depth_width = app_ctrl_depth_width self.remove_tb = remove_tb self.read_delay = read_delay self.rw_same_cycle = rw_same_cycle self.stcl_valid_iter = stcl_valid_iter # phases = [] TODO self.address_width = clog2(self.num_tiles * self.mem_depth) # CLK and RST self._clk = self.clock("clk") self._rst_n = self.reset("rst_n") # INPUTS self._data_in = self.input("data_in", self.data_width, size=self.interconnect_input_ports, packed=True, explicit_array=True) self._wen_in = self.input("wen_in", self.interconnect_input_ports) self._ren_input = self.input("ren_in", self.interconnect_output_ports) # Post rate matched self._ren_in = self.var("ren_in_muxed", self.interconnect_output_ports) # Processed versions of wen and ren from the app ctrl self._wen = self.var("wen", self.interconnect_input_ports) self._ren = self.var("ren", self.interconnect_output_ports) # Add rate matched # If one input port, let any output port use the wen_in as the ren_in # If more, do the same thing but also provide port selection if self.interconnect_input_ports == 1: self._rate_matched = self.input("rate_matched", self.interconnect_output_ports) self._rate_matched.add_attribute(ConfigRegAttr("Rate matched - 1 or 0")) for i in range(self.interconnect_output_ports): self.wire(self._ren_in[i], kts.ternary(self._rate_matched[i], self._wen_in, self._ren_input[i])) else: self._rate_matched = self.input("rate_matched", 1 + kts.clog2(self.interconnect_input_ports), size=self.interconnect_output_ports, explicit_array=True, packed=True) self._rate_matched.add_attribute(ConfigRegAttr("Rate matched [input port | on/off]")) for i in range(self.interconnect_output_ports): self.wire(self._ren_in[i], kts.ternary(self._rate_matched[i][0], self._wen_in[self._rate_matched[i][kts.clog2(self.interconnect_input_ports), 1]], self._ren_input[i])) self._arb_wen_en = self.var("arb_wen_en", self.interconnect_input_ports) self._arb_ren_en = self.var("arb_ren_en", self.interconnect_output_ports) self._data_from_strg = self.input("data_from_strg", self.data_width, size=(self.banks, self.mem_output_ports, self.fw_int), packed=True, explicit_array=True) self._mem_valid_data = self.input("mem_valid_data", self.mem_output_ports, size=self.banks, explicit_array=True, packed=True) self._out_mem_valid_data = self.var("out_mem_valid_data", self.mem_output_ports, size=self.banks, explicit_array=True, packed=True) # We need to signal valids out of the agg buff, only if one exists... if self.agg_height > 0: self._to_iac_valid = self.var("ab_to_mem_valid", self.interconnect_input_ports) self._data_out = self.output("data_out", self.data_width, size=self.interconnect_output_ports, packed=True, explicit_array=True) self._valid_out = self.output("valid_out", self.interconnect_output_ports) self._valid_out_alt = self.var("valid_out_alt", self.interconnect_output_ports) self._data_to_strg = self.output("data_to_strg", self.data_width, size=(self.banks, self.mem_input_ports, self.fw_int), packed=True, explicit_array=True) # If we can perform a read and a write on the same cycle, # this will necessitate a separate read and write address... if self.rw_same_cycle: self._wr_addr_out = self.output("wr_addr_out", self.address_width, size=(self.banks, self.mem_input_ports), explicit_array=True, packed=True) self._rd_addr_out = self.output("rd_addr_out", self.address_width, size=(self.banks, self.mem_output_ports), explicit_array=True, packed=True) else: self._addr_out = self.output("addr_out", self.address_width, size=(self.banks, self.mem_input_ports), packed=True, explicit_array=True) self._cen_to_strg = self.output("cen_to_strg", self.mem_output_ports, size=self.banks, explicit_array=True, packed=True) self._wen_to_strg = self.output("wen_to_strg", self.mem_input_ports, size=self.banks, explicit_array=True, packed=True) if self.num_tb > 0: self._tb_valid_out = self.var("tb_valid_out", self.interconnect_output_ports) self._port_wens = self.var("port_wens", self.interconnect_input_ports) #################### ##### APP CTRL ##### #################### self._ack_transpose = self.var("ack_transpose", self.banks, size=self.interconnect_output_ports, explicit_array=True, packed=True) self._ack_reduced = self.var("ack_reduced", self.interconnect_output_ports) self.app_ctrl = AppCtrl(interconnect_input_ports=self.interconnect_input_ports, interconnect_output_ports=self.interconnect_output_ports, depth_width=self.app_ctrl_depth_width, sprt_stcl_valid=True, stcl_iter_support=self.stcl_valid_iter) # Some refactoring here for pond to get rid of app controllers... # This is honestly pretty messy and should clean up nicely when we have the spec... self._ren_out_reduced = self.var("ren_out_reduced", self.interconnect_output_ports) if self.num_tb == 0 or self.remove_tb: self.wire(self._wen, self._wen_in) self.wire(self._ren, self._ren_in) self.wire(self._valid_out, self._valid_out_alt) self.wire(self._arb_wen_en, self._wen) self.wire(self._arb_ren_en, self._ren) else: self.add_child("app_ctrl", self.app_ctrl, clk=self._clk, rst_n=self._rst_n, wen_in=self._wen_in, ren_in=self._ren_in, # ren_update=self._tb_valid_out, valid_out_data=self._valid_out, # valid_out_stencil=, wen_out=self._wen, ren_out=self._ren) self.wire(self.app_ctrl.ports.tb_valid, self._tb_valid_out) self.wire(self.app_ctrl.ports.ren_update, self._tb_valid_out) self.app_ctrl_coarse = AppCtrl(interconnect_input_ports=self.interconnect_input_ports, interconnect_output_ports=self.interconnect_output_ports, depth_width=self.app_ctrl_depth_width) self.add_child("app_ctrl_coarse", self.app_ctrl_coarse, clk=self._clk, rst_n=self._rst_n, wen_in=self._to_iac_valid, # self._port_wens & self._to_iac_valid, # Gets valid and the ack ren_in=self._ren_out_reduced, tb_valid=kts.const(0, 1), ren_update=self._ack_reduced, wen_out=self._arb_wen_en, ren_out=self._arb_ren_en) ########################### ##### INPUT AGG SCHED ##### ########################### ########################################### ##### AGGREGATION ALIGNERS (OPTIONAL) ##### ########################################### # These variables are holders and can be swapped out if needed self._data_consume = self._data_in self._valid_consume = self._wen # Zero out if not aligning if(self.agg_height > 0): self._align_to_agg = self.var("align_input", self.interconnect_input_ports) # Add the aggregation buffer aligners if(self.align_input): self._data_consume = self.var("data_consume", self.data_width, size=self.interconnect_input_ports, packed=True, explicit_array=True) self._valid_consume = self.var("valid_consume", self.interconnect_input_ports) # Make new aggregation aligners for each port for i in range(self.interconnect_input_ports): new_child = AggAligner(self.data_width, self.max_line_length) self.add_child(f"agg_align_{i}", new_child, clk=self._clk, rst_n=self._rst_n, in_dat=self._data_in[i], in_valid=self._wen[i], align=self._align_to_agg[i], out_valid=self._valid_consume[i], out_dat=self._data_consume[i]) else: if self.agg_height > 0: self.wire(self._align_to_agg, const(0, self._align_to_agg.width)) ################################################ ##### END: AGGREGATION ALIGNERS (OPTIONAL) ##### ################################################ if self.agg_height == 0: self._to_iac_dat = self._data_consume self._to_iac_valid = self._valid_consume ################################## ##### AGG BUFFERS (OPTIONAL) ##### ################################## # Only instantiate agg_buffer if needed if(self.agg_height > 0): self._to_iac_dat = self.var("ab_to_mem_dat", self.mem_width, size=self.interconnect_input_ports, packed=True, explicit_array=True) # self._to_iac_valid = self.var("ab_to_mem_valid", # self.interconnect_input_ports) self._agg_buffers = [] # Add input aggregations buffers for i in range(self.interconnect_input_ports): # add children aggregator buffers... agg_buffer_new = AggregationBuffer(self.agg_height, self.data_width, self.mem_width, self.max_agg_schedule) self._agg_buffers.append(agg_buffer_new) self.add_child(f"agg_in_{i}", agg_buffer_new, clk=self._clk, rst_n=self._rst_n, data_in=self._data_consume[i], valid_in=self._valid_consume[i], align=self._align_to_agg[i], data_out=self._to_iac_dat[i], valid_out=self._to_iac_valid[i]) ####################################### ##### END: AGG BUFFERS (OPTIONAL) ##### ####################################### self._ready_tba = self.var("ready_tba", self.interconnect_output_ports) #################################### ##### INPUT ADDRESS CONTROLLER ##### #################################### self._wen_to_arb = self.var("wen_to_arb", self.mem_input_ports, size=self.banks, explicit_array=True, packed=True) self._addr_to_arb = self.var("addr_to_arb", self.address_width, size=(self.banks, self.mem_input_ports), explicit_array=True, packed=True) self._data_to_arb = self.var("data_to_arb", self.data_width, size=(self.banks, self.mem_input_ports, self.fw_int), explicit_array=True, packed=True) # Connect these inputs ports to an address generator iac = InputAddrCtrl(interconnect_input_ports=self.interconnect_input_ports, mem_depth=self.mem_depth, num_tiles=self.num_tiles, banks=self.banks, iterator_support=self.input_iterator_support, address_width=self.address_width, data_width=self.data_width, fetch_width=self.mem_width, multiwrite=self.multiwrite, strg_wr_ports=self.mem_input_ports, config_width=self.input_config_width) self.add_child(f"input_addr_ctrl", iac, clk=self._clk, rst_n=self._rst_n, valid_in=self._to_iac_valid, # wen_en=kts.concat(*([kts.const(1, 1)] * self.interconnect_input_ports)), wen_en=self._arb_wen_en, data_in=self._to_iac_dat, wen_to_sram=self._wen_to_arb, addr_out=self._addr_to_arb, port_out=self._port_wens, data_out=self._data_to_arb) ######################################### ##### END: INPUT ADDRESS CONTROLLER ##### ######################################### self._arb_acks = self.var("arb_acks", self.interconnect_output_ports, size=self.banks, explicit_array=True, packed=True) self._prefetch_step = self.var("prefetch_step", self.interconnect_output_ports) self._oac_step = self.var("oac_step", self.interconnect_output_ports) self._oac_valid = self.var("oac_valid", self.interconnect_output_ports) self._ren_out = self.var("ren_out", self.interconnect_output_ports, size=self.banks, explicit_array=True, packed=True) self._ren_out_tpose = self.var("ren_out_tpose", self.banks, size=self.interconnect_output_ports, explicit_array=True, packed=True) self._oac_addr_out = self.var("oac_addr_out", self.address_width, size=self.interconnect_output_ports, explicit_array=True, packed=True) ##################################### ##### OUTPUT ADDRESS CONTROLLER ##### ##################################### oac = OutputAddrCtrl(interconnect_output_ports=self.interconnect_output_ports, mem_depth=self.mem_depth, num_tiles=self.num_tiles, banks=self.banks, iterator_support=self.output_iterator_support, address_width=self.address_width, config_width=self.output_config_width) if self.remove_tb: self.wire(self._oac_valid, self._ren) self.wire(self._oac_step, self._ren) else: self.wire(self._oac_valid, self._prefetch_step) self.wire(self._oac_step, self._ack_reduced) self.chain_idx_bits = max(1, clog2(num_tiles)) self._enable_chain_output = self.input("enable_chain_output", 1) self._chain_idx_output = self.input("chain_idx_output", self.chain_idx_bits) self.add_child(f"output_addr_ctrl", oac, clk=self._clk, rst_n=self._rst_n, valid_in=self._oac_valid, ren=self._ren_out, addr_out=self._oac_addr_out, step_in=self._oac_step) for i in range(self.interconnect_output_ports): for j in range(self.banks): self.wire(self._ren_out_tpose[i][j], self._ren_out[j][i]) ############################## ##### READ/WRITE ARBITER ##### ############################## # Hook up the read write arbiters for each bank self._arb_dat_out = self.var("arb_dat_out", self.data_width, size=(self.banks, self.mem_output_ports, self.fw_int), explicit_array=True, packed=True) self._arb_port_out = self.var("arb_port_out", self.interconnect_output_ports, size=(self.banks, self.mem_output_ports), explicit_array=True, packed=True) self._arb_valid_out = self.var("arb_valid_out", self.mem_output_ports, size=self.banks, explicit_array=True, packed=True) self._rd_sync_gate = self.var("rd_sync_gate", self.interconnect_output_ports) self.arbiters = [] for i in range(self.banks): rw_arb = RWArbiter(fetch_width=self.mem_width, data_width=self.data_width, memory_depth=self.mem_depth, num_tiles=self.num_tiles, int_in_ports=self.interconnect_input_ports, int_out_ports=self.interconnect_output_ports, strg_wr_ports=self.mem_input_ports, strg_rd_ports=self.mem_output_ports, read_delay=self.read_delay, rw_same_cycle=self.rw_same_cycle, separate_addresses=self.rw_same_cycle) self.arbiters.append(rw_arb) self.add_child(f"rw_arb_{i}", rw_arb, clk=self._clk, rst_n=self._rst_n, wen_in=self._wen_to_arb[i], w_data=self._data_to_arb[i], w_addr=self._addr_to_arb[i], data_from_mem=self._data_from_strg[i], mem_valid_data=self._mem_valid_data[i], out_mem_valid_data=self._out_mem_valid_data[i], ren_en=self._arb_ren_en, rd_addr=self._oac_addr_out, out_data=self._arb_dat_out[i], out_port=self._arb_port_out[i], out_valid=self._arb_valid_out[i], cen_mem=self._cen_to_strg[i], wen_mem=self._wen_to_strg[i], data_to_mem=self._data_to_strg[i], out_ack=self._arb_acks[i]) # Bind the separate addrs if self.rw_same_cycle: self.wire(rw_arb.ports.wr_addr_to_mem, self._wr_addr_out[i]) self.wire(rw_arb.ports.rd_addr_to_mem, self._rd_addr_out[i]) else: self.wire(rw_arb.ports.addr_to_mem, self._addr_out[i]) if self.remove_tb: self.wire(rw_arb.ports.ren_in, self._ren_out[i]) else: self.wire(rw_arb.ports.ren_in, self._ren_out[i] & self._rd_sync_gate) self.num_tb_bits = max(1, clog2(self.num_tb)) self._data_to_sync = self.var("data_to_sync", self.data_width, size=(self.interconnect_output_ports, self.fw_int), explicit_array=True, packed=True) self._valid_to_sync = self.var("valid_to_sync", self.interconnect_output_ports) self._data_to_tba = self.var("data_to_tba", self.data_width, size=(self.interconnect_output_ports, self.fw_int), explicit_array=True, packed=True) self._valid_to_tba = self.var("valid_to_tba", self.interconnect_output_ports) self._data_to_pref = self.var("data_to_pref", self.data_width, size=(self.interconnect_output_ports, self.fw_int), explicit_array=True, packed=True) self._valid_to_pref = self.var("valid_to_pref", self.interconnect_output_ports) ####################### ##### DEMUX READS ##### ####################### dmux_rd = DemuxReads(fetch_width=self.mem_width, data_width=self.data_width, banks=self.banks, int_out_ports=self.interconnect_output_ports, strg_rd_ports=self.mem_output_ports) self._arb_dat_out_f = self.var("arb_dat_out_f", self.data_width, size=(self.banks * self.mem_output_ports, self.fw_int), explicit_array=True, packed=True) self._arb_port_out_f = self.var("arb_port_out_f", self.interconnect_output_ports, size=(self.banks * self.mem_output_ports), explicit_array=True, packed=True) self._arb_valid_out_f = self.var("arb_valid_out_f", self.mem_output_ports * self.banks) self._arb_mem_valid_data_f = self.var("arb_mem_valid_data_f", self.mem_output_ports * self.banks) self._arb_mem_valid_data_out = self.var("arb_mem_valid_data_out", self.interconnect_output_ports) self._mem_valid_data_sync = self.var("mem_valid_data_sync", self.interconnect_output_ports) self._mem_valid_data_pref = self.var("mem_valid_data_pref", self.interconnect_output_ports) tmp_cnt = 0 for i in range(self.banks): for j in range(self.mem_output_ports): self.wire(self._arb_dat_out_f[tmp_cnt], self._arb_dat_out[i][j]) self.wire(self._arb_port_out_f[tmp_cnt], self._arb_port_out[i][j]) self.wire(self._arb_valid_out_f[tmp_cnt], self._arb_valid_out[i][j]) self.wire(self._arb_mem_valid_data_f[tmp_cnt], self._out_mem_valid_data[i][j]) tmp_cnt = tmp_cnt + 1 # If this is end of the road... if self.remove_tb: assert self.fw_int == 1, "Make it easier on me now..." self.add_child("demux_rds", dmux_rd, clk=self._clk, rst_n=self._rst_n, data_in=self._arb_dat_out_f, mem_valid_data=self._arb_mem_valid_data_f, mem_valid_data_out=self._arb_mem_valid_data_out, valid_in=self._arb_valid_out_f, port_in=self._arb_port_out_f, valid_out=self._valid_out_alt) for i in range(self.interconnect_output_ports): self.wire(self._data_out[i], dmux_rd.ports.data_out[i]) else: self.add_child("demux_rds", dmux_rd, clk=self._clk, rst_n=self._rst_n, data_in=self._arb_dat_out_f, mem_valid_data=self._arb_mem_valid_data_f, mem_valid_data_out=self._arb_mem_valid_data_out, valid_in=self._arb_valid_out_f, port_in=self._arb_port_out_f, data_out=self._data_to_sync, valid_out=self._valid_to_sync) ####################### ##### SYNC GROUPS ##### ####################### sync_group = SyncGroups(fetch_width=self.mem_width, data_width=self.data_width, int_out_ports=self.interconnect_output_ports) for i in range(self.interconnect_output_ports): self.wire(self._ren_out_reduced[i], self._ren_out_tpose[i].r_or()) self.add_child("sync_grp", sync_group, clk=self._clk, rst_n=self._rst_n, data_in=self._data_to_sync, mem_valid_data=self._arb_mem_valid_data_out, mem_valid_data_out=self._mem_valid_data_sync, valid_in=self._valid_to_sync, data_out=self._data_to_pref, valid_out=self._valid_to_pref, ren_in=self._ren_out_reduced, rd_sync_gate=self._rd_sync_gate, ack_in=self._ack_reduced) # This is the end of the line if we aren't using tb ###################### ##### PREFETCHER ##### ###################### prefetchers = [] for i in range(self.interconnect_output_ports): pref = Prefetcher(fetch_width=self.mem_width, data_width=self.data_width, max_prefetch=self.max_prefetch) prefetchers.append(pref) if self.num_tb == 0: assert self.fw_int == 1, \ "If no transpose buffer, data width needs match memory width" self.add_child(f"pre_fetch_{i}", pref, clk=self._clk, rst_n=self._rst_n, data_in=self._data_to_pref[i], mem_valid_data=self._mem_valid_data_sync[i], mem_valid_data_out=self._mem_valid_data_pref[i], valid_read=self._valid_to_pref[i], tba_rdy_in=self._ren[i], # data_out=self._data_out[i], valid_out=self._valid_out_alt[i], prefetch_step=self._prefetch_step[i]) self.wire(self._data_out[i], pref.ports.data_out[0]) else: self.add_child(f"pre_fetch_{i}", pref, clk=self._clk, rst_n=self._rst_n, data_in=self._data_to_pref[i], mem_valid_data=self._mem_valid_data_sync[i], mem_valid_data_out=self._mem_valid_data_pref[i], valid_read=self._valid_to_pref[i], tba_rdy_in=self._ready_tba[i], data_out=self._data_to_tba[i], valid_out=self._valid_to_tba[i], prefetch_step=self._prefetch_step[i]) ############################# ##### TRANSPOSE BUFFERS ##### ############################# if self.num_tb > 0: self._tb_data_out = self.var("tb_data_out", self.data_width, size=self.interconnect_output_ports, explicit_array=True, packed=True) self._tb_valid_out = self.var("tb_valid_out", self.interconnect_output_ports) for i in range(self.interconnect_output_ports): tba = TransposeBufferAggregation(word_width=self.data_width, fetch_width=self.fw_int, num_tb=self.num_tb, max_tb_height=self.max_tb_height, max_range=self.tb_range_max, max_range_inner=self.tb_range_inner_max, max_stride=self.max_tb_stride, tb_iterator_support=self.tb_iterator_support) self.add_child(f"tba_{i}", tba, clk=self._clk, rst_n=self._rst_n, SRAM_to_tb_data=self._data_to_tba[i], valid_data=self._valid_to_tba[i], tb_index_for_data=0, ack_in=self._valid_to_tba[i], mem_valid_data=self._mem_valid_data_pref[i], tb_to_interconnect_data=self._tb_data_out[i], tb_to_interconnect_valid=self._tb_valid_out[i], tb_arbiter_rdy=self._ready_tba[i], tba_ren=self._ren[i]) for i in range(self.interconnect_output_ports): self.wire(self._data_out[i], self._tb_data_out[i]) # self.wire(self._valid_out[i], self._tb_valid_out[i]) else: self.wire(self._valid_out, self._valid_out_alt) #################### ##### ADD CODE ##### #################### self.add_code(self.transpose_acks) self.add_code(self.reduce_acks)
def __init__(self, iterator_support=6, config_width=16): super().__init__(f"addr_gen_{iterator_support}_{config_width}", debug=True) # Store local... self.iterator_support = iterator_support self.config_width = config_width # PORT DEFS: begin # INPUTS self._clk = self.clock("clk") self._rst_n = self.reset("rst_n") self._restart = self.input("restart", 1) self._strides = self.input("strides", self.config_width, size=self.iterator_support, packed=True, explicit_array=True) self._strides.add_attribute( ConfigRegAttr("Strides of address generator")) self._strides.add_attribute( FormalAttr(f"{self._strides.name}", FormalSignalConstraint.SOLVE)) self._starting_addr = self.input("starting_addr", self.config_width) self._starting_addr.add_attribute( ConfigRegAttr("Starting address of address generator")) self._starting_addr.add_attribute( FormalAttr(f"{self._starting_addr.name}", FormalSignalConstraint.SOLVE)) self._step = self.input("step", 1) self._mux_sel = self.input("mux_sel", max(clog2(self.iterator_support), 1)) # OUTPUTS # TODO why is this config width instead of address width? self._addr_out = self.output("addr_out", self.config_width) # PORT DEFS: end # LOCAL VARIABLES: begin self._strt_addr = self.var("strt_addr", self.config_width) self._calc_addr = self.var("calc_addr", self.config_width) self._max_value = self.var("max_value", self.iterator_support) # LOCAL VARIABLES: end # GENERATION LOGIC: begin self.wire(self._strt_addr, self._starting_addr) self.wire(self._addr_out, self._calc_addr) self._current_addr = self.var("current_addr", self.config_width) # Calculate address by taking previous calculation and adding the muxed stride self.wire(self._calc_addr, self._strt_addr + self._current_addr) self.add_code(self.calculate_address)
def __init__(self, interconnect_input_ports, interconnect_output_ports, depth_width=16, sprt_stcl_valid=False, stcl_cnt_width=16, stcl_iter_support=4): super().__init__("app_ctrl", debug=True) self.int_in_ports = interconnect_input_ports self.int_out_ports = interconnect_output_ports self.depth_width = depth_width self.sprt_stcl_valid = sprt_stcl_valid self.stcl_cnt_width = stcl_cnt_width self.stcl_iter_support = stcl_iter_support # Clock and Reset self._clk = self.clock("clk") self._rst_n = self.reset("rst_n") # IO self._wen_in = self.input("wen_in", self.int_in_ports) self._ren_in = self.input("ren_in", self.int_out_ports) self._ren_update = self.input("ren_update", self.int_out_ports) self._tb_valid = self.input("tb_valid", self.int_out_ports) self._valid_out_data = self.output("valid_out_data", self.int_out_ports) self._valid_out_stencil = self.output("valid_out_stencil", self.int_out_ports) # Send tb valid to valid out for now... if self.sprt_stcl_valid: # Add the config registers to watch self._ranges = self.input("ranges", self.stcl_cnt_width, size=self.stcl_iter_support, packed=True, explicit_array=True) self._ranges.add_attribute(ConfigRegAttr("Ranges of stencil valid generator")) self._threshold = self.input("threshold", self.stcl_cnt_width, size=self.stcl_iter_support, packed=True, explicit_array=True) self._threshold.add_attribute(ConfigRegAttr("Threshold of stencil valid generator")) self._dim_counter = self.var("dim_counter", self.stcl_cnt_width, size=self.stcl_iter_support, packed=True, explicit_array=True) self._update = self.var("update", self.stcl_iter_support) self.wire(self._update[0], const(1, 1)) for i in range(self.stcl_iter_support - 1): self.wire(self._update[i + 1], (self._dim_counter[i] == (self._ranges[i] - 1)) & self._update[i]) for i in range(self.stcl_iter_support): self.add_code(self.dim_counter_update, idx=i) # Now we need to just compute stencil valid threshold_comps = [self._dim_counter[_i] >= self._threshold[_i] for _i in range(self.stcl_iter_support)] self.wire(self._valid_out_stencil[0], kts.concat(*threshold_comps).r_and()) for i in range(self.int_out_ports - 1): # self.wire(self._valid_out_stencil[i + 1], 0) # for multiple ports self.wire(self._valid_out_stencil[i + 1], kts.concat(*threshold_comps).r_and()) else: self.wire(self._valid_out_stencil, self._tb_valid) # Now gate the valid with stencil valid self.wire(self._valid_out_data, self._tb_valid & self._valid_out_stencil) self._wr_delay_state_n = self.var("wr_delay_state_n", self.int_out_ports) self._wen_out = self.output("wen_out", self.int_in_ports) self._ren_out = self.output("ren_out", self.int_out_ports) self._write_depth_wo = self.input("write_depth_wo", self.depth_width, size=self.int_in_ports, explicit_array=True, packed=True) self._write_depth_wo.add_attribute(ConfigRegAttr("Depth of writes")) self._write_depth_ss = self.input("write_depth_ss", self.depth_width, size=self.int_in_ports, explicit_array=True, packed=True) self._write_depth_ss.add_attribute(ConfigRegAttr("Depth of writes")) self._write_depth = self.var("write_depth", self.depth_width, size=self.int_in_ports, explicit_array=True, packed=True) for i in range(self.int_in_ports): self.wire(self._write_depth[i], kts.ternary(self._wr_delay_state_n[i], self._write_depth_ss[i], self._write_depth_wo[i])) self._read_depth = self.input("read_depth", self.depth_width, size=self.int_out_ports, explicit_array=True, packed=True) self._read_depth.add_attribute(ConfigRegAttr("Depth of reads")) self._write_count = self.var("write_count", self.depth_width, size=self.int_in_ports, explicit_array=True, packed=True) self._read_count = self.var("read_count", self.depth_width, size=self.int_out_ports, explicit_array=True, packed=True) self._write_done = self.var("write_done", self.int_in_ports) self._write_done_ff = self.var("write_done_ff", self.int_in_ports) self._read_done = self.var("read_done", self.int_out_ports) self._read_done_ff = self.var("read_done_ff", self.int_out_ports) self.in_port_bits = max(1, kts.clog2(self.int_in_ports)) self._input_port = self.input("input_port", self.in_port_bits, size=self.int_out_ports, explicit_array=True, packed=True) self._input_port.add_attribute(ConfigRegAttr("Relative input port for an output port")) self.out_port_bits = max(1, kts.clog2(self.int_out_ports)) self._output_port = self.input("output_port", self.out_port_bits, size=self.int_in_ports, explicit_array=True, packed=True) self._output_port.add_attribute(ConfigRegAttr("Relative output port for an input port")) self._prefill = self.input("prefill", self.int_out_ports) self._prefill.add_attribute(ConfigRegAttr("Is the input stream prewritten?")) for i in range(self.int_out_ports): self.add_code(self.set_read_done, idx=i) if self.int_in_ports == 1: self.add_code(self.set_read_done_ff_one_wr, idx=i) else: self.add_code(self.set_read_done_ff, idx=i) # self._write_done_comb = self.var("write_done_comb", self.int_in_ports) for i in range(self.int_in_ports): self.add_code(self.set_write_done, idx=i) self.add_code(self.set_write_done_ff, idx=i) for i in range(self.int_in_ports): self.add_code(self.set_write_cnt, idx=i) for i in range(self.int_out_ports): if self.int_in_ports == 1: self.add_code(self.set_read_cnt_one_wr, idx=i) else: self.add_code(self.set_read_cnt, idx=i) for i in range(self.int_out_ports): if self.int_in_ports == 1: self.add_code(self.set_wr_delay_state_one_wr, idx=i) else: self.add_code(self.set_wr_delay_state, idx=i) self._read_on = self.var("read_on", self.int_out_ports) for i in range(self.int_out_ports): self.wire(self._read_on[i], self._read_depth[i].r_or()) # If we have prefill enabled, we are skipping the initial delay step... self.wire(self._ren_out, (self._wr_delay_state_n | self._prefill) & ~self._read_done_ff & self._ren_in & self._read_on) self.wire(self._wen_out, ~self._write_done_ff & self._wen_in)
def add_config_reg(generator, name, description, bitwidth, **kwargs): cfg_reg = generator.input(name, bitwidth, **kwargs) cfg_reg.add_attribute(ConfigRegAttr(description)) return cfg_reg
def __init__(self, fetch_width, data_width, max_prefetch): super().__init__("prefetcher") # Capture to the object self.fetch_width = fetch_width self.data_width = data_width self.fw_int = int(self.fetch_width / self.data_width) self.max_prefetch = max_prefetch # Clock and Reset self._clk = self.clock("clk") self._rst_n = self.reset("rst_n") # Inputs self._data_in = self.input("data_in", self.data_width, size=self.fw_int, explicit_array=True, packed=True) self._valid_read = self.input("valid_read", 1) self._tba_rdy_in = self.input("tba_rdy_in", 1) self._input_latency = self.input("input_latency", clog2(self.max_prefetch) + 1) doc = "This register is set to denote the input latency loop for reads. " + \ "This is sent to an internal fifo and an almost full signal is " + \ "used to pull more reads that the transpose buffers need." self._input_latency.add_attribute(ConfigRegAttr(doc)) self._max_lat = const(self.max_prefetch - 1, clog2(self.max_prefetch) + 1) # Outputs self._data_out = self.output("data_out", self.data_width, size=self.fw_int, explicit_array=True, packed=True) self._valid_out = self.output("valid_out", 1) self._prefetch_step = self.output("prefetch_step", 1) # Local Signals self._cnt = self.var("cnt", clog2(self.max_prefetch) + 1) self._fifo_empty = self.var("fifo_empty", 1) self._fifo_full = self.var("fifo_full", 1) reg_fifo = RegFIFO(data_width=self.data_width, width_mult=self.fw_int, depth=self.max_prefetch) self.add_child("fifo", reg_fifo, clk=self._clk, rst_n=self._rst_n, clk_en=1, data_in=self._data_in, data_out=self._data_out, push=self._valid_read, pop=self._tba_rdy_in, empty=self._fifo_empty, full=self._fifo_full, valid=self._valid_out) # Generate self.add_code(self.update_cnt) self.add_code(self.set_prefetch_step)
def __init__( self, data_width=16, # CGRA Params mem_depth=32, default_iterator_support=3, interconnect_input_ports=2, # Connection to int interconnect_output_ports=2, mem_input_ports=1, mem_output_ports=1, config_data_width=32, config_addr_width=8, cycle_count_width=16, add_clk_enable=True, add_flush=True): super().__init__("pond", debug=True) self.interconnect_input_ports = interconnect_input_ports self.interconnect_output_ports = interconnect_output_ports self.mem_input_ports = mem_input_ports self.mem_output_ports = mem_output_ports self.mem_depth = mem_depth self.data_width = data_width self.config_data_width = config_data_width self.config_addr_width = config_addr_width self.add_clk_enable = add_clk_enable self.add_flush = add_flush self.cycle_count_width = cycle_count_width self.default_iterator_support = default_iterator_support self.default_config_width = kts.clog2(self.mem_depth) # inputs self._clk = self.clock("clk") self._clk.add_attribute( FormalAttr(f"{self._clk.name}", FormalSignalConstraint.CLK)) self._rst_n = self.reset("rst_n") self._rst_n.add_attribute( FormalAttr(f"{self._rst_n.name}", FormalSignalConstraint.RSTN)) self._clk_en = self.clock_en("clk_en", 1) # Enable/Disable tile self._tile_en = self.input("tile_en", 1) self._tile_en.add_attribute( ConfigRegAttr("Tile logic enable manifested as clock gate")) gclk = self.var("gclk", 1) self._gclk = kts.util.clock(gclk) self.wire(gclk, kts.util.clock(self._clk & self._tile_en)) self._cycle_count = add_counter(self, "cycle_count", self.cycle_count_width) # Create write enable + addr, same for read. # self._write = self.input("write", self.interconnect_input_ports) self._write = self.var("write", self.mem_input_ports) # self._write.add_attribute(ControlSignalAttr(is_control=True)) self._write_addr = self.var("write_addr", kts.clog2(self.mem_depth), size=self.interconnect_input_ports, explicit_array=True, packed=True) # Add "_pond" suffix to avoid error during garnet RTL generation self._data_in = self.input("data_in_pond", self.data_width, size=self.interconnect_input_ports, explicit_array=True, packed=True) self._data_in.add_attribute( FormalAttr(f"{self._data_in.name}", FormalSignalConstraint.SEQUENCE)) self._data_in.add_attribute(ControlSignalAttr(is_control=False)) self._read = self.var("read", self.mem_output_ports) self._t_write = self.var("t_write", self.interconnect_input_ports) self._t_read = self.var("t_read", self.interconnect_output_ports) # self._read.add_attribute(ControlSignalAttr(is_control=True)) self._read_addr = self.var("read_addr", kts.clog2(self.mem_depth), size=self.interconnect_output_ports, explicit_array=True, packed=True) self._s_read_addr = self.var("s_read_addr", kts.clog2(self.mem_depth), size=self.interconnect_output_ports, explicit_array=True, packed=True) self._data_out = self.output("data_out_pond", self.data_width, size=self.interconnect_output_ports, explicit_array=True, packed=True) self._data_out.add_attribute( FormalAttr(f"{self._data_out.name}", FormalSignalConstraint.SEQUENCE)) self._data_out.add_attribute(ControlSignalAttr(is_control=False)) self._valid_out = self.output("valid_out_pond", self.interconnect_output_ports) self._valid_out.add_attribute( FormalAttr(f"{self._valid_out.name}", FormalSignalConstraint.SEQUENCE)) self._valid_out.add_attribute(ControlSignalAttr(is_control=False)) self._mem_data_out = self.var("mem_data_out", self.data_width, size=self.mem_output_ports, explicit_array=True, packed=True) self._s_mem_data_in = self.var("s_mem_data_in", self.data_width, size=self.interconnect_input_ports, explicit_array=True, packed=True) self._mem_data_in = self.var("mem_data_in", self.data_width, size=self.mem_input_ports, explicit_array=True, packed=True) self._s_mem_write_addr = self.var("s_mem_write_addr", kts.clog2(self.mem_depth), size=self.interconnect_input_ports, explicit_array=True, packed=True) self._s_mem_read_addr = self.var("s_mem_read_addr", kts.clog2(self.mem_depth), size=self.interconnect_output_ports, explicit_array=True, packed=True) self._mem_write_addr = self.var("mem_write_addr", kts.clog2(self.mem_depth), size=self.mem_input_ports, explicit_array=True, packed=True) self._mem_read_addr = self.var("mem_read_addr", kts.clog2(self.mem_depth), size=self.mem_output_ports, explicit_array=True, packed=True) if self.interconnect_output_ports == 1: self.wire(self._data_out[0], self._mem_data_out[0]) else: for i in range(self.interconnect_output_ports): self.wire(self._data_out[i], self._mem_data_out[0]) # Valid out is simply passing the read signal through... self.wire(self._valid_out, self._t_read) # Create write addressors for wr_port in range(self.interconnect_input_ports): RF_WRITE_ITER = ForLoop( iterator_support=self.default_iterator_support, config_width=self.cycle_count_width) RF_WRITE_ADDR = AddrGen( iterator_support=self.default_iterator_support, config_width=self.default_config_width) RF_WRITE_SCHED = SchedGen( iterator_support=self.default_iterator_support, config_width=self.cycle_count_width, use_enable=True) self.add_child(f"rf_write_iter_{wr_port}", RF_WRITE_ITER, clk=self._gclk, rst_n=self._rst_n, step=self._t_write[wr_port]) # Whatever comes through here should hopefully just pipe through seamlessly # addressor modules self.add_child(f"rf_write_addr_{wr_port}", RF_WRITE_ADDR, clk=self._gclk, rst_n=self._rst_n, step=self._t_write[wr_port], mux_sel=RF_WRITE_ITER.ports.mux_sel_out, restart=RF_WRITE_ITER.ports.restart) safe_wire(self, self._write_addr[wr_port], RF_WRITE_ADDR.ports.addr_out) self.add_child(f"rf_write_sched_{wr_port}", RF_WRITE_SCHED, clk=self._gclk, rst_n=self._rst_n, mux_sel=RF_WRITE_ITER.ports.mux_sel_out, finished=RF_WRITE_ITER.ports.restart, cycle_count=self._cycle_count, valid_output=self._t_write[wr_port]) # Create read addressors for rd_port in range(self.interconnect_output_ports): RF_READ_ITER = ForLoop( iterator_support=self.default_iterator_support, config_width=self.cycle_count_width) RF_READ_ADDR = AddrGen( iterator_support=self.default_iterator_support, config_width=self.default_config_width) RF_READ_SCHED = SchedGen( iterator_support=self.default_iterator_support, config_width=self.cycle_count_width, use_enable=True) self.add_child(f"rf_read_iter_{rd_port}", RF_READ_ITER, clk=self._gclk, rst_n=self._rst_n, step=self._t_read[rd_port]) self.add_child(f"rf_read_addr_{rd_port}", RF_READ_ADDR, clk=self._gclk, rst_n=self._rst_n, step=self._t_read[rd_port], mux_sel=RF_READ_ITER.ports.mux_sel_out, restart=RF_READ_ITER.ports.restart) if self.interconnect_output_ports > 1: safe_wire(self, self._read_addr[rd_port], RF_READ_ADDR.ports.addr_out) else: safe_wire(self, self._read_addr[rd_port], RF_READ_ADDR.ports.addr_out) self.add_child(f"rf_read_sched_{rd_port}", RF_READ_SCHED, clk=self._gclk, rst_n=self._rst_n, mux_sel=RF_READ_ITER.ports.mux_sel_out, finished=RF_READ_ITER.ports.restart, cycle_count=self._cycle_count, valid_output=self._t_read[rd_port]) self.wire(self._write, self._t_write.r_or()) self.wire(self._mem_write_addr[0], decode(self, self._t_write, self._s_mem_write_addr)) self.wire(self._mem_data_in[0], decode(self, self._t_write, self._s_mem_data_in)) self.wire(self._read, self._t_read.r_or()) self.wire(self._mem_read_addr[0], decode(self, self._t_read, self._s_mem_read_addr)) # =================================== # Instantiate config hooks... # =================================== self.fw_int = 1 self.data_words_per_set = 2**self.config_addr_width self.sets = int( (self.fw_int * self.mem_depth) / self.data_words_per_set) self.sets_per_macro = max( 1, int(self.mem_depth / self.data_words_per_set)) self.total_sets = max(1, 1 * self.sets_per_macro) self._config_data_in = self.input("config_data_in", self.config_data_width) self._config_data_in.add_attribute(ControlSignalAttr(is_control=False)) self._config_data_in_shrt = self.var("config_data_in_shrt", self.data_width) self.wire(self._config_data_in_shrt, self._config_data_in[self.data_width - 1, 0]) self._config_addr_in = self.input("config_addr_in", self.config_addr_width) self._config_addr_in.add_attribute(ControlSignalAttr(is_control=False)) self._config_data_out_shrt = self.var("config_data_out_shrt", self.data_width, size=self.total_sets, explicit_array=True, packed=True) self._config_data_out = self.output("config_data_out", self.config_data_width, size=self.total_sets, explicit_array=True, packed=True) self._config_data_out.add_attribute( ControlSignalAttr(is_control=False)) for i in range(self.total_sets): self.wire( self._config_data_out[i], self._config_data_out_shrt[i].extend(self.config_data_width)) self._config_read = self.input("config_read", 1) self._config_read.add_attribute(ControlSignalAttr(is_control=False)) self._config_write = self.input("config_write", 1) self._config_write.add_attribute(ControlSignalAttr(is_control=False)) self._config_en = self.input("config_en", self.total_sets) self._config_en.add_attribute(ControlSignalAttr(is_control=False)) self._mem_data_cfg = self.var("mem_data_cfg", self.data_width, explicit_array=True, packed=True) self._mem_addr_cfg = self.var("mem_addr_cfg", kts.clog2(self.mem_depth)) # Add config... stg_cfg_seq = StorageConfigSeq( data_width=self.data_width, config_addr_width=self.config_addr_width, addr_width=kts.clog2(self.mem_depth), fetch_width=self.data_width, total_sets=self.total_sets, sets_per_macro=self.sets_per_macro) # The clock to config sequencer needs to be the normal clock or # if the tile is off, we bring the clock back in based on config_en cfg_seq_clk = self.var("cfg_seq_clk", 1) self._cfg_seq_clk = kts.util.clock(cfg_seq_clk) self.wire(cfg_seq_clk, kts.util.clock(self._gclk)) self.add_child(f"config_seq", stg_cfg_seq, clk=self._cfg_seq_clk, rst_n=self._rst_n, clk_en=self._clk_en | self._config_en.r_or(), config_data_in=self._config_data_in_shrt, config_addr_in=self._config_addr_in, config_wr=self._config_write, config_rd=self._config_read, config_en=self._config_en, wr_data=self._mem_data_cfg, rd_data_out=self._config_data_out_shrt, addr_out=self._mem_addr_cfg) if self.interconnect_output_ports == 1: self.wire(stg_cfg_seq.ports.rd_data_stg, self._mem_data_out) else: self.wire(stg_cfg_seq.ports.rd_data_stg[0], self._mem_data_out[0]) self.RF_GEN = RegisterFile(data_width=self.data_width, write_ports=self.mem_input_ports, read_ports=self.mem_output_ports, width_mult=1, depth=self.mem_depth, read_delay=0) # Now we can instantiate and wire up the register file self.add_child(f"rf", self.RF_GEN, clk=self._gclk, rst_n=self._rst_n, data_out=self._mem_data_out) # Opt in for config_write self._write_rf = self.var("write_rf", self.mem_input_ports) self.wire( self._write_rf[0], kts.ternary(self._config_en.r_or(), self._config_write, self._write[0])) for i in range(self.mem_input_ports - 1): self.wire( self._write_rf[i + 1], kts.ternary(self._config_en.r_or(), kts.const(0, 1), self._write[i + 1])) self.wire(self.RF_GEN.ports.wen, self._write_rf) # Opt in for config_data_in for i in range(self.interconnect_input_ports): self.wire( self._s_mem_data_in[i], kts.ternary(self._config_en.r_or(), self._mem_data_cfg, self._data_in[i])) self.wire(self.RF_GEN.ports.data_in, self._mem_data_in) # Opt in for config_addr for i in range(self.interconnect_input_ports): self.wire( self._s_mem_write_addr[i], kts.ternary(self._config_en.r_or(), self._mem_addr_cfg, self._write_addr[i])) self.wire(self.RF_GEN.ports.wr_addr, self._mem_write_addr[0]) for i in range(self.interconnect_output_ports): self.wire( self._s_mem_read_addr[i], kts.ternary(self._config_en.r_or(), self._mem_addr_cfg, self._read_addr[i])) self.wire(self.RF_GEN.ports.rd_addr, self._mem_read_addr[0]) if self.add_clk_enable: # self.clock_en("clk_en") kts.passes.auto_insert_clock_enable(self.internal_generator) clk_en_port = self.internal_generator.get_port("clk_en") clk_en_port.add_attribute(ControlSignalAttr(False)) if self.add_flush: self.add_attribute("sync-reset=flush") kts.passes.auto_insert_sync_reset(self.internal_generator) flush_port = self.internal_generator.get_port("flush") flush_port.add_attribute(ControlSignalAttr(True)) # Finally, lift the config regs... lift_config_reg(self.internal_generator)
def __init__(self, interconnect_input_ports=2, mem_depth=32, num_tiles=1, banks=1, iterator_support=6, address_width=5, data_width=16, fetch_width=16, multiwrite=1, strg_wr_ports=2, config_width=16): super().__init__("input_addr_ctrl", debug=True) assert multiwrite >= 1, "Multiwrite must be at least 1..." self.interconnect_input_ports = interconnect_input_ports self.mem_depth = mem_depth self.num_tiles = num_tiles self.banks = banks self.iterator_support = iterator_support self.address_width = address_width self.port_sched_width = max(1, clog2(self.interconnect_input_ports)) self.data_width = data_width self.fetch_width = fetch_width self.fw_int = int(self.fetch_width / self.data_width) self.multiwrite = multiwrite self.strg_wr_ports = strg_wr_ports self.config_width = config_width self.mem_addr_width = clog2(self.num_tiles * self.mem_depth) if self.banks > 1: self.bank_addr_width = clog2(self.banks) else: self.bank_addr_width = 0 self.address_width = self.mem_addr_width + self.bank_addr_width # Clock and Reset self._clk = self.clock("clk") self._rst_n = self.reset("rst_n") # Inputs # phases = [] TODO # Take in the valid and data and attach an address + direct to a port self._valid_in = self.input("valid_in", self.interconnect_input_ports) self._wen_en = self.input("wen_en", self.interconnect_input_ports) self._wen_en_saved = self.var("wen_en_saved", self.interconnect_input_ports) self._data_in = self.input("data_in", self.data_width, size=(self.interconnect_input_ports, self.fw_int), explicit_array=True, packed=True) self._data_in_saved = self.var("data_in_saved", self.data_width, size=(self.interconnect_input_ports, self.fw_int), explicit_array=True, packed=True) # Outputs self._wen = self.output("wen_to_sram", self.strg_wr_ports, size=self.banks, explicit_array=True, packed=True) wen_full_size = (self.interconnect_input_ports, self.multiwrite) self._wen_full = self.var("wen_full", self.banks, size=wen_full_size, explicit_array=True, packed=True) self._wen_reduced = self.var("wen_reduced", self.banks, size=self.interconnect_input_ports, explicit_array=True, packed=True) self._wen_reduced_saved = self.var("wen_reduced_saved", self.banks, size=self.interconnect_input_ports, explicit_array=True, packed=True) self._addresses = self.output("addr_out", self.mem_addr_width, size=(self.banks, self.strg_wr_ports), explicit_array=True, packed=True) self._data_out = self.output("data_out", self.data_width, size=(self.banks, self.strg_wr_ports, self.fw_int), explicit_array=True, packed=True) self._port_out_exp = self.var("port_out_exp", self.interconnect_input_ports, size=self.banks, explicit_array=True, packed=True) self._port_out = self.output("port_out", self.interconnect_input_ports) self._counter = self.var("counter", self.port_sched_width) # Wire to port out for i in range(self.interconnect_input_ports): new_tmp = [] for j in range(self.banks): new_tmp.append(self._port_out_exp[j][i]) self.wire(self._port_out[i], kts.concat(*new_tmp).r_or()) self._done = self.var("done", self.strg_wr_ports, size=self.banks, explicit_array=True, packed=True) # LOCAL VARS self._local_addrs = self.var("local_addrs", self.address_width, size=(self.interconnect_input_ports, self.multiwrite), packed=True, explicit_array=True) self._local_addrs_saved = self.var("local_addrs_saved", self.address_width, size=(self.interconnect_input_ports, self.multiwrite), packed=True, explicit_array=True) for i in range(self.interconnect_input_ports): for j in range(self.banks): concat_ports = [] for k in range(self.multiwrite): concat_ports.append(self._wen_full[i][k][j]) self.wire(self._wen_reduced[i][j], kts.concat(*concat_ports).r_or()) if self.banks == 1 and self.interconnect_input_ports == 1: self.wire(self._wen_full[0][0][0], self._valid_in) elif self.banks == 1 and self.interconnect_input_ports > 1: self.add_code(self.set_wen_single) else: self.add_code(self.set_wen_mult) # MAIN # Iterate through all banks to priority decode the wen self.add_code(self.decode_out_lowest) # Also set the write ports on the storage if self.strg_wr_ports > 1: self._idx_cnt = self.var("idx_cnt", 8, size=(self.banks, self.strg_wr_ports - 1), explicit_array=True, packed=True) for i in range(self.strg_wr_ports - 1): self.add_code(self.decode_out_alt, idx=i + 1) # Now we should instantiate the child address generators # (1 per input port) to send to the sram banks for i in range(self.interconnect_input_ports): self.add_child(f"address_gen_{i}", AddrGen(iterator_support=self.iterator_support, config_width=self.config_width), clk=self._clk, rst_n=self._rst_n, clk_en=const(1, 1), flush=const(0, 1), step=self._valid_in[i]) # Need to check that the address falls into the bank for implicit banking # Then, obey the input schedule to send the proper Aggregator to the output # The wen to sram should be that the valid for the selected port is high # Do the same thing for the output address assert self.multiwrite <= self.banks and self.multiwrite > 0,\ "Multiwrite should be between 1 and banks" if self.multiwrite > 1: size = (self.interconnect_input_ports, self.multiwrite - 1) self._offsets_cfg = self.input("offsets_cfg", self.address_width, size=size, packed=True, explicit_array=True) doc = "These offsets provide the ability to write to multiple banks explicitly" self._offsets_cfg.add_attribute(ConfigRegAttr(doc)) self.add_code(self.set_multiwrite_addrs) # to handle multiple input ports going to fewer SRAM write ports self.add_code(self.set_int_ports_counter) self.add_code(self.save_mult_int_signals)