def __init__(self, interconnect_output_ports, mem_depth, num_tiles, banks, iterator_support, address_width, data_width, fetch_width, chain_idx_output): self.interconnect_output_ports = interconnect_output_ports self.mem_depth = mem_depth self.num_tiles = num_tiles self.banks = banks self.iterator_support = iterator_support self.address_width = address_width self.data_width = data_width self.fetch_width = fetch_width self.fw_int = int(self.fetch_width / self.data_width) self.chain_idx_output = chain_idx_output self.config = {} # Create child address generators self.addr_gens = [] for i in range(self.interconnect_output_ports): new_addr_gen = AddrGenModel(iterator_support=self.iterator_support, address_width=self.address_width) self.addr_gens.append(new_addr_gen) self.mem_addr_width = kts.clog2(self.num_tiles * self.mem_depth) self.chain_idx_bits = max(1, kts.clog2(self.num_tiles)) # Get local list of addresses self.addresses = [] for i in range(self.interconnect_output_ports): self.addresses.append(0) # Initialize the configuration for i in range(self.interconnect_output_ports): self.config[f"address_gen_{i}_starting_addr"] = 0 self.config[f"address_gen_{i}_dimensionality"] = 0 for j in range(self.iterator_support): self.config[f"address_gen_{i}_strides_{j}"] = 0 self.config[f"address_gen_{i}_ranges_{j}"] = 0 # Set up the wen self.ren = [] self.mem_addresses = [] for i in range(self.banks): self.ren.append([]) for j in range(self.interconnect_output_ports): self.ren[i].append(0) for i in range(self.interconnect_output_ports): self.mem_addresses.append(0)
def __init__(self, _params: GlobalBufferParams): super().__init__(f"glb_loop_iter") self._params = _params # INPUTS self.clk = self.clock("clk") self.clk_en = self.clock_en("clk_en") self.reset = self.reset("reset") self.ranges = self.input("ranges", self._params.axi_data_width, size=self._params.loop_level, packed=True, explicit_array=True) self.dim = self.input("dim", 1 + clog2(self._params.loop_level)) self.step = self.input("step", 1) self.mux_sel_out = self.output("mux_sel_out", max(clog2(self._params.loop_level), 1)) self.restart = self.output("restart", 1) # local varaibles self.dim_counter = self.var("dim_counter", self._params.axi_data_width, size=self._params.loop_level, packed=True, explicit_array=True) self.max_value = self.var("max_value", self._params.loop_level) self.mux_sel = self.var("mux_sel", max(clog2(self._params.loop_level), 1)) self.wire(self.mux_sel_out, self.mux_sel) self.not_done = self.var("not_done", 1) self.clear = self.var("clear", self._params.loop_level) self.inc = self.var("inc", self._params.loop_level) self.is_maxed = self.var("is_maxed", 1) self.wire(self.is_maxed, (self.dim_counter[self.mux_sel] == self.ranges[self.mux_sel]) & self.inc[self.mux_sel]) self.add_code(self.set_mux_sel) for i in range(self._params.loop_level): self.add_code(self.set_clear, idx=i) self.add_code(self.set_inc, idx=i) self.add_code(self.dim_counter_update, idx=i) self.add_code(self.max_value_update, idx=i) self.wire(self.restart, self.step & (~self.not_done))
def test_nested_scope(): from kratos import clog2 mod = Generator("FindHighestBit", True) width = 4 data = mod.input("data", width) h_bit = mod.output("h_bit", clog2(width)) done = mod.var("done", 1) @always_comb def find_bit(): done = 0 h_bit = 0 for i in range(width): if ~done: if data[i]: done = 1 h_bit = i mod.add_always(find_bit, label="block") verilog(mod, insert_debug_info=True) block = mod.get_marked_stmt("block") last_if = block[-1] for i in range(len(last_if.then_[-1].then_)): stmt = last_if.then_[-1].then_[i] context = stmt.scope_context if len(context) > 0: assert "i" in context is_var, var = context["i"] assert not is_var assert var == "3"
def __init__(self, _params: GlobalBufferParams): super().__init__(f"glb_addr_gen") self._params = _params self.clk = self.clock("clk") self.clk_en = self.clock_en("clk_en") self.reset = self.reset("reset") self.restart = self.input("restart", 1) self.strides = self.input("strides", self._params.axi_data_width, size=self._params.loop_level, packed=True, explicit_array=True) self.start_addr = self.input("start_addr", self._params.axi_data_width) self.step = self.input("step", 1) self.mux_sel = self.input("mux_sel", max(clog2(self._params.loop_level), 1)) self.addr_out = self.output("addr_out", self._params.axi_data_width) # local variables self.current_addr = self.var("current_addr", self._params.axi_data_width) # output address self.wire(self.addr_out, self.start_addr + self.current_addr) self.add_always(self.calculate_address)
def __init__(self, data_width, width_mult, depth, num_tiles): self.data_width = data_width self.width_mult = width_mult self.depth = depth self.num_tiles = num_tiles self.address_width = kts.clog2(self.num_tiles * self.depth) self.chain_idx_bits = max(1, kts.clog2(num_tiles)) self.chain_idx_tile = 0 self.rd_reg = [] for i in range(self.width_mult): self.rd_reg.append(0) self.mem = [] for i in range(self.depth): row = [] for j in range(self.width_mult): row.append(0) self.mem.append(row)
def __init__(self, width, depth): super().__init__("FIFO") in_ = self.input("in", width) out = self.output("out", width) ren = self.input("ren", 1) wen = self.input("wen", 1) data = self.var("data", width, size=depth) clk = self.clock("clk") reset = self.reset("rst") read_ptr = self.var("read_ptr", clog2(depth)) write_ptr = self.var("write_ptr", clog2(depth)) full = self.output("is_full", 1) empty = self.output("is_empty", 1) full_next = self.var("is_full_next", 1) self.wire(empty, ~full & (read_ptr == write_ptr)) def comb(): if wen & (~ren) & ((write_ptr + 1) == read_ptr): full_next = True elif wen and full: full_next = False else: full_next = full @always((posedge, "clk"), (posedge, "rst")) def seq(): if reset: read_ptr = 0 write_ptr = 0 full = 0 else: if ren: read_ptr = read_ptr + 1 out = data[read_ptr] if wen: write_ptr = write_ptr + 1 data[write_ptr] = in_ full = full_next self.add_code(comb) self.add_code(seq)
def mem_ff(self): if self.CEB == 0: self.Q_w = concat( self.mem[resize((self.A << 2) + 3, self.addr_width + 2)], self.mem[resize((self.A << 2) + 2, self.addr_width + 2)], self.mem[resize((self.A << 2) + 1, self.addr_width + 2)], self.mem[resize( (self.A << 2), self.addr_width + 2)]) if self.WEB == 0: for i in range(self.data_width): if self.BWEB[i] == 0: self.mem[resize( (self.A << 2) + i // 16, self.addr_width + 2)][resize( i % 16, clog2( self._params.cgra_data_width))] = self.D[i]
def __init__(self, width: int, depth: int, flatten_output=False, reset_high=False): name_suffix = "" if flatten_output: name_suffix += "_array" if reset_high: name_suffix += "_reset_high" super().__init__(f"pipeline_w_{width}_d_{depth}{name_suffix}") self.clk = self.clock("clk") self.clk_en = self.clock_en("clk_en") self.reset = self.reset("reset") self.width = width self.depth = depth self.reset_high = reset_high if self.depth == 0: self.in_ = self.input("in_", self.width) self.out_ = self.output("out_", self.width) self.wire(self.out_, self.in_) else: self.depth_width = max(clog2(self.depth), 1) self.in_ = self.input("in_", self.width) if flatten_output: self.out_ = self.output("out_", self.width, size=self.depth) else: self.out_ = self.output("out_", self.width) if self.depth == 1 and self.width == 1: self.pipeline_r = self.var("pipeline_r", width=self.width, size=self.depth) else: self.pipeline_r = self.var("pipeline_r", width=self.width, size=self.depth, explicit_array=True) if flatten_output: self.wire(self.out_, self.pipeline_r) else: self.wire(self.out_, self.pipeline_r[self.depth - 1]) self.add_always(self.pipeline)
def add_done_pulse_pipeline(self): maximum_latency = 2 * self._params.num_glb_tiles + self.default_latency latency_width = clog2(maximum_latency) self.done_pulse_d_arr = self.var( "done_pulse_d_arr", 1, size=maximum_latency, explicit_array=True) self.done_pulse_pipeline = Pipeline(width=1, depth=maximum_latency, flatten_output=True) self.add_child("done_pulse_pipeline", self.done_pulse_pipeline, clk=self.clk, clk_en=self.clk_en, reset=self.reset, in_=self.done_pulse_w, out_=self.done_pulse_d_arr) self.wire(self.st_dma_done_pulse, self.done_pulse_d_arr[resize(self.cfg_data_network_latency, latency_width) + self.default_latency])
def __init__(self, use_sram_stub, sram_name, data_width, fw_int, mem_depth, mem_input_ports, mem_output_ports, address_width, bank_num, num_tiles, # configuration registers passed down from top level enable_chain_input, enable_chain_output, chain_idx_input, chain_idx_output): # generation parameters self.use_sram_stub = use_sram_stub self.sram_name = sram_name self.data_width = data_width self.fw_int = fw_int self.mem_depth = mem_depth self.mem_input_ports = mem_input_ports self.mem_output_ports = mem_output_ports self.address_width = address_width self.bank_num = bank_num self.num_tiles = num_tiles # configuration registers passed down from top level self.enable_chain_input = enable_chain_input self.enable_chain_output = enable_chain_output self.chain_idx_input = chain_idx_input self.chain_idx_output = chain_idx_output self.chain_idx_bits = max(1, kts.clog2(num_tiles)) self.prev_wen = 0 self.prev_cen = 0 self.sram = SRAMModel(data_width, fw_int, mem_depth, num_tiles)
def add_strm_rd_addr_pipeline(self): maximum_latency = 2 * self._params.num_glb_tiles + self.default_latency latency_width = clog2(maximum_latency) self.strm_rd_addr_d_arr = self.var("strm_rd_addr_d_arr", width=self._params.glb_addr_width, size=maximum_latency, explicit_array=True) self.strm_rd_addr_pipeline = Pipeline( width=self._params.glb_addr_width, depth=maximum_latency, flatten_output=True) self.add_child("strm_rd_addr_pipeline", self.strm_rd_addr_pipeline, clk=self.clk, clk_en=self.clk_en, reset=self.reset, in_=self.strm_rd_addr_w, out_=self.strm_rd_addr_d_arr) self.strm_data_sel = self.strm_rd_addr_d_arr[ resize(self.cfg_data_network_latency, latency_width) + self.default_latency][self._params.bank_byte_offset - 1, self._params.cgra_byte_offset]
def tile2tile_w2e_wiring(self): self.wire(self.proc_packet_w2e_wsti[0], self.proc_packet_d) self.wire(self.strm_packet_w2e_wsti[0], 0) self.wire(self.pcfg_packet_w2e_wsti[0], 0) for i in range(1, self._params.num_glb_tiles): self.wire( self.proc_packet_w2e_wsti[const( i, clog2(self._params.num_glb_tiles))], self.proc_packet_w2e_esto[const( (i - 1), clog2(self._params.num_glb_tiles))]) self.wire( self.strm_packet_w2e_wsti[const( i, clog2(self._params.num_glb_tiles))], self.strm_packet_w2e_esto[const( (i - 1), clog2(self._params.num_glb_tiles))]) self.wire( self.pcfg_packet_w2e_wsti[const( i, clog2(self._params.num_glb_tiles))], self.pcfg_packet_w2e_esto[const( (i - 1), clog2(self._params.num_glb_tiles))])
def __init__( self, data_width=16, # CGRA Params mem_width=64, mem_depth=512, banks=1, input_iterator_support=6, # Addr Controllers output_iterator_support=6, input_config_width=16, output_config_width=16, interconnect_input_ports=2, # Connection to int interconnect_output_ports=2, mem_input_ports=1, mem_output_ports=1, use_sram_stub=1, sram_macro_info=SRAMMacroInfo("TS1N16FFCLLSBLVTC512X32M4S"), read_delay=1, # Cycle delay in read (SRAM vs Register File) rw_same_cycle=False, # Does the memory allow r+w in same cycle? agg_height=4, max_agg_schedule=16, input_max_port_sched=16, output_max_port_sched=16, align_input=1, max_line_length=128, max_tb_height=1, tb_range_max=1024, tb_range_inner_max=64, tb_sched_max=16, max_tb_stride=15, num_tb=1, tb_iterator_support=2, multiwrite=1, max_prefetch=8, config_data_width=32, config_addr_width=8, num_tiles=2, app_ctrl_depth_width=16, remove_tb=False, fifo_mode=True, add_clk_enable=True, add_flush=True, core_reset_pos=False, stcl_valid_iter=4): super().__init__(config_addr_width, config_data_width) # Capture everything to the tile object self.data_width = data_width self.mem_width = mem_width self.mem_depth = mem_depth self.banks = banks self.fw_int = int(self.mem_width / self.data_width) self.input_iterator_support = input_iterator_support self.output_iterator_support = output_iterator_support self.input_config_width = input_config_width self.output_config_width = output_config_width self.interconnect_input_ports = interconnect_input_ports self.interconnect_output_ports = interconnect_output_ports self.mem_input_ports = mem_input_ports self.mem_output_ports = mem_output_ports self.use_sram_stub = use_sram_stub self.sram_macro_info = sram_macro_info self.read_delay = read_delay self.rw_same_cycle = rw_same_cycle self.agg_height = agg_height self.max_agg_schedule = max_agg_schedule self.input_max_port_sched = input_max_port_sched self.output_max_port_sched = output_max_port_sched self.align_input = align_input self.max_line_length = max_line_length self.max_tb_height = max_tb_height self.tb_range_max = tb_range_max self.tb_range_inner_max = tb_range_inner_max self.tb_sched_max = tb_sched_max self.max_tb_stride = max_tb_stride self.num_tb = num_tb self.tb_iterator_support = tb_iterator_support self.multiwrite = multiwrite self.max_prefetch = max_prefetch self.config_data_width = config_data_width self.config_addr_width = config_addr_width self.num_tiles = num_tiles self.remove_tb = remove_tb self.fifo_mode = fifo_mode self.add_clk_enable = add_clk_enable self.add_flush = add_flush self.core_reset_pos = core_reset_pos self.app_ctrl_depth_width = app_ctrl_depth_width self.stcl_valid_iter = stcl_valid_iter # Typedefs for ease TData = magma.Bits[self.data_width] TBit = magma.Bits[1] self.__inputs = [] self.__outputs = [] # Enumerate input and output ports # (clk and reset are assumed) if self.interconnect_input_ports > 1: for i in range(self.interconnect_input_ports): self.add_port(f"addr_in_{i}", magma.In(TData)) self.__inputs.append(self.ports[f"addr_in_{i}"]) self.add_port(f"data_in_{i}", magma.In(TData)) self.__inputs.append(self.ports[f"data_in_{i}"]) self.add_port(f"wen_in_{i}", magma.In(TBit)) self.__inputs.append(self.ports[f"wen_in_{i}"]) else: self.add_port("addr_in", magma.In(TData)) self.__inputs.append(self.ports[f"addr_in"]) self.add_port("data_in", magma.In(TData)) self.__inputs.append(self.ports[f"data_in"]) self.add_port("wen_in", magma.In(TBit)) self.__inputs.append(self.ports.wen_in) if self.interconnect_output_ports > 1: for i in range(self.interconnect_output_ports): self.add_port(f"data_out_{i}", magma.Out(TData)) self.__outputs.append(self.ports[f"data_out_{i}"]) self.add_port(f"ren_in_{i}", magma.In(TBit)) self.__inputs.append(self.ports[f"ren_in_{i}"]) self.add_port(f"valid_out_{i}", magma.Out(TBit)) self.__outputs.append(self.ports[f"valid_out_{i}"]) # Chaining self.add_port(f"chain_valid_in_{i}", magma.In(TBit)) self.__inputs.append(self.ports[f"chain_valid_in_{i}"]) self.add_port(f"chain_data_in_{i}", magma.In(TData)) self.__inputs.append(self.ports[f"chain_data_in_{i}"]) self.add_port(f"chain_data_out_{i}", magma.Out(TData)) self.__outputs.append(self.ports[f"chain_data_out_{i}"]) self.add_port(f"chain_valid_out_{i}", magma.Out(TBit)) self.__outputs.append(self.ports[f"chain_valid_out_{i}"]) else: self.add_port("data_out", magma.Out(TData)) self.__outputs.append(self.ports[f"data_out"]) self.add_port(f"ren_in", magma.In(TBit)) self.__inputs.append(self.ports[f"ren_in"]) self.add_port(f"valid_out", magma.Out(TBit)) self.__outputs.append(self.ports[f"valid_out"]) self.add_port(f"chain_valid_in", magma.In(TBit)) self.__inputs.append(self.ports[f"chain_valid_in"]) self.add_port(f"chain_data_in", magma.In(TData)) self.__inputs.append(self.ports[f"chain_data_in"]) self.add_port(f"chain_data_out", magma.Out(TData)) self.__outputs.append(self.ports[f"chain_data_out"]) self.add_port(f"chain_valid_out", magma.Out(TBit)) self.__outputs.append(self.ports[f"chain_valid_out"]) self.add_ports(flush=magma.In(TBit), full=magma.Out(TBit), empty=magma.Out(TBit), stall=magma.In(TBit), sram_ready_out=magma.Out(TBit)) self.__inputs.append(self.ports.flush) # self.__inputs.append(self.ports.stall) self.__outputs.append(self.ports.full) self.__outputs.append(self.ports.empty) self.__outputs.append(self.ports.sram_ready_out) cache_key = (self.data_width, self.mem_width, self.mem_depth, self.banks, self.input_iterator_support, self.output_iterator_support, self.interconnect_input_ports, self.interconnect_output_ports, self.use_sram_stub, self.sram_macro_info, self.read_delay, self.rw_same_cycle, self.agg_height, self.max_agg_schedule, self.input_max_port_sched, self.output_max_port_sched, self.align_input, self.max_line_length, self.max_tb_height, self.tb_range_max, self.tb_sched_max, self.max_tb_stride, self.num_tb, self.tb_iterator_support, self.multiwrite, self.max_prefetch, self.config_data_width, self.config_addr_width, self.num_tiles, self.remove_tb, self.fifo_mode, self.stcl_valid_iter, self.add_clk_enable, self.add_flush, self.app_ctrl_depth_width) # Check for circuit caching if cache_key not in MemCore.__circuit_cache: # Instantiate core object here - will only use the object representation to # query for information. The circuit representation will be cached and retrieved # in the following steps. lt_dut = LakeTop( data_width=self.data_width, mem_width=self.mem_width, mem_depth=self.mem_depth, banks=self.banks, input_iterator_support=self.input_iterator_support, output_iterator_support=self.output_iterator_support, input_config_width=self.input_config_width, output_config_width=self.output_config_width, interconnect_input_ports=self.interconnect_input_ports, interconnect_output_ports=self.interconnect_output_ports, use_sram_stub=self.use_sram_stub, sram_macro_info=self.sram_macro_info, read_delay=self.read_delay, rw_same_cycle=self.rw_same_cycle, agg_height=self.agg_height, max_agg_schedule=self.max_agg_schedule, input_max_port_sched=self.input_max_port_sched, output_max_port_sched=self.output_max_port_sched, align_input=self.align_input, max_line_length=self.max_line_length, max_tb_height=self.max_tb_height, tb_range_max=self.tb_range_max, tb_range_inner_max=self.tb_range_inner_max, tb_sched_max=self.tb_sched_max, max_tb_stride=self.max_tb_stride, num_tb=self.num_tb, tb_iterator_support=self.tb_iterator_support, multiwrite=self.multiwrite, max_prefetch=self.max_prefetch, config_data_width=self.config_data_width, config_addr_width=self.config_addr_width, num_tiles=self.num_tiles, app_ctrl_depth_width=self.app_ctrl_depth_width, remove_tb=self.remove_tb, fifo_mode=self.fifo_mode, add_clk_enable=self.add_clk_enable, add_flush=self.add_flush, stcl_valid_iter=self.stcl_valid_iter) change_sram_port_pass = change_sram_port_names( use_sram_stub, sram_macro_info) circ = kts.util.to_magma( lt_dut, flatten_array=True, check_multiple_driver=False, optimize_if=False, check_flip_flop_always_ff=False, additional_passes={"change_sram_port": change_sram_port_pass}) MemCore.__circuit_cache[cache_key] = (circ, lt_dut) else: circ, lt_dut = MemCore.__circuit_cache[cache_key] # Save as underlying circuit object self.underlying = FromMagma(circ) self.chain_idx_bits = max(1, kts.clog2(self.num_tiles)) # put a 1-bit register and a mux to select the control signals # TODO: check if enable_chain_output needs to be here? I don't think so? control_signals = [("wen_in", self.interconnect_input_ports), ("ren_in", self.interconnect_output_ports), ("flush", 1), ("chain_valid_in", self.interconnect_output_ports)] for control_signal, width in control_signals: # TODO: consult with Ankita to see if we can use the normal # mux here if width == 1: mux = MuxWrapper(2, 1, name=f"{control_signal}_sel") reg_value_name = f"{control_signal}_reg_value" reg_sel_name = f"{control_signal}_reg_sel" self.add_config(reg_value_name, 1) self.add_config(reg_sel_name, 1) self.wire(mux.ports.I[0], self.ports[control_signal]) self.wire(mux.ports.I[1], self.registers[reg_value_name].ports.O) self.wire(mux.ports.S, self.registers[reg_sel_name].ports.O) # 0 is the default wire, which takes from the routing network self.wire(mux.ports.O[0], self.underlying.ports[control_signal][0]) else: for i in range(width): mux = MuxWrapper(2, 1, name=f"{control_signal}_{i}_sel") reg_value_name = f"{control_signal}_{i}_reg_value" reg_sel_name = f"{control_signal}_{i}_reg_sel" self.add_config(reg_value_name, 1) self.add_config(reg_sel_name, 1) self.wire(mux.ports.I[0], self.ports[f"{control_signal}_{i}"]) self.wire(mux.ports.I[1], self.registers[reg_value_name].ports.O) self.wire(mux.ports.S, self.registers[reg_sel_name].ports.O) # 0 is the default wire, which takes from the routing network self.wire(mux.ports.O[0], self.underlying.ports[control_signal][i]) if self.interconnect_input_ports > 1: for i in range(self.interconnect_input_ports): self.wire(self.ports[f"data_in_{i}"], self.underlying.ports[f"data_in_{i}"]) self.wire(self.ports[f"addr_in_{i}"], self.underlying.ports[f"addr_in_{i}"]) else: self.wire(self.ports.addr_in, self.underlying.ports.addr_in) self.wire(self.ports.data_in, self.underlying.ports.data_in) if self.interconnect_output_ports > 1: for i in range(self.interconnect_output_ports): self.wire(self.ports[f"data_out_{i}"], self.underlying.ports[f"data_out_{i}"]) self.wire(self.ports[f"chain_data_in_{i}"], self.underlying.ports[f"chain_data_in_{i}"]) self.wire(self.ports[f"chain_data_out_{i}"], self.underlying.ports[f"chain_data_out_{i}"]) else: self.wire(self.ports.data_out, self.underlying.ports.data_out) self.wire(self.ports.chain_data_in, self.underlying.ports.chain_data_in) self.wire(self.ports.chain_data_out, self.underlying.ports.chain_data_out) # Need to invert this self.resetInverter = FromMagma(mantle.DefineInvert(1)) self.wire(self.resetInverter.ports.I[0], self.ports.reset) self.wire(self.resetInverter.ports.O[0], self.underlying.ports.rst_n) self.wire(self.ports.clk, self.underlying.ports.clk) if self.interconnect_output_ports == 1: self.wire(self.ports.valid_out[0], self.underlying.ports.valid_out[0]) self.wire(self.ports.chain_valid_out[0], self.underlying.ports.chain_valid_out[0]) else: for j in range(self.interconnect_output_ports): self.wire(self.ports[f"valid_out_{j}"][0], self.underlying.ports.valid_out[j]) self.wire(self.ports[f"chain_valid_out_{j}"][0], self.underlying.ports.chain_valid_out[j]) self.wire(self.ports.empty[0], self.underlying.ports.empty[0]) self.wire(self.ports.full[0], self.underlying.ports.full[0]) # PE core uses clk_en (essentially active low stall) self.stallInverter = FromMagma(mantle.DefineInvert(1)) self.wire(self.stallInverter.ports.I, self.ports.stall) self.wire(self.stallInverter.ports.O[0], self.underlying.ports.clk_en[0]) self.wire(self.ports.sram_ready_out[0], self.underlying.ports.sram_ready_out[0]) # we have six? features in total # 0: TILE # 1: TILE # 1-4: SMEM # Feature 0: Tile self.__features: List[CoreFeature] = [self] # Features 1-4: SRAM self.num_sram_features = lt_dut.total_sets for sram_index in range(self.num_sram_features): core_feature = CoreFeature(self, sram_index + 1) self.__features.append(core_feature) # Wire the config for idx, core_feature in enumerate(self.__features): if (idx > 0): self.add_port(f"config_{idx}", magma.In(ConfigurationType(8, 32))) # port aliasing core_feature.ports["config"] = self.ports[f"config_{idx}"] self.add_port("config", magma.In(ConfigurationType(8, 32))) # or the signal up t = ConfigurationType(8, 32) t_names = ["config_addr", "config_data"] or_gates = {} for t_name in t_names: port_type = t[t_name] or_gate = FromMagma( mantle.DefineOr(len(self.__features), len(port_type))) or_gate.instance_name = f"OR_{t_name}_FEATURE" for idx, core_feature in enumerate(self.__features): self.wire(or_gate.ports[f"I{idx}"], core_feature.ports.config[t_name]) or_gates[t_name] = or_gate self.wire(or_gates["config_addr"].ports.O, self.underlying.ports.config_addr_in[0:8]) self.wire(or_gates["config_data"].ports.O, self.underlying.ports.config_data_in) # read data out for idx, core_feature in enumerate(self.__features): if (idx > 0): # self.add_port(f"read_config_data_{idx}", self.add_port(f"read_config_data_{idx}", magma.Out(magma.Bits[32])) # port aliasing core_feature.ports["read_config_data"] = \ self.ports[f"read_config_data_{idx}"] # MEM Config configurations = [("tile_en", 1), ("fifo_ctrl_fifo_depth", 16), ("mode", 2), ("enable_chain_output", 1), ("enable_chain_input", 1)] # ("stencil_width", 16), NOT YET merged_configs = [] merged_in_sched = [] merged_out_sched = [] # Add config registers to configurations # TODO: Have lake spit this information out automatically from the wrapper configurations.append((f"chain_idx_input", self.chain_idx_bits)) configurations.append((f"chain_idx_output", self.chain_idx_bits)) for i in range(self.interconnect_input_ports): configurations.append((f"strg_ub_agg_align_{i}_line_length", kts.clog2(self.max_line_length))) configurations.append((f"strg_ub_agg_in_{i}_in_period", kts.clog2(self.input_max_port_sched))) # num_bits_in_sched = kts.clog2(self.agg_height) # sched_per_feat = math.floor(self.config_data_width / num_bits_in_sched) # new_width = num_bits_in_sched * sched_per_feat # feat_num = 0 # num_feats_merge = math.ceil(self.input_max_port_sched / sched_per_feat) # for k in range(num_feats_merge): # num_here = sched_per_feat # if self.input_max_port_sched - (k * sched_per_feat) < sched_per_feat: # num_here = self.input_max_port_sched - (k * sched_per_feat) # merged_configs.append((f"strg_ub_agg_in_{i}_in_sched_merged_{k * sched_per_feat}", # num_here * num_bits_in_sched, num_here)) for j in range(self.input_max_port_sched): configurations.append((f"strg_ub_agg_in_{i}_in_sched_{j}", kts.clog2(self.agg_height))) configurations.append((f"strg_ub_agg_in_{i}_out_period", kts.clog2(self.input_max_port_sched))) for j in range(self.output_max_port_sched): configurations.append((f"strg_ub_agg_in_{i}_out_sched_{j}", kts.clog2(self.agg_height))) configurations.append((f"strg_ub_app_ctrl_write_depth_wo_{i}", self.app_ctrl_depth_width)) configurations.append((f"strg_ub_app_ctrl_write_depth_ss_{i}", self.app_ctrl_depth_width)) configurations.append( (f"strg_ub_app_ctrl_coarse_write_depth_wo_{i}", self.app_ctrl_depth_width)) configurations.append( (f"strg_ub_app_ctrl_coarse_write_depth_ss_{i}", self.app_ctrl_depth_width)) configurations.append( (f"strg_ub_input_addr_ctrl_address_gen_{i}_dimensionality", 1 + kts.clog2(self.input_iterator_support))) configurations.append( (f"strg_ub_input_addr_ctrl_address_gen_{i}_starting_addr", self.input_config_width)) for j in range(self.input_iterator_support): configurations.append( (f"strg_ub_input_addr_ctrl_address_gen_{i}_ranges_{j}", self.input_config_width)) configurations.append( (f"strg_ub_input_addr_ctrl_address_gen_{i}_strides_{j}", self.input_config_width)) configurations.append( (f"strg_ub_app_ctrl_prefill", self.interconnect_output_ports)) configurations.append((f"strg_ub_app_ctrl_coarse_prefill", self.interconnect_output_ports)) for i in range(self.stcl_valid_iter): configurations.append((f"strg_ub_app_ctrl_ranges_{i}", 16)) configurations.append((f"strg_ub_app_ctrl_threshold_{i}", 16)) for i in range(self.interconnect_output_ports): configurations.append((f"strg_ub_app_ctrl_input_port_{i}", kts.clog2(self.interconnect_input_ports))) configurations.append((f"strg_ub_app_ctrl_read_depth_{i}", self.app_ctrl_depth_width)) configurations.append((f"strg_ub_app_ctrl_coarse_input_port_{i}", kts.clog2(self.interconnect_input_ports))) configurations.append((f"strg_ub_app_ctrl_coarse_read_depth_{i}", self.app_ctrl_depth_width)) configurations.append( (f"strg_ub_output_addr_ctrl_address_gen_{i}_dimensionality", 1 + kts.clog2(self.output_iterator_support))) configurations.append( (f"strg_ub_output_addr_ctrl_address_gen_{i}_starting_addr", self.output_config_width)) for j in range(self.output_iterator_support): configurations.append( (f"strg_ub_output_addr_ctrl_address_gen_{i}_ranges_{j}", self.output_config_width)) configurations.append( (f"strg_ub_output_addr_ctrl_address_gen_{i}_strides_{j}", self.output_config_width)) configurations.append((f"strg_ub_pre_fetch_{i}_input_latency", kts.clog2(self.max_prefetch) + 1)) configurations.append((f"strg_ub_sync_grp_sync_group_{i}", self.interconnect_output_ports)) configurations.append( (f"strg_ub_rate_matched_{i}", 1 + kts.clog2(self.interconnect_input_ports))) for j in range(self.num_tb): configurations.append( (f"strg_ub_tba_{i}_tb_{j}_dimensionality", 2)) num_indices_bits = 1 + kts.clog2(self.fw_int) indices_per_feat = math.floor(self.config_data_width / num_indices_bits) new_width = num_indices_bits * indices_per_feat feat_num = 0 num_feats_merge = math.ceil(self.tb_range_inner_max / indices_per_feat) for k in range(num_feats_merge): num_idx = indices_per_feat if (self.tb_range_inner_max - (k * indices_per_feat)) < indices_per_feat: num_idx = self.tb_range_inner_max - (k * indices_per_feat) merged_configs.append(( f"strg_ub_tba_{i}_tb_{j}_indices_merged_{k * indices_per_feat}", num_idx * num_indices_bits, num_idx)) # for k in range(self.tb_range_inner_max): # configurations.append((f"strg_ub_tba_{i}_tb_{j}_indices_{k}", kts.clog2(self.fw_int) + 1)) configurations.append((f"strg_ub_tba_{i}_tb_{j}_range_inner", kts.clog2(self.tb_range_inner_max))) configurations.append((f"strg_ub_tba_{i}_tb_{j}_range_outer", kts.clog2(self.tb_range_max))) configurations.append((f"strg_ub_tba_{i}_tb_{j}_stride", kts.clog2(self.max_tb_stride))) configurations.append((f"strg_ub_tba_{i}_tb_{j}_tb_height", max(1, kts.clog2(self.num_tb)))) configurations.append((f"strg_ub_tba_{i}_tb_{j}_starting_addr", max(1, kts.clog2(self.fw_int)))) # Do all the stuff for the main config main_feature = self.__features[0] for config_reg_name, width in configurations: main_feature.add_config(config_reg_name, width) if (width == 1): self.wire(main_feature.registers[config_reg_name].ports.O[0], self.underlying.ports[config_reg_name][0]) else: self.wire(main_feature.registers[config_reg_name].ports.O, self.underlying.ports[config_reg_name]) for config_reg_name, width, num_merged in merged_configs: main_feature.add_config(config_reg_name, width) token_under = config_reg_name.split("_") base_name = config_reg_name.split("_merged")[0] base_indices = int(config_reg_name.split("_merged_")[1]) num_bits = width // num_merged for i in range(num_merged): self.wire( main_feature.registers[config_reg_name].ports. O[i * num_bits:(i + 1) * num_bits], self.underlying.ports[f"{base_name}_{base_indices + i}"]) # SRAM # These should also account for num features # or_all_cfg_rd = FromMagma(mantle.DefineOr(4, 1)) or_all_cfg_rd = FromMagma(mantle.DefineOr(self.num_sram_features, 1)) or_all_cfg_rd.instance_name = f"OR_CONFIG_WR_SRAM" or_all_cfg_wr = FromMagma(mantle.DefineOr(self.num_sram_features, 1)) or_all_cfg_wr.instance_name = f"OR_CONFIG_RD_SRAM" for sram_index in range(self.num_sram_features): core_feature = self.__features[sram_index + 1] self.add_port(f"config_en_{sram_index}", magma.In(magma.Bit)) # port aliasing core_feature.ports["config_en"] = \ self.ports[f"config_en_{sram_index}"] self.wire(core_feature.ports.read_config_data, self.underlying.ports[f"config_data_out_{sram_index}"]) # also need to wire the sram signal # the config enable is the OR of the rd+wr or_gate_en = FromMagma(mantle.DefineOr(2, 1)) or_gate_en.instance_name = f"OR_CONFIG_EN_SRAM_{sram_index}" self.wire(or_gate_en.ports.I0, core_feature.ports.config.write) self.wire(or_gate_en.ports.I1, core_feature.ports.config.read) self.wire(core_feature.ports.config_en, self.underlying.ports["config_en"][sram_index]) # Still connect to the OR of all the config rd/wr self.wire(core_feature.ports.config.write, or_all_cfg_wr.ports[f"I{sram_index}"]) self.wire(core_feature.ports.config.read, or_all_cfg_rd.ports[f"I{sram_index}"]) self.wire(or_all_cfg_rd.ports.O[0], self.underlying.ports.config_read[0]) self.wire(or_all_cfg_wr.ports.O[0], self.underlying.ports.config_write[0]) self._setup_config() conf_names = list(self.registers.keys()) conf_names.sort() with open("mem_cfg.txt", "w+") as cfg_dump: for idx, reg in enumerate(conf_names): write_line = f"|{reg}|{idx}|{self.registers[reg].width}||\n" cfg_dump.write(write_line) with open("mem_synth.txt", "w+") as cfg_dump: for idx, reg in enumerate(conf_names): write_line = f"{reg}\n" cfg_dump.write(write_line)
def __init__(self, _params: GlobalBufferParams): super().__init__("glb_core_load_dma") self._params = _params self.header = GlbHeader(self._params) assert self._params.bank_data_width == self._params.cgra_data_width * 4 self.clk = self.clock("clk") self.clk_en = self.clock_en("clk_en") self.reset = self.reset("reset") self.data_g2f = self.output("data_g2f", width=self._params.cgra_data_width) self.data_valid_g2f = self.output("data_valid_g2f", width=1) self.rdrq_packet = self.output("rdrq_packet", self.header.rdrq_packet_t) self.rdrs_packet = self.input("rdrs_packet", self.header.rdrs_packet_t) self.cfg_ld_dma_num_repeat = self.input( "cfg_ld_dma_num_repeat", clog2(self._params.queue_depth) + 1) self.cfg_ld_dma_ctrl_use_valid = self.input( "cfg_ld_dma_ctrl_use_valid", 1) self.cfg_ld_dma_ctrl_mode = self.input("cfg_ld_dma_ctrl_mode", 2) self.cfg_data_network_latency = self.input("cfg_data_network_latency", self._params.latency_width) self.cfg_ld_dma_header = self.input("cfg_ld_dma_header", self.header.cfg_dma_header_t, size=self._params.queue_depth) self.ld_dma_start_pulse = self.input("ld_dma_start_pulse", 1) self.ld_dma_done_pulse = self.output("ld_dma_done_pulse", 1) # local parameter self.default_latency = ( self._params.glb_switch_pipeline_depth + self._params.glb_bank_memory_pipeline_depth + self._params.sram_gen_pipeline_depth + self._params.sram_gen_output_pipeline_depth + 1 # SRAM macro read latency + self._params.glb_switch_pipeline_depth + 2 # FIXME: Unnecessary delay of moving back and forth btw switch and router + 1 # load_dma cache register delay ) # local variables self.strm_data = self.var("strm_data", self._params.cgra_data_width) self.strm_data_r = self.var("strm_data_r", self._params.cgra_data_width) self.strm_data_valid = self.var("strm_data_valid", 1) self.strm_data_valid_r = self.var("strm_data_valid_r", 1) self.strm_data_sel = self.var( "strm_data_sel", self._params.bank_byte_offset - self._params.cgra_byte_offset) self.strm_rd_en_w = self.var("strm_rd_en_w", 1) self.strm_rd_addr_w = self.var("strm_rd_addr_w", self._params.glb_addr_width) self.last_strm_rd_addr_r = self.var("last_strm_rd_addr_r", self._params.glb_addr_width) self.ld_dma_start_pulse_next = self.var("ld_dma_start_pulse_next", 1) self.ld_dma_start_pulse_r = self.var("ld_dma_start_pulse_r", 1) self.is_first = self.var("is_first", 1) self.ld_dma_done_pulse_w = self.var("ld_dma_done_pulse_w", 1) self.bank_addr_match = self.var("bank_addr_match", 1) self.bank_rdrq_rd_en = self.var("bank_rdrq_rd_en", 1) self.bank_rdrq_rd_addr = self.var("bank_rdrq_rd_addr", self._params.glb_addr_width) self.bank_rdrs_data_cache_r = self.var("bank_rdrs_data_cache_r", self._params.bank_data_width) self.strm_run = self.var("strm_run", 1) self.loop_done = self.var("loop_done", 1) self.cycle_valid = self.var("cycle_valid", 1) self.cycle_count = self.var("cycle_count", self._params.axi_data_width) self.cycle_current_addr = self.var("cycle_current_addr", self._params.axi_data_width) self.data_current_addr = self.var("data_current_addr", self._params.axi_data_width) self.loop_mux_sel = self.var("loop_mux_sel", clog2(self._params.loop_level)) self.repeat_cnt = self.var("repeat_cnt", clog2(self._params.queue_depth) + 1) if self._params.queue_depth != 1: self.queue_sel_r = self.var("queue_sel_r", max(1, clog2(self.repeat_cnt.width))) # Current dma header self.current_dma_header = self.var("current_dma_header", self.header.cfg_dma_header_t) if self._params.queue_depth == 1: self.wire(self.cfg_ld_dma_header, self.current_dma_header) else: self.wire(self.cfg_ld_dma_header[self.queue_sel_r], self.current_dma_header) if self._params.queue_depth != 1: self.add_always(self.queue_sel_ff) self.add_always(self.repeat_cnt_ff) self.add_always(self.cycle_counter) self.add_always(self.is_first_ff) self.add_always(self.strm_run_ff) self.add_always(self.strm_data_ff) self.add_strm_data_start_pulse_pipeline() self.add_ld_dma_done_pulse_pipeline() self.add_strm_rd_en_pipeline() self.add_strm_rd_addr_pipeline() self.add_always(self.ld_dma_start_pulse_logic) self.add_always(self.ld_dma_start_pulse_ff) self.add_always(self.strm_data_mux) self.add_always(self.ld_dma_done_pulse_logic) self.add_always(self.strm_rdrq_packet_ff) self.add_always(self.last_strm_rd_addr_ff) self.add_always(self.bank_rdrq_packet_logic) self.add_always(self.bank_rdrs_data_cache_ff) self.add_always(self.strm_data_logic) # Loop iteration shared for cycle and data self.loop_iter = GlbLoopIter(self._params) self.add_child("loop_iter", self.loop_iter, clk=self.clk, clk_en=self.clk_en, reset=self.reset, step=self.cycle_valid, mux_sel_out=self.loop_mux_sel, restart=self.loop_done) self.wire(self.loop_iter.dim, self.current_dma_header[f"dim"]) for i in range(self._params.loop_level): self.wire(self.loop_iter.ranges[i], self.current_dma_header[f"range_{i}"]) # Cycle stride self.cycle_stride_sched_gen = GlbSchedGen(self._params) self.add_child("cycle_stride_sched_gen", self.cycle_stride_sched_gen, clk=self.clk, clk_en=self.clk_en, reset=self.reset, restart=self.ld_dma_start_pulse_r, cycle_count=self.cycle_count, current_addr=self.cycle_current_addr, finished=self.loop_done, valid_output=self.cycle_valid) self.cycle_stride_addr_gen = GlbAddrGen(self._params) self.add_child("cycle_stride_addr_gen", self.cycle_stride_addr_gen, clk=self.clk, clk_en=self.clk_en, reset=self.reset, restart=self.ld_dma_start_pulse_r, step=self.cycle_valid, mux_sel=self.loop_mux_sel, addr_out=self.cycle_current_addr) self.wire( self.cycle_stride_addr_gen.start_addr, ext(self.current_dma_header[f"cycle_start_addr"], self._params.axi_data_width)) for i in range(self._params.loop_level): self.wire(self.cycle_stride_addr_gen.strides[i], self.current_dma_header[f"cycle_stride_{i}"]) # Data stride self.data_stride_addr_gen = GlbAddrGen(self._params) self.add_child("data_stride_addr_gen", self.data_stride_addr_gen, clk=self.clk, clk_en=self.clk_en, reset=self.reset, restart=self.ld_dma_start_pulse_r, step=self.cycle_valid, mux_sel=self.loop_mux_sel, addr_out=self.data_current_addr) self.wire( self.data_stride_addr_gen.start_addr, ext(self.current_dma_header[f"start_addr"], self._params.axi_data_width)) for i in range(self._params.loop_level): self.wire(self.data_stride_addr_gen.strides[i], self.current_dma_header[f"stride_{i}"])
def test_kratos_single_instance(find_free_port, simulator): if not simulator.available(): pytest.skip(simulator.__name__ + " not available") from kratos import Generator, clog2, always_ff, always_comb, verilog, posedge input_width = 16 buffer_size = 4 mod = Generator("mod", debug=True) clk = mod.clock("clk") rst = mod.reset("rst") in_ = mod.input("in", input_width) out = mod.output("out", input_width) counter = mod.var("count", clog2(buffer_size)) # notice that verilator does not support probing on packed array!!! data = mod.var("data", input_width, size=buffer_size) @always_comb def sum_data(): out = 0 for i in range(buffer_size): out = out + data[i] @always_ff((posedge, clk), (posedge, rst)) def buffer_logic(): if rst: for i in range(buffer_size): data[i] = 0 counter = 0 else: data[counter] = in_ counter += 1 mod.add_always(sum_data, ssa_transform=True) mod.add_always(buffer_logic) py_line_num = get_line_num(py_filename, " out = out + data[i]") with tempfile.TemporaryDirectory() as temp: temp = os.path.abspath(temp) db_filename = os.path.join(temp, "debug.db") sv_filename = os.path.join(temp, "mod.sv") verilog(mod, filename=sv_filename, insert_debug_info=True, debug_db_filename=db_filename, ssa_transform=True, insert_verilator_info=True) # run verilator tb = "test_kratos.cc" if simulator == VerilatorTester else "test_kratos.sv" main_file = get_vector_file(tb) with simulator(sv_filename, main_file, cwd=temp) as tester: port = find_free_port() uri = get_uri(port) # set the port tester.run(blocking=False, DEBUG_PORT=port, DEBUG_LOG=True) async def client_logic(): client = hgdb.HGDBClient(uri, db_filename) await client.connect() # set breakpoint await client.set_breakpoint(py_filename, py_line_num) await client.continue_() for i in range(4): bp = await client.recv() assert bp["payload"]["instances"][0]["local"]["i"] == str(i) if simulator == XceliumTester: assert bp["payload"]["time"] == 10 await client.continue_() for i in range(4): # the first breakpoint, out is not calculated yet # so it should be 0 # after that, it should be 1 bp = await client.recv() if simulator == XceliumTester: assert bp["payload"]["time"] == 30 if i == 0: assert bp["payload"]["instances"][0]["local"]["out"] == "0" else: assert bp["payload"]["instances"][0]["local"]["out"] == "1" await client.continue_() # remove the breakpoint and set a conditional breakpoint # discard the current breakpoint information await client.recv() # remove the current breakpoint await client.remove_breakpoint(py_filename, py_line_num) await client.set_breakpoint(py_filename, py_line_num, cond="out == 6 && i == 3") await client.continue_() bp = await client.recv() assert bp["payload"]["instances"][0]["local"]["out"] == "6" assert bp["payload"]["instances"][0]["local"]["i"] == "3" assert bp["payload"]["instances"][0]["local"]["data.2"] == "3" asyncio.get_event_loop().run_until_complete(client_logic())
def __init__(self, _params: GlobalBufferParams): super().__init__("glb_core_store_dma") self._params = _params self.header = GlbHeader(self._params) assert self._params.bank_data_width == self._params.cgra_data_width * 4 self.clk = self.clock("clk") self.clk_en = self.clock_en("clk_en") self.reset = self.reset("reset") self.data_f2g = self.input( "data_f2g", width=self._params.cgra_data_width) self.data_valid_f2g = self.input("data_valid_f2g", width=1) self.wr_packet = self.output( "wr_packet", self.header.wr_packet_t) self.cfg_st_dma_num_repeat = self.input("cfg_st_dma_num_repeat", clog2(self._params.queue_depth) + 1) self.cfg_st_dma_ctrl_mode = self.input("cfg_st_dma_ctrl_mode", 2) self.cfg_st_dma_ctrl_use_valid = self.input("cfg_st_dma_ctrl_use_valid", 1) self.cfg_data_network_latency = self.input( "cfg_data_network_latency", self._params.latency_width) self.cfg_st_dma_header = self.input( "cfg_st_dma_header", self.header.cfg_dma_header_t, size=self._params.queue_depth, explicit_array=True) self.st_dma_start_pulse = self.input("st_dma_start_pulse", 1) self.st_dma_done_pulse = self.output("st_dma_done_pulse", 1) # localparam self.default_latency = (self._params.glb_bank_memory_pipeline_depth + self._params.sram_gen_pipeline_depth + self._params.glb_switch_pipeline_depth ) self.cgra_strb_width = self._params.cgra_data_width // 8 self.cgra_strb_value = 2 ** (self._params.cgra_data_width // 8) - 1 # local variables self.strm_wr_data_w = self.var("strm_wr_data_w", width=self._params.cgra_data_width) self.strm_wr_addr_w = self.var("strm_wr_addr_w", width=self._params.glb_addr_width) self.last_strm_wr_addr_r = self.var("last_strm_wr_addr_r", width=self._params.glb_addr_width) self.strm_wr_en_w = self.var("strm_wr_en_w", width=1) self.strm_data_sel = self.var("strm_data_sel", self._params.bank_byte_offset - self._params.cgra_byte_offset) self.bank_addr_match = self.var("bank_addr_match", 1) self.bank_wr_en = self.var("bank_wr_en", 1) self.bank_wr_addr = self.var("bank_wr_addr", width=self._params.glb_addr_width) self.bank_wr_data_cache_r = self.var("bank_wr_data_cache_r", self._params.bank_data_width) self.bank_wr_data_cache_w = self.var("bank_wr_data_cache_w", self._params.bank_data_width) self.bank_wr_strb_cache_r = self.var("bank_wr_strb_cache_r", math.ceil(self._params.bank_data_width / 8)) self.bank_wr_strb_cache_w = self.var("bank_wr_strb_cache_w", math.ceil(self._params.bank_data_width / 8)) self.done_pulse_w = self.var("done_pulse_w", 1) self.st_dma_start_pulse_next = self.var("st_dma_start_pulse_next", 1) self.st_dma_start_pulse_r = self.var("st_dma_start_pulse_r", 1) self.is_first = self.var("is_first", 1) self.is_last = self.var("is_last", 1) self.strm_run = self.var("strm_run", 1) self.loop_done = self.var("loop_done", 1) self.cycle_valid = self.var("cycle_valid", 1) self.cycle_valid_muxed = self.var("cycle_valid_muxed", 1) self.cycle_count = self.var("cycle_count", self._params.axi_data_width) self.cycle_current_addr = self.var("cycle_current_addr", self._params.axi_data_width) self.data_current_addr = self.var("data_current_addr", self._params.axi_data_width) self.loop_mux_sel = self.var("loop_mux_sel", clog2(self._params.loop_level)) self.repeat_cnt = self.var("repeat_cnt", clog2(self._params.queue_depth) + 1) if self._params.queue_depth != 1: self.queue_sel_r = self.var("queue_sel_r", max(1, clog2(self.repeat_cnt.width))) # Current dma header self.current_dma_header = self.var("current_dma_header", self.header.cfg_dma_header_t) if self._params.queue_depth == 1: self.wire(self.cfg_st_dma_header, self.current_dma_header) else: self.wire(self.cfg_st_dma_header[self.queue_sel_r], self.current_dma_header) if self._params.queue_depth != 1: self.add_always(self.queue_sel_ff) self.add_always(self.repeat_cnt_ff) self.add_always(self.is_first_ff) self.add_always(self.is_last_ff) self.add_always(self.strm_run_ff) self.add_always(self.st_dma_start_pulse_logic) self.add_always(self.st_dma_start_pulse_ff) self.add_always(self.cycle_counter) self.add_always(self.cycle_valid_comb) self.add_always(self.strm_wr_packet_comb) self.add_always(self.last_strm_wr_addr_ff) self.add_always(self.strm_data_sel_comb) self.add_always(self.bank_wr_packet_cache_comb) self.add_always(self.bank_wr_packet_cache_ff) self.add_always(self.bank_wr_packet_logic) self.add_always(self.wr_packet_logic) self.add_always(self.strm_done_pulse_logic) self.add_done_pulse_pipeline() # Loop iteration shared for cycle and data self.loop_iter = GlbLoopIter(self._params) self.add_child("loop_iter", self.loop_iter, clk=self.clk, clk_en=self.clk_en, reset=self.reset, step=self.cycle_valid_muxed, mux_sel_out=self.loop_mux_sel, restart=self.loop_done) self.wire(self.loop_iter.dim, self.current_dma_header[f"dim"]) for i in range(self._params.loop_level): self.wire(self.loop_iter.ranges[i], self.current_dma_header[f"range_{i}"]) # Cycle stride self.cycle_stride_sched_gen = GlbSchedGen(self._params) self.add_child("cycle_stride_sched_gen", self.cycle_stride_sched_gen, clk=self.clk, clk_en=self.clk_en, reset=self.reset, restart=self.st_dma_start_pulse_r, cycle_count=self.cycle_count, current_addr=self.cycle_current_addr, finished=self.loop_done, valid_output=self.cycle_valid) self.cycle_stride_addr_gen = GlbAddrGen(self._params) self.add_child("cycle_stride_addr_gen", self.cycle_stride_addr_gen, clk=self.clk, clk_en=self.clk_en, reset=self.reset, restart=self.st_dma_start_pulse_r, step=self.cycle_valid_muxed, mux_sel=self.loop_mux_sel, addr_out=self.cycle_current_addr) self.wire(self.cycle_stride_addr_gen.start_addr, ext( self.current_dma_header[f"cycle_start_addr"], self._params.axi_data_width)) for i in range(self._params.loop_level): self.wire(self.cycle_stride_addr_gen.strides[i], self.current_dma_header[f"cycle_stride_{i}"]) # Data stride self.data_stride_addr_gen = GlbAddrGen(self._params) self.add_child("data_stride_addr_gen", self.data_stride_addr_gen, clk=self.clk, clk_en=self.clk_en, reset=self.reset, restart=self.st_dma_start_pulse_r, step=self.cycle_valid_muxed, mux_sel=self.loop_mux_sel, addr_out=self.data_current_addr) self.wire(self.data_stride_addr_gen.start_addr, ext( self.current_dma_header[f"start_addr"], self._params.axi_data_width)) for i in range(self._params.loop_level): self.wire(self.data_stride_addr_gen.strides[i], self.current_dma_header[f"stride_{i}"])
def __init__( self, data_width=16, # CGRA Params mem_width=64, mem_depth=512, banks=2, input_iterator_support=6, # Addr Controllers output_iterator_support=6, interconnect_input_ports=1, # Connection to int interconnect_output_ports=3, mem_input_ports=1, mem_output_ports=1, use_sram_stub=1, sram_name="default_name", read_delay=1, agg_height=8, max_agg_schedule=64, input_max_port_sched=64, output_max_port_sched=64, align_input=1, max_line_length=2048, tb_height=1, tb_range_max=2048, tb_range_inner_max=5, tb_sched_max=64, num_tb=1, multiwrite=2, max_prefetch=64, num_tiles=1, stcl_valid_iter=4): self.data_width = data_width self.mem_width = mem_width self.mem_depth = mem_depth self.banks = banks self.input_iterator_support = input_iterator_support self.output_iterator_support = output_iterator_support self.interconnect_input_ports = interconnect_input_ports self.interconnect_output_ports = interconnect_output_ports self.mem_input_ports = mem_input_ports self.mem_output_ports = mem_output_ports self.use_sram_stub = use_sram_stub self.sram_name = sram_name self.agg_height = agg_height self.max_agg_schedule = max_agg_schedule self.input_max_port_sched = input_max_port_sched self.output_max_port_sched = output_max_port_sched self.input_port_sched_width = kts.clog2(self.interconnect_input_ports) self.align_input = align_input self.max_line_length = max_line_length assert self.mem_width > self.data_width, "Data width needs to be smaller than mem" self.fw_int = int(self.mem_width / self.data_width) self.num_tb = num_tb self.tb_height = tb_height self.tb_range_max = tb_range_max self.tb_range_inner_max = tb_range_inner_max self.tb_sched_max = tb_sched_max self.multiwrite = multiwrite self.max_prefetch = max_prefetch self.read_delay = read_delay self.num_tiles = num_tiles self.stcl_valid_iter = stcl_valid_iter self.chain_idx_bits = max(1, kts.clog2(num_tiles)) self.address_width = kts.clog2(self.mem_depth * num_tiles) self.config = {} # top level configuration registers # chaining self.config[f"enable_chain_input"] = 0 self.config[f"enable_chain_output"] = 0 self.config[f"chain_idx_input"] = 0 self.config[f"chain_idx_output"] = 0 self.config[f"tile_en"] = 0 self.config[f"mode"] = 0 for i in range(self.interconnect_output_ports): self.config[f"rate_matched_{i}"] = 0 # Set up model.. ### APP CTRL (FINE-GRAINED) self.app_ctrl = AppCtrlModel( int_in_ports=self.interconnect_input_ports, int_out_ports=self.interconnect_output_ports, sprt_stcl_valid=True, stcl_iter_support=self.stcl_valid_iter) for i in range(self.interconnect_input_ports): self.config[f"app_ctrl_write_depth_{i}"] = 0 for i in range(self.interconnect_output_ports): self.config[f"app_ctrl_input_port_{i}"] = 0 self.config[f"app_ctrl_read_depth_{i}"] = 0 self.config[f"app_ctrl_prefill_{i}"] = 0 # calculate stencil_valid with fine-grained application controller for i in range(self.stcl_valid_iter): self.config[f'app_ctrl_ranges_{i}'] = 0 self.config[f'app_ctrl_app_ctrl_threshold_{i}'] = 0 ### COARSE APP CTRL self.app_ctrl_coarse = AppCtrlModel( int_in_ports=self.interconnect_input_ports, int_out_ports=self.interconnect_output_ports, sprt_stcl_valid=False, # unused, stencil valid comes from app ctrl, # not coarse app ctrl stcl_iter_support=0) for i in range(self.interconnect_input_ports): self.config[f"app_ctrl_coarse_write_depth_{i}"] = 0 for i in range(self.interconnect_output_ports): self.config[f"app_ctrl_coarse_input_port_{i}"] = 0 self.config[f"app_ctrl_coarse_read_depth_{i}"] = 0 self.config[f"app_ctrl_coarse_prefill_{i}"] = 0 ### INST AGG ALIGNER if (self.agg_height > 0): self.agg_aligners = [] for i in range(self.interconnect_input_ports): self.agg_aligners.append( AggAlignerModel(data_width=self.data_width, max_line_length=self.max_line_length)) self.config[f"agg_align_{i}_line_length"] = 0 ### AGG BUFF self.agg_buffs = [] for port in range(self.interconnect_input_ports): self.agg_buffs.append( AggBuffModel(agg_height=self.agg_height, data_width=self.data_width, mem_width=self.mem_width, max_agg_schedule=self.max_agg_schedule)) self.config[f"agg_in_{i}_in_period"] = 0 self.config[f"agg_in_{i}_out_period"] = 0 for j in range(self.max_agg_schedule): self.config[f"agg_in_{i}_in_sched_{j}"] = 0 self.config[f"agg_in_{i}_out_sched_{j}"] = 0 ### INPUT ADDR CTRL self.iac = InputAddrCtrlModel( interconnect_input_ports=self.interconnect_input_ports, data_width=self.data_width, fetch_width=self.mem_width, mem_depth=self.mem_depth, num_tiles=self.num_tiles, banks=self.banks, iterator_support=self.input_iterator_support, max_port_schedule=self.input_max_port_sched, address_width=self.address_width) for i in range(self.interconnect_input_ports): self.config[f"input_addr_ctrl_address_gen_{i}_dimensionality"] = 0 self.config[f"input_addr_ctrl_address_gen_{i}_starting_addr"] = 0 for j in range(self.input_iterator_support): self.config[f"input_addr_ctrl_address_gen_{i}_ranges_{j}"] = 0 self.config[f"input_addr_ctrl_address_gen_{i}_strides_{j}"] = 0 for j in range(self.multiwrite): self.config[f"input_addr_ctrl_offsets_cfg_{i}_{j}"] = 0 ### OUTPUT ADDR CTRL self.oac = OutputAddrCtrlModel( interconnect_output_ports=self.interconnect_output_ports, mem_depth=self.mem_depth, num_tiles=self.num_tiles, data_width=self.data_width, fetch_width=self.mem_width, banks=self.banks, iterator_support=self.output_iterator_support, address_width=self.address_width, chain_idx_output=self.config[f"chain_idx_output"]) for i in range(self.interconnect_output_ports): self.config[f"output_addr_ctrl_address_gen_{i}_dimensionality"] = 0 self.config[f"output_addr_ctrl_address_gen_{i}_starting_addr"] = 0 for j in range(self.input_iterator_support): self.config[f"output_addr_ctrl_address_gen_{i}_ranges_{j}"] = 0 self.config[ f"output_addr_ctrl_address_gen_{i}_strides_{j}"] = 0 ### RW ARBITER # Per bank allocation self.rw_arbs = [] for bank in range(self.banks): self.rw_arbs.append( RWArbiterModel(fetch_width=self.mem_width, data_width=self.data_width, memory_depth=self.mem_depth, int_out_ports=self.interconnect_output_ports, read_delay=self.read_delay)) self.mems = [] if self.read_delay == 1: ### SRAMS for banks in range(self.banks): self.mems.append( SRAMWrapperModel( use_sram_stub=self.use_sram_stub, sram_name=self.sram_name, data_width=self.data_width, fw_int=self.fw_int, mem_depth=self.mem_depth, mem_input_ports=self.mem_input_ports, mem_output_ports=self.mem_output_ports, address_width=self.address_width, bank_num=banks, num_tiles=self.num_tiles, enable_chain_input=self.config[f"enable_chain_input"], enable_chain_output=self. config[f"enable_chain_output"], chain_idx_input=self.config[f"chain_idx_input"], chain_idx_output=self.config[f"chain_idx_output"])) else: ### REGFILES for banks in range(self.banks): self.mems.append( RegisterFileModel(data_width=self.data_width, write_ports=self.mem_input_ports, read_ports=self.mem_output_ports, width_mult=self.fw_int, depth=self.mem_depth)) ### DEMUX READS self.demux_reads = DemuxReadsModel( fetch_width=self.mem_width, data_width=self.data_width, banks=self.banks, int_out_ports=self.interconnect_output_ports) ### SYNC GROUPS self.sync_groups = SyncGroupsModel( fetch_width=self.mem_width, data_width=self.data_width, int_out_ports=self.interconnect_output_ports) for i in range(self.interconnect_output_ports): self.config[f"sync_grp_sync_group_{i}"] = 0 ### PREFETCHERS self.prefetchers = [] for port in range(self.interconnect_output_ports): self.prefetchers.append( PrefetcherModel(fetch_width=self.mem_width, data_width=self.data_width, max_prefetch=self.max_prefetch)) self.config[f"pre_fetch_{port}_input_latency"] = 0 ### TBAS self.tbas = [] for port in range(self.interconnect_output_ports): self.tbas.append( TBAModel(word_width=self.data_width, fetch_width=self.fw_int, num_tb=self.num_tb, tb_height=self.tb_height, max_range=self.tb_range_max, max_range_inner=self.tb_range_inner_max)) for i in range(self.tb_height): self.config[f"tba_{port}_tb_{i}_range_inner"] = 0 self.config[f"tba_{port}_tb_{i}_range_outer"] = 0 self.config[f"tba_{port}_tb_{i}_stride"] = 0 self.config[f"tba_{port}_tb_{i}_dimensionality"] = 0 self.config[f"tba_{port}_tb_{i}_starting_addr"] = 0 for j in range(self.tb_sched_max): self.config[f"tba_{port}_tb_{i}_indices_{j}"] = 0
def __init__( self, data_width=16, # CGRA Params mem_depth=32, default_iterator_support=3, interconnect_input_ports=1, # Connection to int interconnect_output_ports=1, config_data_width=32, config_addr_width=8, cycle_count_width=16, add_clk_enable=True, add_flush=True): lake_name = "Pond_pond" super().__init__(config_data_width=config_data_width, config_addr_width=config_addr_width, data_width=data_width, name="PondCore") # Capture everything to the tile object self.interconnect_input_ports = interconnect_input_ports self.interconnect_output_ports = interconnect_output_ports self.mem_depth = mem_depth self.data_width = data_width self.config_data_width = config_data_width self.config_addr_width = config_addr_width self.add_clk_enable = add_clk_enable self.add_flush = add_flush self.cycle_count_width = cycle_count_width self.default_iterator_support = default_iterator_support self.default_config_width = kts.clog2(self.mem_depth) cache_key = (self.data_width, self.mem_depth, self.interconnect_input_ports, self.interconnect_output_ports, self.config_data_width, self.config_addr_width, self.add_clk_enable, self.add_flush, self.cycle_count_width, self.default_iterator_support) # Check for circuit caching if cache_key not in LakeCoreBase._circuit_cache: # Instantiate core object here - will only use the object representation to # query for information. The circuit representation will be cached and retrieved # in the following steps. self.dut = Pond( data_width=data_width, # CGRA Params mem_depth=mem_depth, default_iterator_support=default_iterator_support, interconnect_input_ports= interconnect_input_ports, # Connection to int interconnect_output_ports=interconnect_output_ports, config_data_width=config_data_width, config_addr_width=config_addr_width, cycle_count_width=cycle_count_width, add_clk_enable=add_clk_enable, add_flush=add_flush) circ = kts.util.to_magma(self.dut, flatten_array=True, check_multiple_driver=False, optimize_if=False, check_flip_flop_always_ff=False) LakeCoreBase._circuit_cache[cache_key] = (circ, self.dut) else: circ, self.dut = LakeCoreBase._circuit_cache[cache_key] # Save as underlying circuit object self.underlying = FromMagma(circ) self.wrap_lake_core() conf_names = list(self.registers.keys()) conf_names.sort() with open("pond_cfg.txt", "w+") as cfg_dump: for idx, reg in enumerate(conf_names): write_line = f"(\"{reg}\", 0), # {self.registers[reg].width}\n" cfg_dump.write(write_line) with open("pond_synth.txt", "w+") as cfg_dump: for idx, reg in enumerate(conf_names): write_line = f"{reg}\n" cfg_dump.write(write_line)
def __init__(self, _params: GlobalBufferParams): self._params = _params self.cfg_data_network_t = PackedStruct("cfg_data_network_t", [("tile_connected", 1), ("latency", self._params.latency_width)]) self.cfg_pcfg_network_t = PackedStruct("cfg_pcfg_network_t", [("tile_connected", 1), ("latency", self._params.latency_width)]) self.cfg_dma_ctrl_t = PackedStruct("dma_ctrl_t", [("mode", 2), ("use_valid", 1), ("data_mux", 2), ("num_repeat", clog2(self._params.queue_depth) + 1)]) # NOTE: Kratos does not support struct of struct now. dma_header_struct_list = [("start_addr", self._params.glb_addr_width), ("cycle_start_addr", self._params.glb_addr_width)] dma_header_struct_list += [("dim", 1 + clog2(self._params.loop_level))] for i in range(self._params.loop_level): dma_header_struct_list += [(f"range_{i}", self._params.axi_data_width), (f"stride_{i}", self._params.axi_data_width), (f"cycle_stride_{i}", self._params.axi_data_width)] self.cfg_dma_header_t = PackedStruct("dma_header_t", dma_header_struct_list) # pcfg dma header self.cfg_pcfg_dma_ctrl_t = PackedStruct("pcfg_dma_ctrl_t", [("mode", 1)]) self.cfg_pcfg_dma_header_t = PackedStruct("pcfg_dma_header_t", [("start_addr", self._params.glb_addr_width), ("num_cfg", self._params.max_num_cfg_width)]) wr_packet_list = [("wr_en", 1), ("wr_strb", math.ceil(self._params.bank_data_width / 8)), ("wr_addr", self._params.glb_addr_width), ("wr_data", self._params.bank_data_width), ] rdrq_packet_list = [("rd_en", 1), ("rd_addr", self._params.glb_addr_width), ] rdrs_packet_list = [("rd_data", self._params.bank_data_width), ("rd_data_valid", 1), ] self.packet_t = PackedStruct( "packet_t", wr_packet_list + rdrq_packet_list + rdrs_packet_list) self.rd_packet_t = PackedStruct( "rd_packet_t", rdrq_packet_list + rdrs_packet_list) self.rdrq_packet_t = PackedStruct("rdrq_packet_t", rdrq_packet_list) self.rdrs_packet_t = PackedStruct("rdrs_packet_t", rdrs_packet_list) self.wr_packet_t = PackedStruct("wr_packet_t", wr_packet_list) # NOTE: Kratos currently does not support struct of struct. # This can become cleaner if it does. self.wr_packet_ports = [name for (name, _) in wr_packet_list] self.rdrq_packet_ports = [name for (name, _) in rdrq_packet_list] self.rdrs_packet_ports = [name for (name, _) in rdrs_packet_list] self.rd_packet_ports = [name for (name, _) in ( rdrq_packet_list + rdrs_packet_list)] self.packet_ports = [name for (name, _) in ( rdrq_packet_list + rdrs_packet_list + wr_packet_list)] self.cgra_cfg_t = PackedStruct("cgra_cfg_t", [("rd_en", 1), ("wr_en", 1), ( "addr", self._params.cgra_cfg_addr_width), ("data", self._params.cgra_cfg_data_width)])
def __init__(self, data_width=16, # CGRA Params mem_width=64, mem_depth=512, banks=1, input_iterator_support=6, # Addr Controllers output_iterator_support=6, input_config_width=16, output_config_width=16, interconnect_input_ports=2, # Connection to int interconnect_output_ports=2, mem_input_ports=1, mem_output_ports=1, read_delay=1, # Cycle delay in read (SRAM vs Register File) rw_same_cycle=False, # Does the memory allow r+w in same cycle? agg_height=4, max_agg_schedule=32, input_max_port_sched=32, output_max_port_sched=32, align_input=1, max_line_length=128, max_tb_height=1, tb_range_max=128, tb_range_inner_max=5, tb_sched_max=64, max_tb_stride=15, num_tb=1, tb_iterator_support=2, multiwrite=1, num_tiles=1, max_prefetch=8, app_ctrl_depth_width=16, remove_tb=False, stcl_valid_iter=4): super().__init__("strg_ub") self.data_width = data_width self.mem_width = mem_width self.mem_depth = mem_depth self.banks = banks self.input_iterator_support = input_iterator_support self.output_iterator_support = output_iterator_support self.input_config_width = input_config_width self.output_config_width = output_config_width self.interconnect_input_ports = interconnect_input_ports self.interconnect_output_ports = interconnect_output_ports self.mem_input_ports = mem_input_ports self.mem_output_ports = mem_output_ports self.agg_height = agg_height self.max_agg_schedule = max_agg_schedule self.input_max_port_sched = input_max_port_sched self.output_max_port_sched = output_max_port_sched self.input_port_sched_width = clog2(self.interconnect_input_ports) self.align_input = align_input self.max_line_length = max_line_length assert self.mem_width >= self.data_width, "Data width needs to be smaller than mem" self.fw_int = int(self.mem_width / self.data_width) self.num_tb = num_tb self.max_tb_height = max_tb_height self.tb_range_max = tb_range_max self.tb_range_inner_max = tb_range_inner_max self.max_tb_stride = max_tb_stride self.tb_sched_max = tb_sched_max self.tb_iterator_support = tb_iterator_support self.multiwrite = multiwrite self.max_prefetch = max_prefetch self.num_tiles = num_tiles self.app_ctrl_depth_width = app_ctrl_depth_width self.remove_tb = remove_tb self.read_delay = read_delay self.rw_same_cycle = rw_same_cycle self.stcl_valid_iter = stcl_valid_iter # phases = [] TODO self.address_width = clog2(self.num_tiles * self.mem_depth) # CLK and RST self._clk = self.clock("clk") self._rst_n = self.reset("rst_n") # INPUTS self._data_in = self.input("data_in", self.data_width, size=self.interconnect_input_ports, packed=True, explicit_array=True) self._wen_in = self.input("wen_in", self.interconnect_input_ports) self._ren_input = self.input("ren_in", self.interconnect_output_ports) # Post rate matched self._ren_in = self.var("ren_in_muxed", self.interconnect_output_ports) # Processed versions of wen and ren from the app ctrl self._wen = self.var("wen", self.interconnect_input_ports) self._ren = self.var("ren", self.interconnect_output_ports) # Add rate matched # If one input port, let any output port use the wen_in as the ren_in # If more, do the same thing but also provide port selection if self.interconnect_input_ports == 1: self._rate_matched = self.input("rate_matched", self.interconnect_output_ports) self._rate_matched.add_attribute(ConfigRegAttr("Rate matched - 1 or 0")) for i in range(self.interconnect_output_ports): self.wire(self._ren_in[i], kts.ternary(self._rate_matched[i], self._wen_in, self._ren_input[i])) else: self._rate_matched = self.input("rate_matched", 1 + kts.clog2(self.interconnect_input_ports), size=self.interconnect_output_ports, explicit_array=True, packed=True) self._rate_matched.add_attribute(ConfigRegAttr("Rate matched [input port | on/off]")) for i in range(self.interconnect_output_ports): self.wire(self._ren_in[i], kts.ternary(self._rate_matched[i][0], self._wen_in[self._rate_matched[i][kts.clog2(self.interconnect_input_ports), 1]], self._ren_input[i])) self._arb_wen_en = self.var("arb_wen_en", self.interconnect_input_ports) self._arb_ren_en = self.var("arb_ren_en", self.interconnect_output_ports) self._data_from_strg = self.input("data_from_strg", self.data_width, size=(self.banks, self.mem_output_ports, self.fw_int), packed=True, explicit_array=True) self._mem_valid_data = self.input("mem_valid_data", self.mem_output_ports, size=self.banks, explicit_array=True, packed=True) self._out_mem_valid_data = self.var("out_mem_valid_data", self.mem_output_ports, size=self.banks, explicit_array=True, packed=True) # We need to signal valids out of the agg buff, only if one exists... if self.agg_height > 0: self._to_iac_valid = self.var("ab_to_mem_valid", self.interconnect_input_ports) self._data_out = self.output("data_out", self.data_width, size=self.interconnect_output_ports, packed=True, explicit_array=True) self._valid_out = self.output("valid_out", self.interconnect_output_ports) self._valid_out_alt = self.var("valid_out_alt", self.interconnect_output_ports) self._data_to_strg = self.output("data_to_strg", self.data_width, size=(self.banks, self.mem_input_ports, self.fw_int), packed=True, explicit_array=True) # If we can perform a read and a write on the same cycle, # this will necessitate a separate read and write address... if self.rw_same_cycle: self._wr_addr_out = self.output("wr_addr_out", self.address_width, size=(self.banks, self.mem_input_ports), explicit_array=True, packed=True) self._rd_addr_out = self.output("rd_addr_out", self.address_width, size=(self.banks, self.mem_output_ports), explicit_array=True, packed=True) else: self._addr_out = self.output("addr_out", self.address_width, size=(self.banks, self.mem_input_ports), packed=True, explicit_array=True) self._cen_to_strg = self.output("cen_to_strg", self.mem_output_ports, size=self.banks, explicit_array=True, packed=True) self._wen_to_strg = self.output("wen_to_strg", self.mem_input_ports, size=self.banks, explicit_array=True, packed=True) if self.num_tb > 0: self._tb_valid_out = self.var("tb_valid_out", self.interconnect_output_ports) self._port_wens = self.var("port_wens", self.interconnect_input_ports) #################### ##### APP CTRL ##### #################### self._ack_transpose = self.var("ack_transpose", self.banks, size=self.interconnect_output_ports, explicit_array=True, packed=True) self._ack_reduced = self.var("ack_reduced", self.interconnect_output_ports) self.app_ctrl = AppCtrl(interconnect_input_ports=self.interconnect_input_ports, interconnect_output_ports=self.interconnect_output_ports, depth_width=self.app_ctrl_depth_width, sprt_stcl_valid=True, stcl_iter_support=self.stcl_valid_iter) # Some refactoring here for pond to get rid of app controllers... # This is honestly pretty messy and should clean up nicely when we have the spec... self._ren_out_reduced = self.var("ren_out_reduced", self.interconnect_output_ports) if self.num_tb == 0 or self.remove_tb: self.wire(self._wen, self._wen_in) self.wire(self._ren, self._ren_in) self.wire(self._valid_out, self._valid_out_alt) self.wire(self._arb_wen_en, self._wen) self.wire(self._arb_ren_en, self._ren) else: self.add_child("app_ctrl", self.app_ctrl, clk=self._clk, rst_n=self._rst_n, wen_in=self._wen_in, ren_in=self._ren_in, # ren_update=self._tb_valid_out, valid_out_data=self._valid_out, # valid_out_stencil=, wen_out=self._wen, ren_out=self._ren) self.wire(self.app_ctrl.ports.tb_valid, self._tb_valid_out) self.wire(self.app_ctrl.ports.ren_update, self._tb_valid_out) self.app_ctrl_coarse = AppCtrl(interconnect_input_ports=self.interconnect_input_ports, interconnect_output_ports=self.interconnect_output_ports, depth_width=self.app_ctrl_depth_width) self.add_child("app_ctrl_coarse", self.app_ctrl_coarse, clk=self._clk, rst_n=self._rst_n, wen_in=self._to_iac_valid, # self._port_wens & self._to_iac_valid, # Gets valid and the ack ren_in=self._ren_out_reduced, tb_valid=kts.const(0, 1), ren_update=self._ack_reduced, wen_out=self._arb_wen_en, ren_out=self._arb_ren_en) ########################### ##### INPUT AGG SCHED ##### ########################### ########################################### ##### AGGREGATION ALIGNERS (OPTIONAL) ##### ########################################### # These variables are holders and can be swapped out if needed self._data_consume = self._data_in self._valid_consume = self._wen # Zero out if not aligning if(self.agg_height > 0): self._align_to_agg = self.var("align_input", self.interconnect_input_ports) # Add the aggregation buffer aligners if(self.align_input): self._data_consume = self.var("data_consume", self.data_width, size=self.interconnect_input_ports, packed=True, explicit_array=True) self._valid_consume = self.var("valid_consume", self.interconnect_input_ports) # Make new aggregation aligners for each port for i in range(self.interconnect_input_ports): new_child = AggAligner(self.data_width, self.max_line_length) self.add_child(f"agg_align_{i}", new_child, clk=self._clk, rst_n=self._rst_n, in_dat=self._data_in[i], in_valid=self._wen[i], align=self._align_to_agg[i], out_valid=self._valid_consume[i], out_dat=self._data_consume[i]) else: if self.agg_height > 0: self.wire(self._align_to_agg, const(0, self._align_to_agg.width)) ################################################ ##### END: AGGREGATION ALIGNERS (OPTIONAL) ##### ################################################ if self.agg_height == 0: self._to_iac_dat = self._data_consume self._to_iac_valid = self._valid_consume ################################## ##### AGG BUFFERS (OPTIONAL) ##### ################################## # Only instantiate agg_buffer if needed if(self.agg_height > 0): self._to_iac_dat = self.var("ab_to_mem_dat", self.mem_width, size=self.interconnect_input_ports, packed=True, explicit_array=True) # self._to_iac_valid = self.var("ab_to_mem_valid", # self.interconnect_input_ports) self._agg_buffers = [] # Add input aggregations buffers for i in range(self.interconnect_input_ports): # add children aggregator buffers... agg_buffer_new = AggregationBuffer(self.agg_height, self.data_width, self.mem_width, self.max_agg_schedule) self._agg_buffers.append(agg_buffer_new) self.add_child(f"agg_in_{i}", agg_buffer_new, clk=self._clk, rst_n=self._rst_n, data_in=self._data_consume[i], valid_in=self._valid_consume[i], align=self._align_to_agg[i], data_out=self._to_iac_dat[i], valid_out=self._to_iac_valid[i]) ####################################### ##### END: AGG BUFFERS (OPTIONAL) ##### ####################################### self._ready_tba = self.var("ready_tba", self.interconnect_output_ports) #################################### ##### INPUT ADDRESS CONTROLLER ##### #################################### self._wen_to_arb = self.var("wen_to_arb", self.mem_input_ports, size=self.banks, explicit_array=True, packed=True) self._addr_to_arb = self.var("addr_to_arb", self.address_width, size=(self.banks, self.mem_input_ports), explicit_array=True, packed=True) self._data_to_arb = self.var("data_to_arb", self.data_width, size=(self.banks, self.mem_input_ports, self.fw_int), explicit_array=True, packed=True) # Connect these inputs ports to an address generator iac = InputAddrCtrl(interconnect_input_ports=self.interconnect_input_ports, mem_depth=self.mem_depth, num_tiles=self.num_tiles, banks=self.banks, iterator_support=self.input_iterator_support, address_width=self.address_width, data_width=self.data_width, fetch_width=self.mem_width, multiwrite=self.multiwrite, strg_wr_ports=self.mem_input_ports, config_width=self.input_config_width) self.add_child(f"input_addr_ctrl", iac, clk=self._clk, rst_n=self._rst_n, valid_in=self._to_iac_valid, # wen_en=kts.concat(*([kts.const(1, 1)] * self.interconnect_input_ports)), wen_en=self._arb_wen_en, data_in=self._to_iac_dat, wen_to_sram=self._wen_to_arb, addr_out=self._addr_to_arb, port_out=self._port_wens, data_out=self._data_to_arb) ######################################### ##### END: INPUT ADDRESS CONTROLLER ##### ######################################### self._arb_acks = self.var("arb_acks", self.interconnect_output_ports, size=self.banks, explicit_array=True, packed=True) self._prefetch_step = self.var("prefetch_step", self.interconnect_output_ports) self._oac_step = self.var("oac_step", self.interconnect_output_ports) self._oac_valid = self.var("oac_valid", self.interconnect_output_ports) self._ren_out = self.var("ren_out", self.interconnect_output_ports, size=self.banks, explicit_array=True, packed=True) self._ren_out_tpose = self.var("ren_out_tpose", self.banks, size=self.interconnect_output_ports, explicit_array=True, packed=True) self._oac_addr_out = self.var("oac_addr_out", self.address_width, size=self.interconnect_output_ports, explicit_array=True, packed=True) ##################################### ##### OUTPUT ADDRESS CONTROLLER ##### ##################################### oac = OutputAddrCtrl(interconnect_output_ports=self.interconnect_output_ports, mem_depth=self.mem_depth, num_tiles=self.num_tiles, banks=self.banks, iterator_support=self.output_iterator_support, address_width=self.address_width, config_width=self.output_config_width) if self.remove_tb: self.wire(self._oac_valid, self._ren) self.wire(self._oac_step, self._ren) else: self.wire(self._oac_valid, self._prefetch_step) self.wire(self._oac_step, self._ack_reduced) self.chain_idx_bits = max(1, clog2(num_tiles)) self._enable_chain_output = self.input("enable_chain_output", 1) self._chain_idx_output = self.input("chain_idx_output", self.chain_idx_bits) self.add_child(f"output_addr_ctrl", oac, clk=self._clk, rst_n=self._rst_n, valid_in=self._oac_valid, ren=self._ren_out, addr_out=self._oac_addr_out, step_in=self._oac_step) for i in range(self.interconnect_output_ports): for j in range(self.banks): self.wire(self._ren_out_tpose[i][j], self._ren_out[j][i]) ############################## ##### READ/WRITE ARBITER ##### ############################## # Hook up the read write arbiters for each bank self._arb_dat_out = self.var("arb_dat_out", self.data_width, size=(self.banks, self.mem_output_ports, self.fw_int), explicit_array=True, packed=True) self._arb_port_out = self.var("arb_port_out", self.interconnect_output_ports, size=(self.banks, self.mem_output_ports), explicit_array=True, packed=True) self._arb_valid_out = self.var("arb_valid_out", self.mem_output_ports, size=self.banks, explicit_array=True, packed=True) self._rd_sync_gate = self.var("rd_sync_gate", self.interconnect_output_ports) self.arbiters = [] for i in range(self.banks): rw_arb = RWArbiter(fetch_width=self.mem_width, data_width=self.data_width, memory_depth=self.mem_depth, num_tiles=self.num_tiles, int_in_ports=self.interconnect_input_ports, int_out_ports=self.interconnect_output_ports, strg_wr_ports=self.mem_input_ports, strg_rd_ports=self.mem_output_ports, read_delay=self.read_delay, rw_same_cycle=self.rw_same_cycle, separate_addresses=self.rw_same_cycle) self.arbiters.append(rw_arb) self.add_child(f"rw_arb_{i}", rw_arb, clk=self._clk, rst_n=self._rst_n, wen_in=self._wen_to_arb[i], w_data=self._data_to_arb[i], w_addr=self._addr_to_arb[i], data_from_mem=self._data_from_strg[i], mem_valid_data=self._mem_valid_data[i], out_mem_valid_data=self._out_mem_valid_data[i], ren_en=self._arb_ren_en, rd_addr=self._oac_addr_out, out_data=self._arb_dat_out[i], out_port=self._arb_port_out[i], out_valid=self._arb_valid_out[i], cen_mem=self._cen_to_strg[i], wen_mem=self._wen_to_strg[i], data_to_mem=self._data_to_strg[i], out_ack=self._arb_acks[i]) # Bind the separate addrs if self.rw_same_cycle: self.wire(rw_arb.ports.wr_addr_to_mem, self._wr_addr_out[i]) self.wire(rw_arb.ports.rd_addr_to_mem, self._rd_addr_out[i]) else: self.wire(rw_arb.ports.addr_to_mem, self._addr_out[i]) if self.remove_tb: self.wire(rw_arb.ports.ren_in, self._ren_out[i]) else: self.wire(rw_arb.ports.ren_in, self._ren_out[i] & self._rd_sync_gate) self.num_tb_bits = max(1, clog2(self.num_tb)) self._data_to_sync = self.var("data_to_sync", self.data_width, size=(self.interconnect_output_ports, self.fw_int), explicit_array=True, packed=True) self._valid_to_sync = self.var("valid_to_sync", self.interconnect_output_ports) self._data_to_tba = self.var("data_to_tba", self.data_width, size=(self.interconnect_output_ports, self.fw_int), explicit_array=True, packed=True) self._valid_to_tba = self.var("valid_to_tba", self.interconnect_output_ports) self._data_to_pref = self.var("data_to_pref", self.data_width, size=(self.interconnect_output_ports, self.fw_int), explicit_array=True, packed=True) self._valid_to_pref = self.var("valid_to_pref", self.interconnect_output_ports) ####################### ##### DEMUX READS ##### ####################### dmux_rd = DemuxReads(fetch_width=self.mem_width, data_width=self.data_width, banks=self.banks, int_out_ports=self.interconnect_output_ports, strg_rd_ports=self.mem_output_ports) self._arb_dat_out_f = self.var("arb_dat_out_f", self.data_width, size=(self.banks * self.mem_output_ports, self.fw_int), explicit_array=True, packed=True) self._arb_port_out_f = self.var("arb_port_out_f", self.interconnect_output_ports, size=(self.banks * self.mem_output_ports), explicit_array=True, packed=True) self._arb_valid_out_f = self.var("arb_valid_out_f", self.mem_output_ports * self.banks) self._arb_mem_valid_data_f = self.var("arb_mem_valid_data_f", self.mem_output_ports * self.banks) self._arb_mem_valid_data_out = self.var("arb_mem_valid_data_out", self.interconnect_output_ports) self._mem_valid_data_sync = self.var("mem_valid_data_sync", self.interconnect_output_ports) self._mem_valid_data_pref = self.var("mem_valid_data_pref", self.interconnect_output_ports) tmp_cnt = 0 for i in range(self.banks): for j in range(self.mem_output_ports): self.wire(self._arb_dat_out_f[tmp_cnt], self._arb_dat_out[i][j]) self.wire(self._arb_port_out_f[tmp_cnt], self._arb_port_out[i][j]) self.wire(self._arb_valid_out_f[tmp_cnt], self._arb_valid_out[i][j]) self.wire(self._arb_mem_valid_data_f[tmp_cnt], self._out_mem_valid_data[i][j]) tmp_cnt = tmp_cnt + 1 # If this is end of the road... if self.remove_tb: assert self.fw_int == 1, "Make it easier on me now..." self.add_child("demux_rds", dmux_rd, clk=self._clk, rst_n=self._rst_n, data_in=self._arb_dat_out_f, mem_valid_data=self._arb_mem_valid_data_f, mem_valid_data_out=self._arb_mem_valid_data_out, valid_in=self._arb_valid_out_f, port_in=self._arb_port_out_f, valid_out=self._valid_out_alt) for i in range(self.interconnect_output_ports): self.wire(self._data_out[i], dmux_rd.ports.data_out[i]) else: self.add_child("demux_rds", dmux_rd, clk=self._clk, rst_n=self._rst_n, data_in=self._arb_dat_out_f, mem_valid_data=self._arb_mem_valid_data_f, mem_valid_data_out=self._arb_mem_valid_data_out, valid_in=self._arb_valid_out_f, port_in=self._arb_port_out_f, data_out=self._data_to_sync, valid_out=self._valid_to_sync) ####################### ##### SYNC GROUPS ##### ####################### sync_group = SyncGroups(fetch_width=self.mem_width, data_width=self.data_width, int_out_ports=self.interconnect_output_ports) for i in range(self.interconnect_output_ports): self.wire(self._ren_out_reduced[i], self._ren_out_tpose[i].r_or()) self.add_child("sync_grp", sync_group, clk=self._clk, rst_n=self._rst_n, data_in=self._data_to_sync, mem_valid_data=self._arb_mem_valid_data_out, mem_valid_data_out=self._mem_valid_data_sync, valid_in=self._valid_to_sync, data_out=self._data_to_pref, valid_out=self._valid_to_pref, ren_in=self._ren_out_reduced, rd_sync_gate=self._rd_sync_gate, ack_in=self._ack_reduced) # This is the end of the line if we aren't using tb ###################### ##### PREFETCHER ##### ###################### prefetchers = [] for i in range(self.interconnect_output_ports): pref = Prefetcher(fetch_width=self.mem_width, data_width=self.data_width, max_prefetch=self.max_prefetch) prefetchers.append(pref) if self.num_tb == 0: assert self.fw_int == 1, \ "If no transpose buffer, data width needs match memory width" self.add_child(f"pre_fetch_{i}", pref, clk=self._clk, rst_n=self._rst_n, data_in=self._data_to_pref[i], mem_valid_data=self._mem_valid_data_sync[i], mem_valid_data_out=self._mem_valid_data_pref[i], valid_read=self._valid_to_pref[i], tba_rdy_in=self._ren[i], # data_out=self._data_out[i], valid_out=self._valid_out_alt[i], prefetch_step=self._prefetch_step[i]) self.wire(self._data_out[i], pref.ports.data_out[0]) else: self.add_child(f"pre_fetch_{i}", pref, clk=self._clk, rst_n=self._rst_n, data_in=self._data_to_pref[i], mem_valid_data=self._mem_valid_data_sync[i], mem_valid_data_out=self._mem_valid_data_pref[i], valid_read=self._valid_to_pref[i], tba_rdy_in=self._ready_tba[i], data_out=self._data_to_tba[i], valid_out=self._valid_to_tba[i], prefetch_step=self._prefetch_step[i]) ############################# ##### TRANSPOSE BUFFERS ##### ############################# if self.num_tb > 0: self._tb_data_out = self.var("tb_data_out", self.data_width, size=self.interconnect_output_ports, explicit_array=True, packed=True) self._tb_valid_out = self.var("tb_valid_out", self.interconnect_output_ports) for i in range(self.interconnect_output_ports): tba = TransposeBufferAggregation(word_width=self.data_width, fetch_width=self.fw_int, num_tb=self.num_tb, max_tb_height=self.max_tb_height, max_range=self.tb_range_max, max_range_inner=self.tb_range_inner_max, max_stride=self.max_tb_stride, tb_iterator_support=self.tb_iterator_support) self.add_child(f"tba_{i}", tba, clk=self._clk, rst_n=self._rst_n, SRAM_to_tb_data=self._data_to_tba[i], valid_data=self._valid_to_tba[i], tb_index_for_data=0, ack_in=self._valid_to_tba[i], mem_valid_data=self._mem_valid_data_pref[i], tb_to_interconnect_data=self._tb_data_out[i], tb_to_interconnect_valid=self._tb_valid_out[i], tb_arbiter_rdy=self._ready_tba[i], tba_ren=self._ren[i]) for i in range(self.interconnect_output_ports): self.wire(self._data_out[i], self._tb_data_out[i]) # self.wire(self._valid_out[i], self._tb_valid_out[i]) else: self.wire(self._valid_out, self._valid_out_alt) #################### ##### ADD CODE ##### #################### self.add_code(self.transpose_acks) self.add_code(self.reduce_acks)
def __init__( self, data_width=16, # CGRA Params mem_width=64, mem_depth=512, banks=1, input_addr_iterator_support=6, output_addr_iterator_support=6, input_sched_iterator_support=6, output_sched_iterator_support=6, config_width=16, # output_config_width=16, interconnect_input_ports=2, # Connection to int interconnect_output_ports=2, mem_input_ports=1, mem_output_ports=1, read_delay=1, # Cycle delay in read (SRAM vs Register File) rw_same_cycle=False, # Does the memory allow r+w in same cycle? agg_height=4, tb_height=2): super().__init__("strg_ub_tb_only") ################################################################################## # Capture constructor parameter... ################################################################################## self.fetch_width = mem_width // data_width self.interconnect_input_ports = interconnect_input_ports self.interconnect_output_ports = interconnect_output_ports self.agg_height = agg_height self.tb_height = tb_height self.mem_width = mem_width self.mem_depth = mem_depth self.config_width = config_width self.data_width = data_width self.input_addr_iterator_support = input_addr_iterator_support self.input_sched_iterator_support = input_sched_iterator_support self.default_iterator_support = 6 self.default_config_width = 16 self.sram_iterator_support = 6 self.agg_rd_addr_gen_width = 8 ################################################################################## # IO ################################################################################## self._clk = self.clock("clk") self._rst_n = self.reset("rst_n") self._cycle_count = self.input("cycle_count", 16) # data from SRAM self._sram_read_data = self.input("sram_read_data", self.data_width, size=self.fetch_width, packed=True, explicit_array=True) # read enable from SRAM self._t_read = self.input("t_read", self.interconnect_output_ports) # sram to tb for loop self._loops_sram2tb_mux_sel = self.input( "loops_sram2tb_mux_sel", width=max(clog2(self.default_iterator_support), 1), size=self.interconnect_output_ports, explicit_array=True, packed=True) self._loops_sram2tb_restart = self.input( "loops_sram2tb_restart", width=1, size=self.interconnect_output_ports, explicit_array=True, packed=True) self._valid_out = self.output("accessor_output", self.interconnect_output_ports) self._data_out = self.output("data_out", self.data_width, size=self.interconnect_output_ports, packed=True, explicit_array=True) ################################################################################## # TB RELEVANT SIGNALS ################################################################################## self._tb = self.var("tb", width=self.data_width, size=(self.interconnect_output_ports, self.tb_height, self.fetch_width), packed=True, explicit_array=True) self._tb_write_addr = self.var("tb_write_addr", 2 + max(1, clog2(self.tb_height)), size=self.interconnect_output_ports, packed=True, explicit_array=True) self._tb_read_addr = self.var("tb_read_addr", 2 + max(1, clog2(self.tb_height)), size=self.interconnect_output_ports, packed=True, explicit_array=True) # write enable to tb, delayed 1 cycle from SRAM reads self._t_read_d1 = self.var("t_read_d1", self.interconnect_output_ports) # read enable for reads from tb self._tb_read = self.var("tb_read", self.interconnect_output_ports) # Break out valids... self.wire(self._valid_out, self._tb_read) # delayed input mux_sel and restart signals from sram read/tb write # for loop and scheduling self._mux_sel_d1 = self.var("mux_sel_d1", kts.clog2(self.default_iterator_support), size=self.interconnect_output_ports, packed=True, explicit_array=True) self._restart_d1 = self.var("restart_d1", width=1, size=self.interconnect_output_ports, explicit_array=True, packed=True) for i in range(self.interconnect_output_ports): # signals delayed by 1 cycle from SRAM @always_ff((posedge, "clk"), (negedge, "rst_n")) def delay_read(): if ~self._rst_n: self._t_read_d1[i] = 0 self._mux_sel_d1[i] = 0 self._restart_d1[i] = 0 else: self._t_read_d1[i] = self._t_read[i] self._mux_sel_d1[i] = self._loops_sram2tb_mux_sel[i] self._restart_d1[i] = self._loops_sram2tb_restart[i] self.add_code(delay_read) ################################################################################## # TB PATHS ################################################################################## for i in range(self.interconnect_output_ports): self.tb_iter_support = 6 self.tb_addr_width = 4 self.tb_range_width = 16 _AG = AddrGen(iterator_support=self.default_iterator_support, config_width=self.tb_addr_width) self.add_child(f"tb_write_addr_gen_{i}", _AG, clk=self._clk, rst_n=self._rst_n, step=self._t_read_d1[i], mux_sel=self._mux_sel_d1[i], restart=self._restart_d1[i]) safe_wire(gen=self, w_to=self._tb_write_addr[i], w_from=_AG.ports.addr_out) @always_ff((posedge, "clk")) def tb_ctrl(): if self._t_read_d1[i]: self._tb[i][self._tb_write_addr[i][0]] = \ self._sram_read_data self.add_code(tb_ctrl) # READ FROM TB fl_ctr_tb_rd = ForLoop(iterator_support=self.tb_iter_support, config_width=self.tb_range_width) self.add_child(f"loops_buf2out_read_{i}", fl_ctr_tb_rd, clk=self._clk, rst_n=self._rst_n, step=self._tb_read[i]) _AG = AddrGen(iterator_support=self.tb_iter_support, config_width=self.tb_addr_width) self.add_child( f"tb_read_addr_gen_{i}", _AG, clk=self._clk, rst_n=self._rst_n, step=self._tb_read[i], # addr_out=self._tb_read_addr[i]) mux_sel=fl_ctr_tb_rd.ports.mux_sel_out, restart=fl_ctr_tb_rd.ports.restart) safe_wire(gen=self, w_to=self._tb_read_addr[i], w_from=_AG.ports.addr_out) self.add_child( f"tb_read_sched_gen_{i}", SchedGen( iterator_support=self.tb_iter_support, # config_width=self.tb_addr_width), config_width=16), clk=self._clk, rst_n=self._rst_n, cycle_count=self._cycle_count, mux_sel=fl_ctr_tb_rd.ports.mux_sel_out, finished=fl_ctr_tb_rd.ports.restart, valid_output=self._tb_read[i]) @always_comb def tb_to_out(): self._data_out[i] = self._tb[i][self._tb_read_addr[i][ clog2(self.tb_height) + clog2(self.fetch_width) - 1, clog2(self.fetch_width)]][self._tb_read_addr[i][ clog2(self.fetch_width) - 1, 0]] self.add_code(tb_to_out)
def __init__(self, data_width=16, banks=1, memory_width=64, rw_same_cycle=False, read_delay=1, addr_width=9): super().__init__("strg_fifo") # Generation parameters self.banks = banks self.data_width = data_width self.memory_width = memory_width self.rw_same_cycle = rw_same_cycle self.read_delay = read_delay self.addr_width = addr_width self.fw_int = int(self.memory_width / self.data_width) # assert banks > 1 or rw_same_cycle is True or self.fw_int > 1, \ # "Can't sustain throughput with this setup. Need potential bandwidth for " + \ # "1 write and 1 read in a cycle - try using more banks or a macro that supports 1R1W" # Clock and Reset self._clk = self.clock("clk") self._rst_n = self.reset("rst_n") # Inputs + Outputs self._push = self.input("push", 1) self._data_in = self.input("data_in", self.data_width) self._pop = self.input("pop", 1) self._data_out = self.output("data_out", self.data_width) self._valid_out = self.output("valid_out", 1) self._empty = self.output("empty", 1) self._full = self.output("full", 1) # get relevant signals from the storage banks self._data_from_strg = self.input("data_from_strg", self.data_width, size=(self.banks, self.fw_int), explicit_array=True, packed=True) self._wen_addr = self.var("wen_addr", self.addr_width, size=self.banks, explicit_array=True, packed=True) self._ren_addr = self.var("ren_addr", self.addr_width, size=self.banks, explicit_array=True, packed=True) self._front_combined = self.var("front_combined", self.data_width, size=self.fw_int, explicit_array=True, packed=True) self._data_to_strg = self.output("data_to_strg", self.data_width, size=(self.banks, self.fw_int), explicit_array=True, packed=True) self._wen_to_strg = self.output("wen_to_strg", self.banks) self._ren_to_strg = self.output("ren_to_strg", self.banks) self._num_words_mem = self.var("num_words_mem", self.data_width) if self.banks == 1: self._curr_bank_wr = self.var("curr_bank_wr", 1) self.wire(self._curr_bank_wr, kts.const(0, 1)) self._curr_bank_rd = self.var("curr_bank_rd", 1) self.wire(self._curr_bank_rd, kts.const(0, 1)) else: self._curr_bank_wr = self.var("curr_bank_wr", kts.clog2(self.banks)) self._curr_bank_rd = self.var("curr_bank_rd", kts.clog2(self.banks)) self._write_queue = self.var("write_queue", self.data_width, size=(self.banks, self.fw_int), explicit_array=True, packed=True) # Lets us know if the bank has a write queued up self._queued_write = self.var("queued_write", self.banks) self._front_data_out = self.var("front_data_out", self.data_width) self._front_pop = self.var("front_pop", 1) self._front_empty = self.var("front_empty", 1) self._front_full = self.var("front_full", 1) self._front_valid = self.var("front_valid", 1) self._front_par_read = self.var("front_par_read", 1) self._front_par_out = self.var("front_par_out", self.data_width, size=(self.fw_int, 1), explicit_array=True, packed=True) self._front_rd_ptr = self.var("front_rd_ptr", max(1, clog2(self.fw_int))) self._front_push = self.var("front_push", 1) self.wire(self._front_push, self._push & (~self._full | self._pop)) self._front_rf = RegFIFO(data_width=self.data_width, width_mult=1, depth=self.fw_int, parallel=True, break_out_rd_ptr=True) # This one breaks out the read pointer so we can properly # reorder the data to storage self.add_child("front_rf", self._front_rf, clk=self._clk, clk_en=kts.const(1, 1), rst_n=self._rst_n, push=self._front_push, pop=self._front_pop, empty=self._front_empty, full=self._front_full, valid=self._front_valid, parallel_read=self._front_par_read, parallel_load=kts.const(0, 1), # We don't need to parallel load the front parallel_in=0, # Same reason as above parallel_out=self._front_par_out, num_load=0, rd_ptr_out=self._front_rd_ptr) self.wire(self._front_rf.ports.data_in[0], self._data_in) self.wire(self._front_data_out, self._front_rf.ports.data_out[0]) self._back_data_in = self.var("back_data_in", self.data_width) self._back_data_out = self.var("back_data_out", self.data_width) self._back_push = self.var("back_push", 1) self._back_empty = self.var("back_empty", 1) self._back_full = self.var("back_full", 1) self._back_valid = self.var("back_valid", 1) self._back_pl = self.var("back_pl", 1) self._back_par_in = self.var("back_par_in", self.data_width, size=(self.fw_int, 1), explicit_array=True, packed=True) self._back_num_load = self.var("back_num_load", clog2(self.fw_int) + 1) self._back_occ = self.var("back_occ", clog2(self.fw_int) + 1) self._front_occ = self.var("front_occ", clog2(self.fw_int) + 1) self._back_rf = RegFIFO(data_width=self.data_width, width_mult=1, depth=self.fw_int, parallel=True, break_out_rd_ptr=False) self._fw_is_1 = self.var("fw_is_1", 1) self.wire(self._fw_is_1, kts.const(self.fw_int == 1, 1)) self._back_pop = self.var("back_pop", 1) if self.fw_int == 1: self.wire(self._back_pop, self._pop & (~self._empty | self._push) & ~self._back_pl) else: self.wire(self._back_pop, self._pop & (~self._empty | self._push)) self.add_child("back_rf", self._back_rf, clk=self._clk, clk_en=kts.const(1, 1), rst_n=self._rst_n, push=self._back_push, pop=self._back_pop, empty=self._back_empty, full=self._back_full, valid=self._back_valid, parallel_read=kts.const(0, 1), # Only do back load when data is going there parallel_load=self._back_pl & self._back_num_load.r_or(), parallel_in=self._back_par_in, num_load=self._back_num_load) self.wire(self._back_rf.ports.data_in[0], self._back_data_in) self.wire(self._back_data_out, self._back_rf.ports.data_out[0]) # send the writes through when a read isn't happening for i in range(self.banks): self.add_code(self.send_writes, idx=i) self.add_code(self.send_reads, idx=i) # Set the parallel load to back bank - if no delay it's immediate # if not, it's delayed :) if self.read_delay == 1: self._ren_delay = self.var("ren_delay", 1) self.add_code(self.set_parallel_ld_delay_1) self.wire(self._back_pl, self._ren_delay) else: self.wire(self._back_pl, self._ren_to_strg.r_or()) # Combine front end data - just the items + incoming # this data is actually based on the rd_ptr from the front fifo for i in range(self.fw_int): self.wire(self._front_combined[i], self._front_par_out[self._front_rd_ptr + i]) # This is always true # self.wire(self._front_combined[self.fw_int - 1], self._data_in) # prioritize queued writes, otherwise send combined data for i in range(self.banks): self.wire(self._data_to_strg[i], kts.ternary(self._queued_write[i], self._write_queue[i], self._front_combined)) # Wire the thin output from front to thin input to back self.wire(self._back_data_in, self._front_data_out) self.wire(self._back_push, self._front_valid) self.add_code(self.set_front_pop) # Queue writes for i in range(self.banks): self.add_code(self.set_write_queue, idx=i) # Track number of words in memory # if self.read_delay == 1: # self.add_code(self.set_num_words_mem_delay) # else: self.add_code(self.set_num_words_mem) # Track occupancy of the two small fifos self.add_code(self.set_front_occ) self.add_code(self.set_back_occ) if self.banks > 1: self.add_code(self.set_curr_bank_wr) self.add_code(self.set_curr_bank_rd) if self.read_delay == 1: self._prev_bank_rd = self.var("prev_bank_rd", max(1, kts.clog2(self.banks))) self.add_code(self.set_prev_bank_rd) # Parallel load data to back - based on num_load index_into = self._curr_bank_rd if self.read_delay == 1: index_into = self._prev_bank_rd for i in range(self.fw_int - 1): # Shift data over if you bypassed from the memory output self.wire(self._back_par_in[i], kts.ternary(self._back_num_load == self.fw_int, self._data_from_strg[index_into][i], self._data_from_strg[index_into][i + 1])) self.wire(self._back_par_in[self.fw_int - 1], kts.ternary(self._back_num_load == self.fw_int, self._data_from_strg[index_into][self.fw_int - 1], kts.const(0, self.data_width))) # Set the parallel read to the front fifo - analogous with trying to write to the memory self.add_code(self.set_front_par_read) # Set the number being parallely loaded into the register self.add_code(self.set_back_num_load) # Data out and valid out are (in the general case) just the data and valid from the back fifo # In the case where we have a fresh memory read, it would be from that bank_idx_read = self._curr_bank_rd if self.read_delay == 1: bank_idx_read = self._prev_bank_rd self.wire(self._data_out, kts.ternary(self._back_pl, self._data_from_strg[bank_idx_read][0], self._back_data_out)) self.wire(self._valid_out, kts.ternary(self._back_pl, self._pop, self._back_valid)) # Set addresses to storage for i in range(self.banks): self.add_code(self.set_wen_addr, idx=i) self.add_code(self.set_ren_addr, idx=i) # Now deal with a shared address vs separate addresses if self.rw_same_cycle: # Separate self._wen_addr_out = self.output("wen_addr_out", self.addr_width, size=self.banks, explicit_array=True, packed=True) self._ren_addr_out = self.output("ren_addr_out", self.addr_width, size=self.banks, explicit_array=True, packed=True) self.wire(self._wen_addr_out, self._wen_addr) self.wire(self._ren_addr_out, self._ren_addr) else: self._addr_out = self.output("addr_out", self.addr_width, size=self.banks, explicit_array=True, packed=True) # If sharing the addresses, send read addr with priority for i in range(self.banks): self.wire(self._addr_out[i], kts.ternary(self._wen_to_strg[i], self._wen_addr[i], self._ren_addr[i])) # Do final empty/full self._num_items = self.var("num_items", self.data_width) self.add_code(self.set_num_items) self._fifo_depth = self.input("fifo_depth", self.data_width) self._fifo_depth.add_attribute(ConfigRegAttr("Fifo depth...")) self.wire(self._empty, self._num_items == 0) self.wire(self._full, self._num_items == (self._fifo_depth))
def test_reg_file_basic(data_width, depth, width_mult, write_ports, read_ports): addr_width = kts.clog2(depth) # Set up model... model_rf = RegisterFileModel(data_width=data_width, write_ports=write_ports, read_ports=read_ports, width_mult=width_mult, depth=depth) new_config = {} model_rf.set_config(new_config=new_config) ### # Set up dut... dut = RegisterFile(data_width=data_width, write_ports=write_ports, read_ports=read_ports, width_mult=width_mult, depth=depth) magma_dut = kts.util.to_magma(dut, flatten_array=True, check_flip_flop_always_ff=False) tester = fault.Tester(magma_dut, magma_dut.clk) ### for key, value in new_config.items(): setattr(tester.circuit, key, value) # initial reset tester.circuit.clk = 0 tester.circuit.rst_n = 0 tester.step(2) tester.circuit.rst_n = 1 tester.step(2) rand.seed(0) for z in range(1000): # Generate new input wen = [] wr_addr = [] wr_data = [] for i in range(write_ports): wen.append(rand.randint(0, 1)) wr_addr.append(rand.randint(0, depth - 1)) new_dat = [] for j in range(width_mult): new_dat.append(rand.randint(0, 2**data_width - 1)) wr_data.append(new_dat) rd_addr = [] for i in range(read_ports): rd_addr.append(rand.randint(0, depth - 1)) if write_ports == 1: tester.circuit.wr_addr = wr_addr[0] else: for i in range(write_ports): setattr(tester.circuit, f"wr_addr_{i}", wr_addr[i]) for i in range(write_ports): tester.circuit.wen[i] = wen[i] if width_mult == 1 and write_ports == 1: tester.circuit.data_in = wr_data[0][0] elif width_mult == 1: for i in range(write_ports): setattr(tester.circuit, f"data_in_{i}_0", wr_data[i][0]) elif write_ports == 1: for i in range(width_mult): setattr(tester.circuit, f"data_in_{i}", wr_data[0][i]) else: for i in range(write_ports): for j in range(width_mult): setattr(tester.circuit, f"data_in_{i}_{j}", wr_data[i][j]) if read_ports == 1: tester.circuit.rd_addr = rd_addr[0] else: for i in range(read_ports): setattr(tester.circuit, f"rd_addr_{i}", rd_addr[i]) model_dat_out = model_rf.interact(wen, wr_addr, rd_addr, wr_data) tester.eval() if width_mult == 1 and read_ports == 1: tester.circuit.data_out.expect(model_dat_out[0][0]) elif width_mult == 1: for i in range(read_ports): getattr(tester.circuit, f"data_out_{i}_0").expect(model_dat_out[i][0]) elif read_ports == 1: for i in range(width_mult): getattr(tester.circuit, f"data_out_{i}").expect(model_dat_out[0][i]) else: for i in range(read_ports): for j in range(width_mult): getattr(tester.circuit, f"data_out_{i}_{j}").expect(model_dat_out[i][j]) tester.step(2) with tempfile.TemporaryDirectory() as tempdir: tester.compile_and_run(target="verilator", directory=tempdir, magma_output="verilog", flags=["-Wno-fatal"])
def __init__(self, interconnect_input_ports, interconnect_output_ports, depth_width=16, sprt_stcl_valid=False, stcl_cnt_width=16, stcl_iter_support=4): super().__init__("app_ctrl", debug=True) self.int_in_ports = interconnect_input_ports self.int_out_ports = interconnect_output_ports self.depth_width = depth_width self.sprt_stcl_valid = sprt_stcl_valid self.stcl_cnt_width = stcl_cnt_width self.stcl_iter_support = stcl_iter_support # Clock and Reset self._clk = self.clock("clk") self._rst_n = self.reset("rst_n") # IO self._wen_in = self.input("wen_in", self.int_in_ports) self._ren_in = self.input("ren_in", self.int_out_ports) self._ren_update = self.input("ren_update", self.int_out_ports) self._tb_valid = self.input("tb_valid", self.int_out_ports) self._valid_out_data = self.output("valid_out_data", self.int_out_ports) self._valid_out_stencil = self.output("valid_out_stencil", self.int_out_ports) # Send tb valid to valid out for now... if self.sprt_stcl_valid: # Add the config registers to watch self._ranges = self.input("ranges", self.stcl_cnt_width, size=self.stcl_iter_support, packed=True, explicit_array=True) self._ranges.add_attribute(ConfigRegAttr("Ranges of stencil valid generator")) self._threshold = self.input("threshold", self.stcl_cnt_width, size=self.stcl_iter_support, packed=True, explicit_array=True) self._threshold.add_attribute(ConfigRegAttr("Threshold of stencil valid generator")) self._dim_counter = self.var("dim_counter", self.stcl_cnt_width, size=self.stcl_iter_support, packed=True, explicit_array=True) self._update = self.var("update", self.stcl_iter_support) self.wire(self._update[0], const(1, 1)) for i in range(self.stcl_iter_support - 1): self.wire(self._update[i + 1], (self._dim_counter[i] == (self._ranges[i] - 1)) & self._update[i]) for i in range(self.stcl_iter_support): self.add_code(self.dim_counter_update, idx=i) # Now we need to just compute stencil valid threshold_comps = [self._dim_counter[_i] >= self._threshold[_i] for _i in range(self.stcl_iter_support)] self.wire(self._valid_out_stencil[0], kts.concat(*threshold_comps).r_and()) for i in range(self.int_out_ports - 1): # self.wire(self._valid_out_stencil[i + 1], 0) # for multiple ports self.wire(self._valid_out_stencil[i + 1], kts.concat(*threshold_comps).r_and()) else: self.wire(self._valid_out_stencil, self._tb_valid) # Now gate the valid with stencil valid self.wire(self._valid_out_data, self._tb_valid & self._valid_out_stencil) self._wr_delay_state_n = self.var("wr_delay_state_n", self.int_out_ports) self._wen_out = self.output("wen_out", self.int_in_ports) self._ren_out = self.output("ren_out", self.int_out_ports) self._write_depth_wo = self.input("write_depth_wo", self.depth_width, size=self.int_in_ports, explicit_array=True, packed=True) self._write_depth_wo.add_attribute(ConfigRegAttr("Depth of writes")) self._write_depth_ss = self.input("write_depth_ss", self.depth_width, size=self.int_in_ports, explicit_array=True, packed=True) self._write_depth_ss.add_attribute(ConfigRegAttr("Depth of writes")) self._write_depth = self.var("write_depth", self.depth_width, size=self.int_in_ports, explicit_array=True, packed=True) for i in range(self.int_in_ports): self.wire(self._write_depth[i], kts.ternary(self._wr_delay_state_n[i], self._write_depth_ss[i], self._write_depth_wo[i])) self._read_depth = self.input("read_depth", self.depth_width, size=self.int_out_ports, explicit_array=True, packed=True) self._read_depth.add_attribute(ConfigRegAttr("Depth of reads")) self._write_count = self.var("write_count", self.depth_width, size=self.int_in_ports, explicit_array=True, packed=True) self._read_count = self.var("read_count", self.depth_width, size=self.int_out_ports, explicit_array=True, packed=True) self._write_done = self.var("write_done", self.int_in_ports) self._write_done_ff = self.var("write_done_ff", self.int_in_ports) self._read_done = self.var("read_done", self.int_out_ports) self._read_done_ff = self.var("read_done_ff", self.int_out_ports) self.in_port_bits = max(1, kts.clog2(self.int_in_ports)) self._input_port = self.input("input_port", self.in_port_bits, size=self.int_out_ports, explicit_array=True, packed=True) self._input_port.add_attribute(ConfigRegAttr("Relative input port for an output port")) self.out_port_bits = max(1, kts.clog2(self.int_out_ports)) self._output_port = self.input("output_port", self.out_port_bits, size=self.int_in_ports, explicit_array=True, packed=True) self._output_port.add_attribute(ConfigRegAttr("Relative output port for an input port")) self._prefill = self.input("prefill", self.int_out_ports) self._prefill.add_attribute(ConfigRegAttr("Is the input stream prewritten?")) for i in range(self.int_out_ports): self.add_code(self.set_read_done, idx=i) if self.int_in_ports == 1: self.add_code(self.set_read_done_ff_one_wr, idx=i) else: self.add_code(self.set_read_done_ff, idx=i) # self._write_done_comb = self.var("write_done_comb", self.int_in_ports) for i in range(self.int_in_ports): self.add_code(self.set_write_done, idx=i) self.add_code(self.set_write_done_ff, idx=i) for i in range(self.int_in_ports): self.add_code(self.set_write_cnt, idx=i) for i in range(self.int_out_ports): if self.int_in_ports == 1: self.add_code(self.set_read_cnt_one_wr, idx=i) else: self.add_code(self.set_read_cnt, idx=i) for i in range(self.int_out_ports): if self.int_in_ports == 1: self.add_code(self.set_wr_delay_state_one_wr, idx=i) else: self.add_code(self.set_wr_delay_state, idx=i) self._read_on = self.var("read_on", self.int_out_ports) for i in range(self.int_out_ports): self.wire(self._read_on[i], self._read_depth[i].r_or()) # If we have prefill enabled, we are skipping the initial delay step... self.wire(self._ren_out, (self._wr_delay_state_n | self._prefill) & ~self._read_done_ff & self._ren_in & self._read_on) self.wire(self._wen_out, ~self._write_done_ff & self._wen_in)
def __init__( self, data_width=16, # CGRA Params mem_width=16, mem_depth=256, banks=1, input_iterator_support=6, # Addr Controllers output_iterator_support=6, input_config_width=16, output_config_width=16, interconnect_input_ports=1, # Connection to int interconnect_output_ports=1, mem_input_ports=1, mem_output_ports=1, use_sram_stub=True, sram_macro_info=SRAMMacroInfo(), read_delay=1, # Cycle delay in read (SRAM vs Register File) rw_same_cycle=True, # Does the memory allow r+w in same cycle? agg_height=4, tb_sched_max=16, config_data_width=32, config_addr_width=8, num_tiles=1, remove_tb=False, fifo_mode=False, add_clk_enable=True, add_flush=True, override_name=None): # name if override_name: self.__name = override_name + "Core" lake_name = override_name else: self.__name = "MemCore" lake_name = "LakeTop" super().__init__(config_addr_width, config_data_width) # Capture everything to the tile object self.data_width = data_width self.mem_width = mem_width self.mem_depth = mem_depth self.banks = banks self.fw_int = int(self.mem_width / self.data_width) self.input_iterator_support = input_iterator_support self.output_iterator_support = output_iterator_support self.input_config_width = input_config_width self.output_config_width = output_config_width self.interconnect_input_ports = interconnect_input_ports self.interconnect_output_ports = interconnect_output_ports self.mem_input_ports = mem_input_ports self.mem_output_ports = mem_output_ports self.use_sram_stub = use_sram_stub self.sram_macro_info = sram_macro_info self.read_delay = read_delay self.rw_same_cycle = rw_same_cycle self.agg_height = agg_height self.config_data_width = config_data_width self.config_addr_width = config_addr_width self.num_tiles = num_tiles self.remove_tb = remove_tb self.fifo_mode = fifo_mode self.add_clk_enable = add_clk_enable self.add_flush = add_flush # self.app_ctrl_depth_width = app_ctrl_depth_width # self.stcl_valid_iter = stcl_valid_iter # Typedefs for ease TData = magma.Bits[self.data_width] TBit = magma.Bits[1] self.__inputs = [] self.__outputs = [] # cache_key = (self.data_width, self.mem_width, self.mem_depth, self.banks, # self.input_iterator_support, self.output_iterator_support, # self.interconnect_input_ports, self.interconnect_output_ports, # self.use_sram_stub, self.sram_macro_info, self.read_delay, # self.rw_same_cycle, self.agg_height, self.max_agg_schedule, # self.input_max_port_sched, self.output_max_port_sched, # self.align_input, self.max_line_length, self.max_tb_height, # self.tb_range_max, self.tb_sched_max, self.max_tb_stride, # self.num_tb, self.tb_iterator_support, self.multiwrite, # self.max_prefetch, self.config_data_width, self.config_addr_width, # self.num_tiles, self.remove_tb, self.fifo_mode, self.stcl_valid_iter, # self.add_clk_enable, self.add_flush, self.app_ctrl_depth_width) cache_key = (self.data_width, self.mem_width, self.mem_depth, self.banks, self.input_iterator_support, self.output_iterator_support, self.interconnect_input_ports, self.interconnect_output_ports, self.use_sram_stub, self.sram_macro_info, self.read_delay, self.rw_same_cycle, self.agg_height, self.config_data_width, self.config_addr_width, self.num_tiles, self.remove_tb, self.fifo_mode, self.add_clk_enable, self.add_flush) # Check for circuit caching if cache_key not in MemCore.__circuit_cache: # Instantiate core object here - will only use the object representation to # query for information. The circuit representation will be cached and retrieved # in the following steps. # lt_dut = LakeTop(data_width=self.data_width, # mem_width=self.mem_width, # mem_depth=self.mem_depth, # banks=self.banks, # input_iterator_support=self.input_iterator_support, # output_iterator_support=self.output_iterator_support, # input_config_width=self.input_config_width, # output_config_width=self.output_config_width, # interconnect_input_ports=self.interconnect_input_ports, # interconnect_output_ports=self.interconnect_output_ports, # use_sram_stub=self.use_sram_stub, # sram_macro_info=self.sram_macro_info, # read_delay=self.read_delay, # rw_same_cycle=self.rw_same_cycle, # agg_height=self.agg_height, # max_agg_schedule=self.max_agg_schedule, # input_max_port_sched=self.input_max_port_sched, # output_max_port_sched=self.output_max_port_sched, # align_input=self.align_input, # max_line_length=self.max_line_length, # max_tb_height=self.max_tb_height, # tb_range_max=self.tb_range_max, # tb_range_inner_max=self.tb_range_inner_max, # tb_sched_max=self.tb_sched_max, # max_tb_stride=self.max_tb_stride, # num_tb=self.num_tb, # tb_iterator_support=self.tb_iterator_support, # multiwrite=self.multiwrite, # max_prefetch=self.max_prefetch, # config_data_width=self.config_data_width, # config_addr_width=self.config_addr_width, # num_tiles=self.num_tiles, # app_ctrl_depth_width=self.app_ctrl_depth_width, # remove_tb=self.remove_tb, # fifo_mode=self.fifo_mode, # add_clk_enable=self.add_clk_enable, # add_flush=self.add_flush, # stcl_valid_iter=self.stcl_valid_iter) lt_dut = LakeTop( data_width=self.data_width, mem_width=self.mem_width, mem_depth=self.mem_depth, banks=self.banks, input_iterator_support=self.input_iterator_support, output_iterator_support=self.output_iterator_support, input_config_width=self.input_config_width, output_config_width=self.output_config_width, interconnect_input_ports=self.interconnect_input_ports, interconnect_output_ports=self.interconnect_output_ports, use_sram_stub=self.use_sram_stub, sram_macro_info=self.sram_macro_info, read_delay=self.read_delay, rw_same_cycle=self.rw_same_cycle, agg_height=self.agg_height, config_data_width=self.config_data_width, config_addr_width=self.config_addr_width, num_tiles=self.num_tiles, remove_tb=self.remove_tb, fifo_mode=self.fifo_mode, add_clk_enable=self.add_clk_enable, add_flush=self.add_flush, name=lake_name, gen_addr=False) change_sram_port_pass = change_sram_port_names( use_sram_stub, sram_macro_info) circ = kts.util.to_magma( lt_dut, flatten_array=True, check_multiple_driver=False, optimize_if=False, check_flip_flop_always_ff=False, additional_passes={"change_sram_port": change_sram_port_pass}) MemCore.__circuit_cache[cache_key] = (circ, lt_dut) else: circ, lt_dut = MemCore.__circuit_cache[cache_key] # Save as underlying circuit object self.underlying = FromMagma(circ) # Enumerate input and output ports # (clk and reset are assumed) core_interface = get_interface(lt_dut) cfgs = extract_top_config(lt_dut) assert len(cfgs) > 0, "No configs?" # We basically add in the configuration bus differently # than the other ports... skip_names = [ "config_data_in", "config_write", "config_addr_in", "config_data_out", "config_read", "config_en", "clk_en" ] # Create a list of signals that will be able to be # hardwired to a constant at runtime... control_signals = [] # The rest of the signals to wire to the underlying representation... other_signals = [] # for port_name, port_size, port_width, is_ctrl, port_dir, explicit_array in core_interface: for io_info in core_interface: if io_info.port_name in skip_names: continue ind_ports = io_info.port_width intf_type = TBit # For our purposes, an explicit array means the inner data HAS to be 16 bits if io_info.expl_arr: ind_ports = io_info.port_size[0] intf_type = TData dir_type = magma.In app_list = self.__inputs if io_info.port_dir == "PortDirection.Out": dir_type = magma.Out app_list = self.__outputs if ind_ports > 1: for i in range(ind_ports): self.add_port(f"{io_info.port_name}_{i}", dir_type(intf_type)) app_list.append(self.ports[f"{io_info.port_name}_{i}"]) else: self.add_port(io_info.port_name, dir_type(intf_type)) app_list.append(self.ports[io_info.port_name]) # classify each signal for wiring to underlying representation... if io_info.is_ctrl: control_signals.append((io_info.port_name, io_info.port_width)) else: if ind_ports > 1: for i in range(ind_ports): other_signals.append( (f"{io_info.port_name}_{i}", io_info.port_dir, io_info.expl_arr, i, io_info.port_name)) else: other_signals.append( (io_info.port_name, io_info.port_dir, io_info.expl_arr, 0, io_info.port_name)) assert (len(self.__outputs) > 0) # We call clk_en stall at this level for legacy reasons???? self.add_ports(stall=magma.In(TBit), ) self.chain_idx_bits = max(1, kts.clog2(self.num_tiles)) # put a 1-bit register and a mux to select the control signals for control_signal, width in control_signals: if width == 1: mux = MuxWrapper(2, 1, name=f"{control_signal}_sel") reg_value_name = f"{control_signal}_reg_value" reg_sel_name = f"{control_signal}_reg_sel" self.add_config(reg_value_name, 1) self.add_config(reg_sel_name, 1) self.wire(mux.ports.I[0], self.ports[control_signal]) self.wire(mux.ports.I[1], self.registers[reg_value_name].ports.O) self.wire(mux.ports.S, self.registers[reg_sel_name].ports.O) # 0 is the default wire, which takes from the routing network self.wire(mux.ports.O[0], self.underlying.ports[control_signal][0]) else: for i in range(width): mux = MuxWrapper(2, 1, name=f"{control_signal}_{i}_sel") reg_value_name = f"{control_signal}_{i}_reg_value" reg_sel_name = f"{control_signal}_{i}_reg_sel" self.add_config(reg_value_name, 1) self.add_config(reg_sel_name, 1) self.wire(mux.ports.I[0], self.ports[f"{control_signal}_{i}"]) self.wire(mux.ports.I[1], self.registers[reg_value_name].ports.O) self.wire(mux.ports.S, self.registers[reg_sel_name].ports.O) # 0 is the default wire, which takes from the routing network self.wire(mux.ports.O[0], self.underlying.ports[control_signal][i]) # Wire the other signals up... for pname, pdir, expl_arr, ind, uname in other_signals: # If we are in an explicit array moment, use the given wire name... if expl_arr is False: # And if not, use the index self.wire(self.ports[pname][0], self.underlying.ports[uname][ind]) else: self.wire(self.ports[pname], self.underlying.ports[pname]) # CLK, RESET, and STALL PER STANDARD PROCEDURE # Need to invert this self.resetInverter = FromMagma(mantle.DefineInvert(1)) self.wire(self.resetInverter.ports.I[0], self.ports.reset) self.wire(self.resetInverter.ports.O[0], self.underlying.ports.rst_n) self.wire(self.ports.clk, self.underlying.ports.clk) # Mem core uses clk_en (essentially active low stall) self.stallInverter = FromMagma(mantle.DefineInvert(1)) self.wire(self.stallInverter.ports.I, self.ports.stall) self.wire(self.stallInverter.ports.O[0], self.underlying.ports.clk_en[0]) # we have six? features in total # 0: TILE # 1: TILE # 1-4: SMEM # Feature 0: Tile self.__features: List[CoreFeature] = [self] # Features 1-4: SRAM self.num_sram_features = lt_dut.total_sets for sram_index in range(self.num_sram_features): core_feature = CoreFeature(self, sram_index + 1) self.__features.append(core_feature) # Wire the config for idx, core_feature in enumerate(self.__features): if (idx > 0): self.add_port( f"config_{idx}", magma.In( ConfigurationType(self.config_addr_width, self.config_data_width))) # port aliasing core_feature.ports["config"] = self.ports[f"config_{idx}"] self.add_port( "config", magma.In( ConfigurationType(self.config_addr_width, self.config_data_width))) # or the signal up t = ConfigurationType(self.config_addr_width, self.config_data_width) t_names = ["config_addr", "config_data"] or_gates = {} for t_name in t_names: port_type = t[t_name] or_gate = FromMagma( mantle.DefineOr(len(self.__features), len(port_type))) or_gate.instance_name = f"OR_{t_name}_FEATURE" for idx, core_feature in enumerate(self.__features): self.wire(or_gate.ports[f"I{idx}"], core_feature.ports.config[t_name]) or_gates[t_name] = or_gate self.wire( or_gates["config_addr"].ports.O, self.underlying.ports.config_addr_in[0:self.config_addr_width]) self.wire(or_gates["config_data"].ports.O, self.underlying.ports.config_data_in) # read data out for idx, core_feature in enumerate(self.__features): if (idx > 0): # self.add_port(f"read_config_data_{idx}", self.add_port(f"read_config_data_{idx}", magma.Out(magma.Bits[self.config_data_width])) # port aliasing core_feature.ports["read_config_data"] = \ self.ports[f"read_config_data_{idx}"] # MEM Config configurations = [] # merged_configs = [] skip_cfgs = [] for cfg_info in cfgs: if cfg_info.port_name in skip_cfgs: continue if cfg_info.expl_arr: if cfg_info.port_size[0] > 1: for i in range(cfg_info.port_size[0]): configurations.append( (f"{cfg_info.port_name}_{i}", cfg_info.port_width)) else: configurations.append( (cfg_info.port_name, cfg_info.port_width)) else: configurations.append( (cfg_info.port_name, cfg_info.port_width)) # Do all the stuff for the main config main_feature = self.__features[0] for config_reg_name, width in configurations: main_feature.add_config(config_reg_name, width) if (width == 1): self.wire(main_feature.registers[config_reg_name].ports.O[0], self.underlying.ports[config_reg_name][0]) else: self.wire(main_feature.registers[config_reg_name].ports.O, self.underlying.ports[config_reg_name]) # SRAM # These should also account for num features # or_all_cfg_rd = FromMagma(mantle.DefineOr(4, 1)) or_all_cfg_rd = FromMagma(mantle.DefineOr(self.num_sram_features, 1)) or_all_cfg_rd.instance_name = f"OR_CONFIG_WR_SRAM" or_all_cfg_wr = FromMagma(mantle.DefineOr(self.num_sram_features, 1)) or_all_cfg_wr.instance_name = f"OR_CONFIG_RD_SRAM" for sram_index in range(self.num_sram_features): core_feature = self.__features[sram_index + 1] self.add_port(f"config_en_{sram_index}", magma.In(magma.Bit)) # port aliasing core_feature.ports["config_en"] = \ self.ports[f"config_en_{sram_index}"] # Sort of a temp hack - the name is just config_data_out if self.num_sram_features == 1: self.wire(core_feature.ports.read_config_data, self.underlying.ports["config_data_out"]) else: self.wire( core_feature.ports.read_config_data, self.underlying.ports[f"config_data_out_{sram_index}"]) # also need to wire the sram signal # the config enable is the OR of the rd+wr or_gate_en = FromMagma(mantle.DefineOr(2, 1)) or_gate_en.instance_name = f"OR_CONFIG_EN_SRAM_{sram_index}" self.wire(or_gate_en.ports.I0, core_feature.ports.config.write) self.wire(or_gate_en.ports.I1, core_feature.ports.config.read) self.wire(core_feature.ports.config_en, self.underlying.ports["config_en"][sram_index]) # Still connect to the OR of all the config rd/wr self.wire(core_feature.ports.config.write, or_all_cfg_wr.ports[f"I{sram_index}"]) self.wire(core_feature.ports.config.read, or_all_cfg_rd.ports[f"I{sram_index}"]) self.wire(or_all_cfg_rd.ports.O[0], self.underlying.ports.config_read[0]) self.wire(or_all_cfg_wr.ports.O[0], self.underlying.ports.config_write[0]) self._setup_config() conf_names = list(self.registers.keys()) conf_names.sort() with open("mem_cfg.txt", "w+") as cfg_dump: for idx, reg in enumerate(conf_names): write_line = f"(\"{reg}\", 0), # {self.registers[reg].width}\n" cfg_dump.write(write_line) with open("mem_synth.txt", "w+") as cfg_dump: for idx, reg in enumerate(conf_names): write_line = f"{reg}\n" cfg_dump.write(write_line)
def gen_global_buffer_rdl(name, params): addr_map = AddrMap(name) # Data Network Ctrl Register data_network_ctrl = Reg("data_network") tile_connected_f = Field("tile_connected", 1) strm_latency_f = Field("latency", params.latency_width) data_network_ctrl.add_children([tile_connected_f, strm_latency_f]) addr_map.add_child(data_network_ctrl) # Pcfg Network Ctrl Register pcfg_network_ctrl = Reg("pcfg_network") pcfg_network_ctrl.add_children( [Field("tile_connected", 1), Field("latency", params.latency_width)]) addr_map.add_child(pcfg_network_ctrl) # Store DMA Ctrl st_dma_ctrl_r = Reg("st_dma_ctrl") st_dma_mode_f = Field("mode", 2) st_dma_ctrl_r.add_child(st_dma_mode_f) st_dma_use_valid_f = Field("use_valid", 1) st_dma_ctrl_r.add_child(st_dma_use_valid_f) st_dma_data_mux_f = Field("data_mux", 2) st_dma_ctrl_r.add_child(st_dma_data_mux_f) st_dma_num_repeat_f = Field("num_repeat", clog2(params.queue_depth) + 1) st_dma_ctrl_r.add_child(st_dma_num_repeat_f) addr_map.add_child(st_dma_ctrl_r) # Store DMA Header if params.queue_depth == 1: st_dma_header_rf = RegFile(f"st_dma_header_0", size=params.queue_depth) else: st_dma_header_rf = RegFile(f"st_dma_header", size=params.queue_depth) # dim reg dim_r = Reg(f"dim") dim_f = Field(f"dim", width=clog2(params.loop_level) + 1) dim_r.add_child(dim_f) st_dma_header_rf.add_child(dim_r) # start_addr reg start_addr_r = Reg(f"start_addr") start_addr_f = Field(f"start_addr", width=params.glb_addr_width) start_addr_r.add_child(start_addr_f) st_dma_header_rf.add_child(start_addr_r) # cycle_start_addr reg cycle_start_addr_r = Reg(f"cycle_start_addr") cycle_start_addr_f = Field(f"cycle_start_addr", width=params.glb_addr_width) cycle_start_addr_r.add_child(cycle_start_addr_f) st_dma_header_rf.add_child(cycle_start_addr_r) # num_word reg range_r = Reg(f"range", size=params.loop_level) range_f = Field("range", width=params.axi_data_width) range_r.add_child(range_f) stride_r = Reg(f"stride", size=params.loop_level) stride_f = Field("stride", width=params.axi_data_width) stride_r.add_child(stride_f) cycle_stride_r = Reg(f"cycle_stride", size=params.loop_level) cycle_stride_f = Field("cycle_stride", width=params.axi_data_width) cycle_stride_r.add_child(cycle_stride_f) st_dma_header_rf.add_child(range_r) st_dma_header_rf.add_child(stride_r) st_dma_header_rf.add_child(cycle_stride_r) addr_map.add_child(st_dma_header_rf) # Load DMA Ctrl ld_dma_ctrl_r = Reg("ld_dma_ctrl") ld_dma_mode_f = Field("mode", 2) ld_dma_ctrl_r.add_child(ld_dma_mode_f) ld_dma_use_valid_f = Field("use_valid", 1) ld_dma_ctrl_r.add_child(ld_dma_use_valid_f) ld_dma_data_mux_f = Field("data_mux", 2) ld_dma_ctrl_r.add_child(ld_dma_data_mux_f) ld_dma_num_repeat_f = Field("num_repeat", clog2(params.queue_depth) + 1) ld_dma_ctrl_r.add_child(ld_dma_num_repeat_f) addr_map.add_child(ld_dma_ctrl_r) # Load DMA Header if params.queue_depth == 1: ld_dma_header_rf = RegFile(f"ld_dma_header_0", size=params.queue_depth) else: ld_dma_header_rf = RegFile(f"ld_dma_header", size=params.queue_depth) # dim reg dim_r = Reg(f"dim") dim_f = Field(f"dim", width=clog2(params.loop_level) + 1) dim_r.add_child(dim_f) ld_dma_header_rf.add_child(dim_r) # start_addr reg start_addr_r = Reg(f"start_addr") start_addr_f = Field(f"start_addr", width=params.glb_addr_width) start_addr_r.add_child(start_addr_f) ld_dma_header_rf.add_child(start_addr_r) # cycle_start_addr reg cycle_start_addr_r = Reg(f"cycle_start_addr") cycle_start_addr_f = Field(f"cycle_start_addr", width=params.glb_addr_width) cycle_start_addr_r.add_child(cycle_start_addr_f) ld_dma_header_rf.add_child(cycle_start_addr_r) # num_word reg range_r = Reg(f"range", size=params.loop_level) range_f = Field("range", width=params.axi_data_width) range_r.add_child(range_f) stride_r = Reg(f"stride", size=params.loop_level) stride_f = Field("stride", width=params.axi_data_width) stride_r.add_child(stride_f) cycle_stride_r = Reg(f"cycle_stride", size=params.loop_level) cycle_stride_f = Field("cycle_stride", width=params.axi_data_width) cycle_stride_r.add_child(cycle_stride_f) ld_dma_header_rf.add_child(range_r) ld_dma_header_rf.add_child(stride_r) ld_dma_header_rf.add_child(cycle_stride_r) addr_map.add_child(ld_dma_header_rf) # Pcfg DMA Ctrl pcfg_dma_ctrl_r = Reg("pcfg_dma_ctrl") pcfg_dma_mode_f = Field("mode", 1) pcfg_dma_ctrl_r.add_child(pcfg_dma_mode_f) addr_map.add_child(pcfg_dma_ctrl_r) # Pcfg DMA Header RegFile pcfg_dma_header_rf = RegFile("pcfg_dma_header") # start_addr reg start_addr_r = Reg(f"start_addr") start_addr_f = Field(f"start_addr", width=params.glb_addr_width) start_addr_r.add_child(start_addr_f) pcfg_dma_header_rf.add_child(start_addr_r) # num cfg reg num_cfg_r = Reg(f"num_cfg") num_cfg_f = Field(f"num_cfg", width=params.max_num_cfg_width) num_cfg_r.add_child(num_cfg_f) pcfg_dma_header_rf.add_child(num_cfg_r) addr_map.add_child(pcfg_dma_header_rf) glb_rdl = Rdl(addr_map) return glb_rdl
def __init__( self, data_width=16, # CGRA Params mem_depth=32, default_iterator_support=3, interconnect_input_ports=2, # Connection to int interconnect_output_ports=2, mem_input_ports=1, mem_output_ports=1, config_data_width=32, config_addr_width=8, cycle_count_width=16, add_clk_enable=True, add_flush=True): super().__init__("pond", debug=True) self.interconnect_input_ports = interconnect_input_ports self.interconnect_output_ports = interconnect_output_ports self.mem_input_ports = mem_input_ports self.mem_output_ports = mem_output_ports self.mem_depth = mem_depth self.data_width = data_width self.config_data_width = config_data_width self.config_addr_width = config_addr_width self.add_clk_enable = add_clk_enable self.add_flush = add_flush self.cycle_count_width = cycle_count_width self.default_iterator_support = default_iterator_support self.default_config_width = kts.clog2(self.mem_depth) # inputs self._clk = self.clock("clk") self._clk.add_attribute( FormalAttr(f"{self._clk.name}", FormalSignalConstraint.CLK)) self._rst_n = self.reset("rst_n") self._rst_n.add_attribute( FormalAttr(f"{self._rst_n.name}", FormalSignalConstraint.RSTN)) self._clk_en = self.clock_en("clk_en", 1) # Enable/Disable tile self._tile_en = self.input("tile_en", 1) self._tile_en.add_attribute( ConfigRegAttr("Tile logic enable manifested as clock gate")) gclk = self.var("gclk", 1) self._gclk = kts.util.clock(gclk) self.wire(gclk, kts.util.clock(self._clk & self._tile_en)) self._cycle_count = add_counter(self, "cycle_count", self.cycle_count_width) # Create write enable + addr, same for read. # self._write = self.input("write", self.interconnect_input_ports) self._write = self.var("write", self.mem_input_ports) # self._write.add_attribute(ControlSignalAttr(is_control=True)) self._write_addr = self.var("write_addr", kts.clog2(self.mem_depth), size=self.interconnect_input_ports, explicit_array=True, packed=True) # Add "_pond" suffix to avoid error during garnet RTL generation self._data_in = self.input("data_in_pond", self.data_width, size=self.interconnect_input_ports, explicit_array=True, packed=True) self._data_in.add_attribute( FormalAttr(f"{self._data_in.name}", FormalSignalConstraint.SEQUENCE)) self._data_in.add_attribute(ControlSignalAttr(is_control=False)) self._read = self.var("read", self.mem_output_ports) self._t_write = self.var("t_write", self.interconnect_input_ports) self._t_read = self.var("t_read", self.interconnect_output_ports) # self._read.add_attribute(ControlSignalAttr(is_control=True)) self._read_addr = self.var("read_addr", kts.clog2(self.mem_depth), size=self.interconnect_output_ports, explicit_array=True, packed=True) self._s_read_addr = self.var("s_read_addr", kts.clog2(self.mem_depth), size=self.interconnect_output_ports, explicit_array=True, packed=True) self._data_out = self.output("data_out_pond", self.data_width, size=self.interconnect_output_ports, explicit_array=True, packed=True) self._data_out.add_attribute( FormalAttr(f"{self._data_out.name}", FormalSignalConstraint.SEQUENCE)) self._data_out.add_attribute(ControlSignalAttr(is_control=False)) self._valid_out = self.output("valid_out_pond", self.interconnect_output_ports) self._valid_out.add_attribute( FormalAttr(f"{self._valid_out.name}", FormalSignalConstraint.SEQUENCE)) self._valid_out.add_attribute(ControlSignalAttr(is_control=False)) self._mem_data_out = self.var("mem_data_out", self.data_width, size=self.mem_output_ports, explicit_array=True, packed=True) self._s_mem_data_in = self.var("s_mem_data_in", self.data_width, size=self.interconnect_input_ports, explicit_array=True, packed=True) self._mem_data_in = self.var("mem_data_in", self.data_width, size=self.mem_input_ports, explicit_array=True, packed=True) self._s_mem_write_addr = self.var("s_mem_write_addr", kts.clog2(self.mem_depth), size=self.interconnect_input_ports, explicit_array=True, packed=True) self._s_mem_read_addr = self.var("s_mem_read_addr", kts.clog2(self.mem_depth), size=self.interconnect_output_ports, explicit_array=True, packed=True) self._mem_write_addr = self.var("mem_write_addr", kts.clog2(self.mem_depth), size=self.mem_input_ports, explicit_array=True, packed=True) self._mem_read_addr = self.var("mem_read_addr", kts.clog2(self.mem_depth), size=self.mem_output_ports, explicit_array=True, packed=True) if self.interconnect_output_ports == 1: self.wire(self._data_out[0], self._mem_data_out[0]) else: for i in range(self.interconnect_output_ports): self.wire(self._data_out[i], self._mem_data_out[0]) # Valid out is simply passing the read signal through... self.wire(self._valid_out, self._t_read) # Create write addressors for wr_port in range(self.interconnect_input_ports): RF_WRITE_ITER = ForLoop( iterator_support=self.default_iterator_support, config_width=self.cycle_count_width) RF_WRITE_ADDR = AddrGen( iterator_support=self.default_iterator_support, config_width=self.default_config_width) RF_WRITE_SCHED = SchedGen( iterator_support=self.default_iterator_support, config_width=self.cycle_count_width, use_enable=True) self.add_child(f"rf_write_iter_{wr_port}", RF_WRITE_ITER, clk=self._gclk, rst_n=self._rst_n, step=self._t_write[wr_port]) # Whatever comes through here should hopefully just pipe through seamlessly # addressor modules self.add_child(f"rf_write_addr_{wr_port}", RF_WRITE_ADDR, clk=self._gclk, rst_n=self._rst_n, step=self._t_write[wr_port], mux_sel=RF_WRITE_ITER.ports.mux_sel_out, restart=RF_WRITE_ITER.ports.restart) safe_wire(self, self._write_addr[wr_port], RF_WRITE_ADDR.ports.addr_out) self.add_child(f"rf_write_sched_{wr_port}", RF_WRITE_SCHED, clk=self._gclk, rst_n=self._rst_n, mux_sel=RF_WRITE_ITER.ports.mux_sel_out, finished=RF_WRITE_ITER.ports.restart, cycle_count=self._cycle_count, valid_output=self._t_write[wr_port]) # Create read addressors for rd_port in range(self.interconnect_output_ports): RF_READ_ITER = ForLoop( iterator_support=self.default_iterator_support, config_width=self.cycle_count_width) RF_READ_ADDR = AddrGen( iterator_support=self.default_iterator_support, config_width=self.default_config_width) RF_READ_SCHED = SchedGen( iterator_support=self.default_iterator_support, config_width=self.cycle_count_width, use_enable=True) self.add_child(f"rf_read_iter_{rd_port}", RF_READ_ITER, clk=self._gclk, rst_n=self._rst_n, step=self._t_read[rd_port]) self.add_child(f"rf_read_addr_{rd_port}", RF_READ_ADDR, clk=self._gclk, rst_n=self._rst_n, step=self._t_read[rd_port], mux_sel=RF_READ_ITER.ports.mux_sel_out, restart=RF_READ_ITER.ports.restart) if self.interconnect_output_ports > 1: safe_wire(self, self._read_addr[rd_port], RF_READ_ADDR.ports.addr_out) else: safe_wire(self, self._read_addr[rd_port], RF_READ_ADDR.ports.addr_out) self.add_child(f"rf_read_sched_{rd_port}", RF_READ_SCHED, clk=self._gclk, rst_n=self._rst_n, mux_sel=RF_READ_ITER.ports.mux_sel_out, finished=RF_READ_ITER.ports.restart, cycle_count=self._cycle_count, valid_output=self._t_read[rd_port]) self.wire(self._write, self._t_write.r_or()) self.wire(self._mem_write_addr[0], decode(self, self._t_write, self._s_mem_write_addr)) self.wire(self._mem_data_in[0], decode(self, self._t_write, self._s_mem_data_in)) self.wire(self._read, self._t_read.r_or()) self.wire(self._mem_read_addr[0], decode(self, self._t_read, self._s_mem_read_addr)) # =================================== # Instantiate config hooks... # =================================== self.fw_int = 1 self.data_words_per_set = 2**self.config_addr_width self.sets = int( (self.fw_int * self.mem_depth) / self.data_words_per_set) self.sets_per_macro = max( 1, int(self.mem_depth / self.data_words_per_set)) self.total_sets = max(1, 1 * self.sets_per_macro) self._config_data_in = self.input("config_data_in", self.config_data_width) self._config_data_in.add_attribute(ControlSignalAttr(is_control=False)) self._config_data_in_shrt = self.var("config_data_in_shrt", self.data_width) self.wire(self._config_data_in_shrt, self._config_data_in[self.data_width - 1, 0]) self._config_addr_in = self.input("config_addr_in", self.config_addr_width) self._config_addr_in.add_attribute(ControlSignalAttr(is_control=False)) self._config_data_out_shrt = self.var("config_data_out_shrt", self.data_width, size=self.total_sets, explicit_array=True, packed=True) self._config_data_out = self.output("config_data_out", self.config_data_width, size=self.total_sets, explicit_array=True, packed=True) self._config_data_out.add_attribute( ControlSignalAttr(is_control=False)) for i in range(self.total_sets): self.wire( self._config_data_out[i], self._config_data_out_shrt[i].extend(self.config_data_width)) self._config_read = self.input("config_read", 1) self._config_read.add_attribute(ControlSignalAttr(is_control=False)) self._config_write = self.input("config_write", 1) self._config_write.add_attribute(ControlSignalAttr(is_control=False)) self._config_en = self.input("config_en", self.total_sets) self._config_en.add_attribute(ControlSignalAttr(is_control=False)) self._mem_data_cfg = self.var("mem_data_cfg", self.data_width, explicit_array=True, packed=True) self._mem_addr_cfg = self.var("mem_addr_cfg", kts.clog2(self.mem_depth)) # Add config... stg_cfg_seq = StorageConfigSeq( data_width=self.data_width, config_addr_width=self.config_addr_width, addr_width=kts.clog2(self.mem_depth), fetch_width=self.data_width, total_sets=self.total_sets, sets_per_macro=self.sets_per_macro) # The clock to config sequencer needs to be the normal clock or # if the tile is off, we bring the clock back in based on config_en cfg_seq_clk = self.var("cfg_seq_clk", 1) self._cfg_seq_clk = kts.util.clock(cfg_seq_clk) self.wire(cfg_seq_clk, kts.util.clock(self._gclk)) self.add_child(f"config_seq", stg_cfg_seq, clk=self._cfg_seq_clk, rst_n=self._rst_n, clk_en=self._clk_en | self._config_en.r_or(), config_data_in=self._config_data_in_shrt, config_addr_in=self._config_addr_in, config_wr=self._config_write, config_rd=self._config_read, config_en=self._config_en, wr_data=self._mem_data_cfg, rd_data_out=self._config_data_out_shrt, addr_out=self._mem_addr_cfg) if self.interconnect_output_ports == 1: self.wire(stg_cfg_seq.ports.rd_data_stg, self._mem_data_out) else: self.wire(stg_cfg_seq.ports.rd_data_stg[0], self._mem_data_out[0]) self.RF_GEN = RegisterFile(data_width=self.data_width, write_ports=self.mem_input_ports, read_ports=self.mem_output_ports, width_mult=1, depth=self.mem_depth, read_delay=0) # Now we can instantiate and wire up the register file self.add_child(f"rf", self.RF_GEN, clk=self._gclk, rst_n=self._rst_n, data_out=self._mem_data_out) # Opt in for config_write self._write_rf = self.var("write_rf", self.mem_input_ports) self.wire( self._write_rf[0], kts.ternary(self._config_en.r_or(), self._config_write, self._write[0])) for i in range(self.mem_input_ports - 1): self.wire( self._write_rf[i + 1], kts.ternary(self._config_en.r_or(), kts.const(0, 1), self._write[i + 1])) self.wire(self.RF_GEN.ports.wen, self._write_rf) # Opt in for config_data_in for i in range(self.interconnect_input_ports): self.wire( self._s_mem_data_in[i], kts.ternary(self._config_en.r_or(), self._mem_data_cfg, self._data_in[i])) self.wire(self.RF_GEN.ports.data_in, self._mem_data_in) # Opt in for config_addr for i in range(self.interconnect_input_ports): self.wire( self._s_mem_write_addr[i], kts.ternary(self._config_en.r_or(), self._mem_addr_cfg, self._write_addr[i])) self.wire(self.RF_GEN.ports.wr_addr, self._mem_write_addr[0]) for i in range(self.interconnect_output_ports): self.wire( self._s_mem_read_addr[i], kts.ternary(self._config_en.r_or(), self._mem_addr_cfg, self._read_addr[i])) self.wire(self.RF_GEN.ports.rd_addr, self._mem_read_addr[0]) if self.add_clk_enable: # self.clock_en("clk_en") kts.passes.auto_insert_clock_enable(self.internal_generator) clk_en_port = self.internal_generator.get_port("clk_en") clk_en_port.add_attribute(ControlSignalAttr(False)) if self.add_flush: self.add_attribute("sync-reset=flush") kts.passes.auto_insert_sync_reset(self.internal_generator) flush_port = self.internal_generator.get_port("flush") flush_port.add_attribute(ControlSignalAttr(True)) # Finally, lift the config regs... lift_config_reg(self.internal_generator)
def interact(self, ack_in, data_in, valid_in, ren_in, mem_valid_data): ''' Returns (data_out, valid_out, rd_sync_gate, mem_valid_data_out) ''' valid_out = [] data_out = [] rd_sync_gate = [] mem_valid_data_out = [] # # Use current state of bus to set local gate reduced # for i in range(self.int_out_ports): # # For this port, just want to check that its corresponding entry is low # for j in range(self.groups): # if self.config[f"sync_group_{i}"] == (1 << j): # self.local_gate_reduced[i] = self.local_gate[j][i] # Set the ren_int, ack_in combo ren_int = [] for i in range(self.int_out_ports): ren_int.append(ren_in[i] & self.local_gate_reduced[i]) # Create new local mask for i in range(self.groups): for j in range(self.int_out_ports): self.local_mask[i][j] = 1 # If port j is in group i, set its gate mask if self.config[f"sync_group_{j}"] == (1 << i): self.local_mask[i][j] = not (ren_int[j] & ((ack_in & (1 << j)) != 0)) # self.local_mask[i][j] = not (ren_int[j] and ack_in[j]) # Get group finished group_finished = [] for i in range(self.groups): group_finished.append(1) # Check that either the bus or mask is low for all items in the group for j in range(self.int_out_ports): # Only check if the port is in the group if self.config[f"sync_group_{j}"] == (1 << i): if self.local_gate[i][j] == 1 and self.local_mask[i][ j] == 1: group_finished[i] = 0 rd_sync_gate = self.get_rd_sync() for i in range(self.groups): for j in range(self.int_out_ports): if group_finished[i] == 1: self.local_gate[i][j] = 1 else: self.local_gate[i][ j] = self.local_gate[i][j] & self.local_mask[i][j] # Can get the valid syncs now # We do this by checking each groups members # to all have their valid reg high for i in range(self.groups): self.sync_group_valid[i] = 1 for j in range(self.int_out_ports): # If any member of the group isn't valid yet, the group isn't valid if (self.config[f"sync_group_{j}"] == (1 << i)) and self.valid_reg[j] == 0: self.sync_group_valid[i] = 0 # Each port gets its group's sync valid for i in range(self.int_out_ports): valid_out.append(self.sync_group_valid[kts.clog2( self.config[f'sync_group_{i}'])]) # valid_out.append(self.sync_group_valid[self.config[f'sync_group_{i}']]) data_out.append(self.data_reg[i].copy()) mem_valid_data_out.append(self.mem_valid_data_out_reg[i]) # Update the registered valids - we keep these around to track # which valids in the group already came for i in range(self.int_out_ports): group_log = kts.clog2(self.config[f"sync_group_{i}"]) if self.sync_group_valid[group_log] == 1 or self.valid_reg[i] == 0: self.valid_reg[i] = valid_in[i] self.data_reg[i] = data_in[i].copy() self.mem_valid_data_out_reg[i] = mem_valid_data[i] # Use current state of bus to set local gate reduced for i in range(self.int_out_ports): # For this port, just want to check that its corresponding entry is low for j in range(self.groups): if self.config[f"sync_group_{i}"] == (1 << j): self.local_gate_reduced[i] = self.local_gate[j][i] # rd_sync_gate = self.get_rd_sync() return (data_out.copy(), valid_out, rd_sync_gate.copy(), mem_valid_data_out)