def bank_wr_packet_cache_comb(self): self.bank_wr_strb_cache_w = self.bank_wr_strb_cache_r self.bank_wr_data_cache_w = self.bank_wr_data_cache_r # First, if cached data is written to memory, clear it. if self.bank_wr_en: self.bank_wr_strb_cache_w = 0 self.bank_wr_data_cache_w = 0 # Next, save data to cache if self.strm_wr_en_w: if self.strm_data_sel == 0: self.bank_wr_strb_cache_w[self.cgra_strb_width - 1, 0] = const(self.cgra_strb_value, self.cgra_strb_width) self.bank_wr_data_cache_w[self._params.cgra_data_width - 1, 0] = self.strm_wr_data_w elif self.strm_data_sel == 1: self.bank_wr_strb_cache_w[self.cgra_strb_width * 2 - 1, self.cgra_strb_width] = const(self.cgra_strb_value, self.cgra_strb_width) self.bank_wr_data_cache_w[self._params.cgra_data_width * 2 - 1, self._params.cgra_data_width] = self.strm_wr_data_w elif self.strm_data_sel == 2: self.bank_wr_strb_cache_w[self.cgra_strb_width * 3 - 1, self.cgra_strb_width * 2] = const(self.cgra_strb_value, self.cgra_strb_width) self.bank_wr_data_cache_w[self._params.cgra_data_width * 3 - 1, self._params.cgra_data_width * 2] = self.strm_wr_data_w elif self.strm_data_sel == 3: self.bank_wr_strb_cache_w[self.cgra_strb_width * 4 - 1, self.cgra_strb_width * 3] = const(self.cgra_strb_value, self.cgra_strb_width) self.bank_wr_data_cache_w[self._params.cgra_data_width * 4 - 1, self._params.cgra_data_width * 3] = self.strm_wr_data_w else: self.bank_wr_strb_cache_w = self.bank_wr_strb_cache_r self.bank_wr_data_cache_w = self.bank_wr_data_cache_r
def set_back_num_load(self): if self._back_pl: self._back_num_load = kts.ternary(self._pop, kts.const(self.fw_int - 1, self._back_num_load.width), kts.const(self.fw_int, self._back_num_load.width)) else: self._back_num_load = 0
def add_sram_macro(self): for i in range(self.num_sram_macros): self.add_child(f"sram_array_{i}", TS1N16FFCLLSBLVTC2048X64M8SW(), CLK=self.CLK, A=self.a_sram_d, BWEB=self.BWEB_d, CEB=self.ceb_demux_d[i], WEB=self.web_demux_d[i], D=self.D_d, Q=self.q_sram2mux[i], RTSEL=const(0b01, 2), WTSEL=const(0b00, 2))
def set_dryer(dryer, mod): dryer.output(mod.ports.dryer_done) heating = dryer.add_state("Heating") spin = dryer.add_state("Spin") cool_down = dryer.add_state("CoolDown") dryer_done = dryer.add_state("Done") heating.next(spin, None) spin.next(cool_down, None) cool_down.next(dryer_done, None) for state in (heating, spin, cool_down): state.output(mod.ports.dryer_done, const(0, 1)) dryer_done.output(mod.ports.dryer_done, const(1, 1))
def set_washer(washer, mod): washer.output(mod.ports.washer_done) water_fill = washer.add_state("WaterFill") spin = washer.add_state("Spin") drain = washer.add_state("Drain") washer_done = washer.add_state("Done") water_fill.next(spin, None) spin.next(drain, None) drain.next(washer_done, None) for state in (water_fill, spin, drain): state.output(mod.ports.washer_done, const(0, 1)) washer_done.output(mod.ports.washer_done, const(1, 1))
def sram_ctrl_logic(self): if ~self.WEB: self.web_demux = ~(const(1, width=self.num_sram_macros) << resize( self.sram_sel, self.num_sram_macros)) else: self.web_demux = const(2**self.num_sram_macros - 1, self.num_sram_macros) if ~self.CEB: self.ceb_demux = ~(const(1, width=self.num_sram_macros) << resize( self.sram_sel, self.num_sram_macros)) else: self.ceb_demux = const(2**self.num_sram_macros - 1, self.num_sram_macros)
def zext(gen, wire, size): if wire.width >= size: return wire else: zext_signal = gen.var(f"{wire.name}_zext", size) gen.wire(zext_signal, kts.concat(kts.const(0, size - wire.width), wire)) return zext_signal
def send_writes(self, idx): # Send a write to a bank if the read to that bank isn't happening (unless you can do both) # further, we make sure there is a queued write, or current buffer is # full and there is an incoming push # and there is stuff in memory to be read self._wen_to_strg[idx] = ((~self._ren_to_strg[idx] | kts.const(int(self.rw_same_cycle), 1)) & (self._queued_write[idx] | # Already queued a write... (((self._front_occ == self.fw_int) & self._push & # Grossness (~self._front_pop)) & (self._curr_bank_wr == idx))))
def main(): mod = Generator("LaundryMachine") washer_door = mod.input("washer_door", 1) dryer_door = mod.input("dryer_door", 1) clothes = mod.input("clothes", 1) washer_done = mod.output("washer_done", 1) dryer_done = mod.output("dryer_done", 1) done_laundry = mod.output("done", 1) mod.clock("clk") mod.reset("rst") main_fsm = mod.add_fsm("Laundry") washer = mod.add_fsm("Washer") dryer = mod.add_fsm("Dryer") main_fsm.add_child_fsm(washer) main_fsm.add_child_fsm(dryer) # add sub states set_washer(washer, mod) set_dryer(dryer, mod) # set the main state reset = main_fsm.add_state("Reset") main_fsm.set_start_state(reset) start_washing = main_fsm.add_state("Washing") start_drying = main_fsm.add_state("Drying") done = main_fsm.add_state("Done") reset.next(start_washing, washer_door == 1) start_washing.next(washer["WaterFill"], None) washer["Done"].next(start_drying, None) start_drying.next(dryer["Heating"], dryer_door == 1) dryer["Done"].next(done, None) done.next(reset, None) main_fsm.output(done_laundry) for state in (reset, start_washing, start_drying): state.output(done_laundry, const(0, 1)) done.output(done_laundry, const(1, 1)) verilog(mod, filename="laundry.sv") main_fsm.dot_graph("laundry.dot")
def add_pipeline(self): sram_signals_reset_high_in = concat(self.WEB, self.CEB, self.web_demux, self.ceb_demux, self.BWEB) sram_signals_reset_high_out = concat(self.WEB_d, self.CEB_d, self.web_demux_d, self.ceb_demux_d, self.BWEB_d) self.sram_signals_reset_high_pipeline = Pipeline( width=sram_signals_reset_high_in.width, depth=self._params.sram_gen_pipeline_depth, reset_high=True) self.add_child("sram_signals_reset_high_pipeline", self.sram_signals_reset_high_pipeline, clk=self.CLK, clk_en=const(1, 1), reset=self.RESET, in_=sram_signals_reset_high_in, out_=sram_signals_reset_high_out) sram_signals_in = concat(self.a_sram, self.sram_sel, self.D) sram_signals_out = concat(self.a_sram_d, self.sram_sel_d, self.D_d) self.sram_signals_pipeline = Pipeline( width=sram_signals_in.width, depth=self._params.sram_gen_pipeline_depth) self.add_child("sram_signals_pipeline", self.sram_signals_pipeline, clk=self.CLK, clk_en=const(1, 1), reset=self.RESET, in_=sram_signals_in, out_=sram_signals_out) self.sram_signals_output_pipeline = Pipeline( width=self.sram_macro_width, depth=self._params.sram_gen_output_pipeline_depth) self.add_child("sram_signals_output_pipeline", self.sram_signals_output_pipeline, clk=self.CLK, clk_en=const(1, 1), reset=self.RESET, in_=self.Q_w, out_=self.Q)
def add_sram_cfg_rd_addr_sel_pipeline(self): self.sram_cfg_rd_addr_sel_d = self.var("sram_cfg_rd_addr_sel_d", 1) self.sram_cfg_rd_addr_sel_pipeline = Pipeline( width=1, depth=self.bank_ctrl_pipeline_depth) self.add_child( "sram_cfg_rd_addr_sel_pipeline", self.sram_cfg_rd_addr_sel_pipeline, clk=self.clk, clk_en=const(1, 1), reset=self.reset, in_=self.if_sram_cfg_s.rd_addr[self._params.bank_byte_offset - 1], out_=self.sram_cfg_rd_addr_sel_d)
def add_counter(generator, name, bitwidth, increment=kts.const(1, 1)): ctr = generator.var(name, bitwidth) @always_ff((posedge, "clk"), (negedge, "rst_n")) def ctr_inc_code(): if ~generator._rst_n: ctr = 0 elif increment: ctr = ctr + 1 generator.add_code(ctr_inc_code) return ctr
def add_pipeline(self): self.mem_pipeline = Pipeline( width=self.data_width, depth=(self._params.sram_gen_pipeline_depth + self._params.sram_gen_output_pipeline_depth)) self.add_child("mem_pipeline", self.mem_pipeline, clk=self.CLK, clk_en=const(1, 1), reset=self.RESET, in_=self.Q_w, out_=self.Q)
def mem_signal_logic(self): if self.if_sram_cfg_s.wr_en: if self.if_sram_cfg_s.wr_addr[self._params.bank_byte_offset - 1] == 0: self.mem_wr_en = 1 self.mem_rd_en_w = 0 self.mem_addr = self.if_sram_cfg_s.wr_addr self.mem_data_in = concat( const( 0, self._params.bank_data_width - self._params.axi_data_width), self.if_sram_cfg_s.wr_data) self.mem_data_in_bit_sel = concat( const( 0, self._params.bank_data_width - self._params.axi_data_width), const(2**self._params.axi_data_width - 1, self._params.axi_data_width)) else: self.mem_wr_en = 1 self.mem_rd_en_w = 0 self.mem_addr = self.if_sram_cfg_s.wr_addr self.mem_data_in = concat( self.if_sram_cfg_s.wr_data[self._params.bank_data_width - self._params.axi_data_width - 1, 0], const(0, self._params.axi_data_width)) self.mem_data_in_bit_sel = concat( const( 2**(self._params.bank_data_width - self._params.axi_data_width) - 1, self._params.bank_data_width - self._params.axi_data_width), const(0, self._params.axi_data_width)) elif self.if_sram_cfg_s.rd_en: self.mem_wr_en = 0 self.mem_rd_en_w = 1 self.mem_addr = self.if_sram_cfg_s.rd_addr self.mem_data_in = 0 self.mem_data_in_bit_sel = 0 elif self.packet_wr_en: self.mem_wr_en = 1 self.mem_rd_en_w = 0 self.mem_addr = self.packet_wr_addr self.mem_data_in = self.packet_wr_data self.mem_data_in_bit_sel = self.packet_wr_data_bit_sel elif self.packet_rd_en: self.mem_wr_en = 0 self.mem_rd_en_w = 1 self.mem_addr = self.packet_rd_addr self.mem_data_in = 0 self.mem_data_in_bit_sel = 0 else: self.mem_wr_en = 0 self.mem_rd_en_w = 0 self.mem_addr = 0 self.mem_data_in = 0 self.mem_data_in_bit_sel = 0
def pipeline(self): if self.reset: for i in range(self.depth): if self.reset_high: self.pipeline_r[i] = const(2**self.width - 1, self.width) else: self.pipeline_r[i] = 0 elif self.clk_en: for i in range(self.depth): if i == 0: self.pipeline_r[i] = self.in_ else: self.pipeline_r[i] = self.pipeline_r[resize( i - 1, self.depth_width)]
def add_rd_en_pipeline(self): self.mem_rd_en_w = self.var("mem_rd_en_w", 1) self.mem_rd_en_d = self.var("mem_rd_en_d", 1) self.sram_cfg_rd_en_d = self.var("sram_cfg_rd_en_d", 1) self.packet_rd_en_d = self.var("packet_rd_en_d", 1) self.wire(self.mem_rd_en_w, self.mem_rd_en) self.mem_rd_en_pipeline = Pipeline(width=1, depth=self.bank_ctrl_pipeline_depth) self.add_child("mem_rd_en_pipeline", self.mem_rd_en_pipeline, clk=self.clk, clk_en=const(1, 1), reset=self.reset, in_=self.mem_rd_en_w, out_=self.mem_rd_en_d) self.sram_cfg_rd_en_pipeline = Pipeline( width=1, depth=self.bank_ctrl_pipeline_depth) self.add_child("sram_cfg_rd_en_pipeline", self.sram_cfg_rd_en_pipeline, clk=self.clk, clk_en=const(1, 1), reset=self.reset, in_=self.if_sram_cfg_s.rd_en, out_=self.sram_cfg_rd_en_d) self.packet_rd_en_pipeline = Pipeline( width=1, depth=self.bank_ctrl_pipeline_depth) self.add_child("packet_rd_en_pipeline", self.packet_rd_en_pipeline, clk=self.clk, clk_en=const(1, 1), reset=self.reset, in_=self.packet_rd_en, out_=self.packet_rd_en_d)
def add_pcfg_dma_done_pulse_pipeline(self): maximum_latency = 3 * self._params.num_glb_tiles + self.default_latency latency_width = clog2(maximum_latency) self.done_pulse_d_arr = self.var( "done_pulse_d_arr", 1, size=maximum_latency, explicit_array=True) self.done_pulse_pipeline = Pipeline(width=1, depth=maximum_latency, flatten_output=True) self.add_child("done_pulse_pipeline", self.done_pulse_pipeline, clk=self.clk, clk_en=const(1, 1), reset=self.reset, in_=self.done_pulse_r, out_=self.done_pulse_d_arr) self.wire(self.pcfg_done_pulse, self.done_pulse_d_arr[resize(self.cfg_pcfg_network_latency, latency_width) + self.default_latency + self._params.num_glb_tiles])
def safe_wire(gen, w_to, w_from): ''' Wire together two signals of (potentially) mismatched width to avoid the exception that Kratos throws. ''' # Only works in one dimension... if w_to.width != w_from.width: if lake_util_verbose_trim: print( f"SAFEWIRE: WIDTH MISMATCH: {w_to.name} width {w_to.width} <-> {w_from.name} width {w_from.width}" ) # w1 contains smaller width... if w_to.width < w_from.width: gen.wire(w_to, w_from[w_to.width - 1, 0]) else: gen.wire(w_to[w_from.width - 1, 0], w_from) zero_overlap = w_to.width - w_from.width gen.wire(w_to[w_to.width - 1, w_from.width], kts.const(0, zero_overlap)) else: gen.wire(w_to, w_from)
def set_read_bank(self): if self.banks == 1: self.wire(self._rd_bank, kts.const(0, 1)) else: # The read bank is comb if no delay, otherwise delayed if self.read_delay == 1: @always_ff((posedge, "clk"), (negedge, "rst_n")) def read_bank_ff(self): if ~self._rst_n: self._rd_bank = 0 else: self._rd_bank = \ self._rd_addr[self.b_a_off + self.bank_width - 1, self.b_a_off] self.add_code(read_bank_ff) else: @always_comb def read_bank_comb(self): self._rd_bank = \ self._rd_addr[self.b_a_off + self.bank_width - 1, self.b_a_off] self.add_code(read_bank_comb)
def tile2tile_w2e_wiring(self): self.wire(self.proc_packet_w2e_wsti[0], self.proc_packet_d) self.wire(self.strm_packet_w2e_wsti[0], 0) self.wire(self.pcfg_packet_w2e_wsti[0], 0) for i in range(1, self._params.num_glb_tiles): self.wire( self.proc_packet_w2e_wsti[const( i, clog2(self._params.num_glb_tiles))], self.proc_packet_w2e_esto[const( (i - 1), clog2(self._params.num_glb_tiles))]) self.wire( self.strm_packet_w2e_wsti[const( i, clog2(self._params.num_glb_tiles))], self.strm_packet_w2e_esto[const( (i - 1), clog2(self._params.num_glb_tiles))]) self.wire( self.pcfg_packet_w2e_wsti[const( i, clog2(self._params.num_glb_tiles))], self.pcfg_packet_w2e_esto[const( (i - 1), clog2(self._params.num_glb_tiles))])
def __init__(self, data_width=16, # CGRA Params mem_width=64, mem_depth=512, banks=1, input_iterator_support=6, # Addr Controllers output_iterator_support=6, input_config_width=16, output_config_width=16, interconnect_input_ports=2, # Connection to int interconnect_output_ports=2, mem_input_ports=1, mem_output_ports=1, read_delay=1, # Cycle delay in read (SRAM vs Register File) rw_same_cycle=False, # Does the memory allow r+w in same cycle? agg_height=4, max_agg_schedule=32, input_max_port_sched=32, output_max_port_sched=32, align_input=1, max_line_length=128, max_tb_height=1, tb_range_max=128, tb_range_inner_max=5, tb_sched_max=64, max_tb_stride=15, num_tb=1, tb_iterator_support=2, multiwrite=1, num_tiles=1, max_prefetch=8, app_ctrl_depth_width=16, remove_tb=False, stcl_valid_iter=4): super().__init__("strg_ub") self.data_width = data_width self.mem_width = mem_width self.mem_depth = mem_depth self.banks = banks self.input_iterator_support = input_iterator_support self.output_iterator_support = output_iterator_support self.input_config_width = input_config_width self.output_config_width = output_config_width self.interconnect_input_ports = interconnect_input_ports self.interconnect_output_ports = interconnect_output_ports self.mem_input_ports = mem_input_ports self.mem_output_ports = mem_output_ports self.agg_height = agg_height self.max_agg_schedule = max_agg_schedule self.input_max_port_sched = input_max_port_sched self.output_max_port_sched = output_max_port_sched self.input_port_sched_width = clog2(self.interconnect_input_ports) self.align_input = align_input self.max_line_length = max_line_length assert self.mem_width >= self.data_width, "Data width needs to be smaller than mem" self.fw_int = int(self.mem_width / self.data_width) self.num_tb = num_tb self.max_tb_height = max_tb_height self.tb_range_max = tb_range_max self.tb_range_inner_max = tb_range_inner_max self.max_tb_stride = max_tb_stride self.tb_sched_max = tb_sched_max self.tb_iterator_support = tb_iterator_support self.multiwrite = multiwrite self.max_prefetch = max_prefetch self.num_tiles = num_tiles self.app_ctrl_depth_width = app_ctrl_depth_width self.remove_tb = remove_tb self.read_delay = read_delay self.rw_same_cycle = rw_same_cycle self.stcl_valid_iter = stcl_valid_iter # phases = [] TODO self.address_width = clog2(self.num_tiles * self.mem_depth) # CLK and RST self._clk = self.clock("clk") self._rst_n = self.reset("rst_n") # INPUTS self._data_in = self.input("data_in", self.data_width, size=self.interconnect_input_ports, packed=True, explicit_array=True) self._wen_in = self.input("wen_in", self.interconnect_input_ports) self._ren_input = self.input("ren_in", self.interconnect_output_ports) # Post rate matched self._ren_in = self.var("ren_in_muxed", self.interconnect_output_ports) # Processed versions of wen and ren from the app ctrl self._wen = self.var("wen", self.interconnect_input_ports) self._ren = self.var("ren", self.interconnect_output_ports) # Add rate matched # If one input port, let any output port use the wen_in as the ren_in # If more, do the same thing but also provide port selection if self.interconnect_input_ports == 1: self._rate_matched = self.input("rate_matched", self.interconnect_output_ports) self._rate_matched.add_attribute(ConfigRegAttr("Rate matched - 1 or 0")) for i in range(self.interconnect_output_ports): self.wire(self._ren_in[i], kts.ternary(self._rate_matched[i], self._wen_in, self._ren_input[i])) else: self._rate_matched = self.input("rate_matched", 1 + kts.clog2(self.interconnect_input_ports), size=self.interconnect_output_ports, explicit_array=True, packed=True) self._rate_matched.add_attribute(ConfigRegAttr("Rate matched [input port | on/off]")) for i in range(self.interconnect_output_ports): self.wire(self._ren_in[i], kts.ternary(self._rate_matched[i][0], self._wen_in[self._rate_matched[i][kts.clog2(self.interconnect_input_ports), 1]], self._ren_input[i])) self._arb_wen_en = self.var("arb_wen_en", self.interconnect_input_ports) self._arb_ren_en = self.var("arb_ren_en", self.interconnect_output_ports) self._data_from_strg = self.input("data_from_strg", self.data_width, size=(self.banks, self.mem_output_ports, self.fw_int), packed=True, explicit_array=True) self._mem_valid_data = self.input("mem_valid_data", self.mem_output_ports, size=self.banks, explicit_array=True, packed=True) self._out_mem_valid_data = self.var("out_mem_valid_data", self.mem_output_ports, size=self.banks, explicit_array=True, packed=True) # We need to signal valids out of the agg buff, only if one exists... if self.agg_height > 0: self._to_iac_valid = self.var("ab_to_mem_valid", self.interconnect_input_ports) self._data_out = self.output("data_out", self.data_width, size=self.interconnect_output_ports, packed=True, explicit_array=True) self._valid_out = self.output("valid_out", self.interconnect_output_ports) self._valid_out_alt = self.var("valid_out_alt", self.interconnect_output_ports) self._data_to_strg = self.output("data_to_strg", self.data_width, size=(self.banks, self.mem_input_ports, self.fw_int), packed=True, explicit_array=True) # If we can perform a read and a write on the same cycle, # this will necessitate a separate read and write address... if self.rw_same_cycle: self._wr_addr_out = self.output("wr_addr_out", self.address_width, size=(self.banks, self.mem_input_ports), explicit_array=True, packed=True) self._rd_addr_out = self.output("rd_addr_out", self.address_width, size=(self.banks, self.mem_output_ports), explicit_array=True, packed=True) else: self._addr_out = self.output("addr_out", self.address_width, size=(self.banks, self.mem_input_ports), packed=True, explicit_array=True) self._cen_to_strg = self.output("cen_to_strg", self.mem_output_ports, size=self.banks, explicit_array=True, packed=True) self._wen_to_strg = self.output("wen_to_strg", self.mem_input_ports, size=self.banks, explicit_array=True, packed=True) if self.num_tb > 0: self._tb_valid_out = self.var("tb_valid_out", self.interconnect_output_ports) self._port_wens = self.var("port_wens", self.interconnect_input_ports) #################### ##### APP CTRL ##### #################### self._ack_transpose = self.var("ack_transpose", self.banks, size=self.interconnect_output_ports, explicit_array=True, packed=True) self._ack_reduced = self.var("ack_reduced", self.interconnect_output_ports) self.app_ctrl = AppCtrl(interconnect_input_ports=self.interconnect_input_ports, interconnect_output_ports=self.interconnect_output_ports, depth_width=self.app_ctrl_depth_width, sprt_stcl_valid=True, stcl_iter_support=self.stcl_valid_iter) # Some refactoring here for pond to get rid of app controllers... # This is honestly pretty messy and should clean up nicely when we have the spec... self._ren_out_reduced = self.var("ren_out_reduced", self.interconnect_output_ports) if self.num_tb == 0 or self.remove_tb: self.wire(self._wen, self._wen_in) self.wire(self._ren, self._ren_in) self.wire(self._valid_out, self._valid_out_alt) self.wire(self._arb_wen_en, self._wen) self.wire(self._arb_ren_en, self._ren) else: self.add_child("app_ctrl", self.app_ctrl, clk=self._clk, rst_n=self._rst_n, wen_in=self._wen_in, ren_in=self._ren_in, # ren_update=self._tb_valid_out, valid_out_data=self._valid_out, # valid_out_stencil=, wen_out=self._wen, ren_out=self._ren) self.wire(self.app_ctrl.ports.tb_valid, self._tb_valid_out) self.wire(self.app_ctrl.ports.ren_update, self._tb_valid_out) self.app_ctrl_coarse = AppCtrl(interconnect_input_ports=self.interconnect_input_ports, interconnect_output_ports=self.interconnect_output_ports, depth_width=self.app_ctrl_depth_width) self.add_child("app_ctrl_coarse", self.app_ctrl_coarse, clk=self._clk, rst_n=self._rst_n, wen_in=self._to_iac_valid, # self._port_wens & self._to_iac_valid, # Gets valid and the ack ren_in=self._ren_out_reduced, tb_valid=kts.const(0, 1), ren_update=self._ack_reduced, wen_out=self._arb_wen_en, ren_out=self._arb_ren_en) ########################### ##### INPUT AGG SCHED ##### ########################### ########################################### ##### AGGREGATION ALIGNERS (OPTIONAL) ##### ########################################### # These variables are holders and can be swapped out if needed self._data_consume = self._data_in self._valid_consume = self._wen # Zero out if not aligning if(self.agg_height > 0): self._align_to_agg = self.var("align_input", self.interconnect_input_ports) # Add the aggregation buffer aligners if(self.align_input): self._data_consume = self.var("data_consume", self.data_width, size=self.interconnect_input_ports, packed=True, explicit_array=True) self._valid_consume = self.var("valid_consume", self.interconnect_input_ports) # Make new aggregation aligners for each port for i in range(self.interconnect_input_ports): new_child = AggAligner(self.data_width, self.max_line_length) self.add_child(f"agg_align_{i}", new_child, clk=self._clk, rst_n=self._rst_n, in_dat=self._data_in[i], in_valid=self._wen[i], align=self._align_to_agg[i], out_valid=self._valid_consume[i], out_dat=self._data_consume[i]) else: if self.agg_height > 0: self.wire(self._align_to_agg, const(0, self._align_to_agg.width)) ################################################ ##### END: AGGREGATION ALIGNERS (OPTIONAL) ##### ################################################ if self.agg_height == 0: self._to_iac_dat = self._data_consume self._to_iac_valid = self._valid_consume ################################## ##### AGG BUFFERS (OPTIONAL) ##### ################################## # Only instantiate agg_buffer if needed if(self.agg_height > 0): self._to_iac_dat = self.var("ab_to_mem_dat", self.mem_width, size=self.interconnect_input_ports, packed=True, explicit_array=True) # self._to_iac_valid = self.var("ab_to_mem_valid", # self.interconnect_input_ports) self._agg_buffers = [] # Add input aggregations buffers for i in range(self.interconnect_input_ports): # add children aggregator buffers... agg_buffer_new = AggregationBuffer(self.agg_height, self.data_width, self.mem_width, self.max_agg_schedule) self._agg_buffers.append(agg_buffer_new) self.add_child(f"agg_in_{i}", agg_buffer_new, clk=self._clk, rst_n=self._rst_n, data_in=self._data_consume[i], valid_in=self._valid_consume[i], align=self._align_to_agg[i], data_out=self._to_iac_dat[i], valid_out=self._to_iac_valid[i]) ####################################### ##### END: AGG BUFFERS (OPTIONAL) ##### ####################################### self._ready_tba = self.var("ready_tba", self.interconnect_output_ports) #################################### ##### INPUT ADDRESS CONTROLLER ##### #################################### self._wen_to_arb = self.var("wen_to_arb", self.mem_input_ports, size=self.banks, explicit_array=True, packed=True) self._addr_to_arb = self.var("addr_to_arb", self.address_width, size=(self.banks, self.mem_input_ports), explicit_array=True, packed=True) self._data_to_arb = self.var("data_to_arb", self.data_width, size=(self.banks, self.mem_input_ports, self.fw_int), explicit_array=True, packed=True) # Connect these inputs ports to an address generator iac = InputAddrCtrl(interconnect_input_ports=self.interconnect_input_ports, mem_depth=self.mem_depth, num_tiles=self.num_tiles, banks=self.banks, iterator_support=self.input_iterator_support, address_width=self.address_width, data_width=self.data_width, fetch_width=self.mem_width, multiwrite=self.multiwrite, strg_wr_ports=self.mem_input_ports, config_width=self.input_config_width) self.add_child(f"input_addr_ctrl", iac, clk=self._clk, rst_n=self._rst_n, valid_in=self._to_iac_valid, # wen_en=kts.concat(*([kts.const(1, 1)] * self.interconnect_input_ports)), wen_en=self._arb_wen_en, data_in=self._to_iac_dat, wen_to_sram=self._wen_to_arb, addr_out=self._addr_to_arb, port_out=self._port_wens, data_out=self._data_to_arb) ######################################### ##### END: INPUT ADDRESS CONTROLLER ##### ######################################### self._arb_acks = self.var("arb_acks", self.interconnect_output_ports, size=self.banks, explicit_array=True, packed=True) self._prefetch_step = self.var("prefetch_step", self.interconnect_output_ports) self._oac_step = self.var("oac_step", self.interconnect_output_ports) self._oac_valid = self.var("oac_valid", self.interconnect_output_ports) self._ren_out = self.var("ren_out", self.interconnect_output_ports, size=self.banks, explicit_array=True, packed=True) self._ren_out_tpose = self.var("ren_out_tpose", self.banks, size=self.interconnect_output_ports, explicit_array=True, packed=True) self._oac_addr_out = self.var("oac_addr_out", self.address_width, size=self.interconnect_output_ports, explicit_array=True, packed=True) ##################################### ##### OUTPUT ADDRESS CONTROLLER ##### ##################################### oac = OutputAddrCtrl(interconnect_output_ports=self.interconnect_output_ports, mem_depth=self.mem_depth, num_tiles=self.num_tiles, banks=self.banks, iterator_support=self.output_iterator_support, address_width=self.address_width, config_width=self.output_config_width) if self.remove_tb: self.wire(self._oac_valid, self._ren) self.wire(self._oac_step, self._ren) else: self.wire(self._oac_valid, self._prefetch_step) self.wire(self._oac_step, self._ack_reduced) self.chain_idx_bits = max(1, clog2(num_tiles)) self._enable_chain_output = self.input("enable_chain_output", 1) self._chain_idx_output = self.input("chain_idx_output", self.chain_idx_bits) self.add_child(f"output_addr_ctrl", oac, clk=self._clk, rst_n=self._rst_n, valid_in=self._oac_valid, ren=self._ren_out, addr_out=self._oac_addr_out, step_in=self._oac_step) for i in range(self.interconnect_output_ports): for j in range(self.banks): self.wire(self._ren_out_tpose[i][j], self._ren_out[j][i]) ############################## ##### READ/WRITE ARBITER ##### ############################## # Hook up the read write arbiters for each bank self._arb_dat_out = self.var("arb_dat_out", self.data_width, size=(self.banks, self.mem_output_ports, self.fw_int), explicit_array=True, packed=True) self._arb_port_out = self.var("arb_port_out", self.interconnect_output_ports, size=(self.banks, self.mem_output_ports), explicit_array=True, packed=True) self._arb_valid_out = self.var("arb_valid_out", self.mem_output_ports, size=self.banks, explicit_array=True, packed=True) self._rd_sync_gate = self.var("rd_sync_gate", self.interconnect_output_ports) self.arbiters = [] for i in range(self.banks): rw_arb = RWArbiter(fetch_width=self.mem_width, data_width=self.data_width, memory_depth=self.mem_depth, num_tiles=self.num_tiles, int_in_ports=self.interconnect_input_ports, int_out_ports=self.interconnect_output_ports, strg_wr_ports=self.mem_input_ports, strg_rd_ports=self.mem_output_ports, read_delay=self.read_delay, rw_same_cycle=self.rw_same_cycle, separate_addresses=self.rw_same_cycle) self.arbiters.append(rw_arb) self.add_child(f"rw_arb_{i}", rw_arb, clk=self._clk, rst_n=self._rst_n, wen_in=self._wen_to_arb[i], w_data=self._data_to_arb[i], w_addr=self._addr_to_arb[i], data_from_mem=self._data_from_strg[i], mem_valid_data=self._mem_valid_data[i], out_mem_valid_data=self._out_mem_valid_data[i], ren_en=self._arb_ren_en, rd_addr=self._oac_addr_out, out_data=self._arb_dat_out[i], out_port=self._arb_port_out[i], out_valid=self._arb_valid_out[i], cen_mem=self._cen_to_strg[i], wen_mem=self._wen_to_strg[i], data_to_mem=self._data_to_strg[i], out_ack=self._arb_acks[i]) # Bind the separate addrs if self.rw_same_cycle: self.wire(rw_arb.ports.wr_addr_to_mem, self._wr_addr_out[i]) self.wire(rw_arb.ports.rd_addr_to_mem, self._rd_addr_out[i]) else: self.wire(rw_arb.ports.addr_to_mem, self._addr_out[i]) if self.remove_tb: self.wire(rw_arb.ports.ren_in, self._ren_out[i]) else: self.wire(rw_arb.ports.ren_in, self._ren_out[i] & self._rd_sync_gate) self.num_tb_bits = max(1, clog2(self.num_tb)) self._data_to_sync = self.var("data_to_sync", self.data_width, size=(self.interconnect_output_ports, self.fw_int), explicit_array=True, packed=True) self._valid_to_sync = self.var("valid_to_sync", self.interconnect_output_ports) self._data_to_tba = self.var("data_to_tba", self.data_width, size=(self.interconnect_output_ports, self.fw_int), explicit_array=True, packed=True) self._valid_to_tba = self.var("valid_to_tba", self.interconnect_output_ports) self._data_to_pref = self.var("data_to_pref", self.data_width, size=(self.interconnect_output_ports, self.fw_int), explicit_array=True, packed=True) self._valid_to_pref = self.var("valid_to_pref", self.interconnect_output_ports) ####################### ##### DEMUX READS ##### ####################### dmux_rd = DemuxReads(fetch_width=self.mem_width, data_width=self.data_width, banks=self.banks, int_out_ports=self.interconnect_output_ports, strg_rd_ports=self.mem_output_ports) self._arb_dat_out_f = self.var("arb_dat_out_f", self.data_width, size=(self.banks * self.mem_output_ports, self.fw_int), explicit_array=True, packed=True) self._arb_port_out_f = self.var("arb_port_out_f", self.interconnect_output_ports, size=(self.banks * self.mem_output_ports), explicit_array=True, packed=True) self._arb_valid_out_f = self.var("arb_valid_out_f", self.mem_output_ports * self.banks) self._arb_mem_valid_data_f = self.var("arb_mem_valid_data_f", self.mem_output_ports * self.banks) self._arb_mem_valid_data_out = self.var("arb_mem_valid_data_out", self.interconnect_output_ports) self._mem_valid_data_sync = self.var("mem_valid_data_sync", self.interconnect_output_ports) self._mem_valid_data_pref = self.var("mem_valid_data_pref", self.interconnect_output_ports) tmp_cnt = 0 for i in range(self.banks): for j in range(self.mem_output_ports): self.wire(self._arb_dat_out_f[tmp_cnt], self._arb_dat_out[i][j]) self.wire(self._arb_port_out_f[tmp_cnt], self._arb_port_out[i][j]) self.wire(self._arb_valid_out_f[tmp_cnt], self._arb_valid_out[i][j]) self.wire(self._arb_mem_valid_data_f[tmp_cnt], self._out_mem_valid_data[i][j]) tmp_cnt = tmp_cnt + 1 # If this is end of the road... if self.remove_tb: assert self.fw_int == 1, "Make it easier on me now..." self.add_child("demux_rds", dmux_rd, clk=self._clk, rst_n=self._rst_n, data_in=self._arb_dat_out_f, mem_valid_data=self._arb_mem_valid_data_f, mem_valid_data_out=self._arb_mem_valid_data_out, valid_in=self._arb_valid_out_f, port_in=self._arb_port_out_f, valid_out=self._valid_out_alt) for i in range(self.interconnect_output_ports): self.wire(self._data_out[i], dmux_rd.ports.data_out[i]) else: self.add_child("demux_rds", dmux_rd, clk=self._clk, rst_n=self._rst_n, data_in=self._arb_dat_out_f, mem_valid_data=self._arb_mem_valid_data_f, mem_valid_data_out=self._arb_mem_valid_data_out, valid_in=self._arb_valid_out_f, port_in=self._arb_port_out_f, data_out=self._data_to_sync, valid_out=self._valid_to_sync) ####################### ##### SYNC GROUPS ##### ####################### sync_group = SyncGroups(fetch_width=self.mem_width, data_width=self.data_width, int_out_ports=self.interconnect_output_ports) for i in range(self.interconnect_output_ports): self.wire(self._ren_out_reduced[i], self._ren_out_tpose[i].r_or()) self.add_child("sync_grp", sync_group, clk=self._clk, rst_n=self._rst_n, data_in=self._data_to_sync, mem_valid_data=self._arb_mem_valid_data_out, mem_valid_data_out=self._mem_valid_data_sync, valid_in=self._valid_to_sync, data_out=self._data_to_pref, valid_out=self._valid_to_pref, ren_in=self._ren_out_reduced, rd_sync_gate=self._rd_sync_gate, ack_in=self._ack_reduced) # This is the end of the line if we aren't using tb ###################### ##### PREFETCHER ##### ###################### prefetchers = [] for i in range(self.interconnect_output_ports): pref = Prefetcher(fetch_width=self.mem_width, data_width=self.data_width, max_prefetch=self.max_prefetch) prefetchers.append(pref) if self.num_tb == 0: assert self.fw_int == 1, \ "If no transpose buffer, data width needs match memory width" self.add_child(f"pre_fetch_{i}", pref, clk=self._clk, rst_n=self._rst_n, data_in=self._data_to_pref[i], mem_valid_data=self._mem_valid_data_sync[i], mem_valid_data_out=self._mem_valid_data_pref[i], valid_read=self._valid_to_pref[i], tba_rdy_in=self._ren[i], # data_out=self._data_out[i], valid_out=self._valid_out_alt[i], prefetch_step=self._prefetch_step[i]) self.wire(self._data_out[i], pref.ports.data_out[0]) else: self.add_child(f"pre_fetch_{i}", pref, clk=self._clk, rst_n=self._rst_n, data_in=self._data_to_pref[i], mem_valid_data=self._mem_valid_data_sync[i], mem_valid_data_out=self._mem_valid_data_pref[i], valid_read=self._valid_to_pref[i], tba_rdy_in=self._ready_tba[i], data_out=self._data_to_tba[i], valid_out=self._valid_to_tba[i], prefetch_step=self._prefetch_step[i]) ############################# ##### TRANSPOSE BUFFERS ##### ############################# if self.num_tb > 0: self._tb_data_out = self.var("tb_data_out", self.data_width, size=self.interconnect_output_ports, explicit_array=True, packed=True) self._tb_valid_out = self.var("tb_valid_out", self.interconnect_output_ports) for i in range(self.interconnect_output_ports): tba = TransposeBufferAggregation(word_width=self.data_width, fetch_width=self.fw_int, num_tb=self.num_tb, max_tb_height=self.max_tb_height, max_range=self.tb_range_max, max_range_inner=self.tb_range_inner_max, max_stride=self.max_tb_stride, tb_iterator_support=self.tb_iterator_support) self.add_child(f"tba_{i}", tba, clk=self._clk, rst_n=self._rst_n, SRAM_to_tb_data=self._data_to_tba[i], valid_data=self._valid_to_tba[i], tb_index_for_data=0, ack_in=self._valid_to_tba[i], mem_valid_data=self._mem_valid_data_pref[i], tb_to_interconnect_data=self._tb_data_out[i], tb_to_interconnect_valid=self._tb_valid_out[i], tb_arbiter_rdy=self._ready_tba[i], tba_ren=self._ren[i]) for i in range(self.interconnect_output_ports): self.wire(self._data_out[i], self._tb_data_out[i]) # self.wire(self._valid_out[i], self._tb_valid_out[i]) else: self.wire(self._valid_out, self._valid_out_alt) #################### ##### ADD CODE ##### #################### self.add_code(self.transpose_acks) self.add_code(self.reduce_acks)
def decrement(var, value): return var - kts.const(value, var.width)
def set_inc(self, idx): self.inc[idx] = 0 if (const(idx, 5) == 0) & self.step & (idx < self.dim): self.inc[idx] = 1 elif (idx == self.mux_sel) & self.step & (idx < self.dim): self.inc[idx] = 1
def code(self): # because the types are inferred # implicit const conversion doesn't work here self._reg = func(const(1, 2)) self._out = func(const(1, 2))
def __init__(self, data_width=16, banks=1, memory_width=64, rw_same_cycle=False, read_delay=1, addr_width=9): super().__init__("strg_fifo") # Generation parameters self.banks = banks self.data_width = data_width self.memory_width = memory_width self.rw_same_cycle = rw_same_cycle self.read_delay = read_delay self.addr_width = addr_width self.fw_int = int(self.memory_width / self.data_width) # assert banks > 1 or rw_same_cycle is True or self.fw_int > 1, \ # "Can't sustain throughput with this setup. Need potential bandwidth for " + \ # "1 write and 1 read in a cycle - try using more banks or a macro that supports 1R1W" # Clock and Reset self._clk = self.clock("clk") self._rst_n = self.reset("rst_n") # Inputs + Outputs self._push = self.input("push", 1) self._data_in = self.input("data_in", self.data_width) self._pop = self.input("pop", 1) self._data_out = self.output("data_out", self.data_width) self._valid_out = self.output("valid_out", 1) self._empty = self.output("empty", 1) self._full = self.output("full", 1) # get relevant signals from the storage banks self._data_from_strg = self.input("data_from_strg", self.data_width, size=(self.banks, self.fw_int), explicit_array=True, packed=True) self._wen_addr = self.var("wen_addr", self.addr_width, size=self.banks, explicit_array=True, packed=True) self._ren_addr = self.var("ren_addr", self.addr_width, size=self.banks, explicit_array=True, packed=True) self._front_combined = self.var("front_combined", self.data_width, size=self.fw_int, explicit_array=True, packed=True) self._data_to_strg = self.output("data_to_strg", self.data_width, size=(self.banks, self.fw_int), explicit_array=True, packed=True) self._wen_to_strg = self.output("wen_to_strg", self.banks) self._ren_to_strg = self.output("ren_to_strg", self.banks) self._num_words_mem = self.var("num_words_mem", self.data_width) if self.banks == 1: self._curr_bank_wr = self.var("curr_bank_wr", 1) self.wire(self._curr_bank_wr, kts.const(0, 1)) self._curr_bank_rd = self.var("curr_bank_rd", 1) self.wire(self._curr_bank_rd, kts.const(0, 1)) else: self._curr_bank_wr = self.var("curr_bank_wr", kts.clog2(self.banks)) self._curr_bank_rd = self.var("curr_bank_rd", kts.clog2(self.banks)) self._write_queue = self.var("write_queue", self.data_width, size=(self.banks, self.fw_int), explicit_array=True, packed=True) # Lets us know if the bank has a write queued up self._queued_write = self.var("queued_write", self.banks) self._front_data_out = self.var("front_data_out", self.data_width) self._front_pop = self.var("front_pop", 1) self._front_empty = self.var("front_empty", 1) self._front_full = self.var("front_full", 1) self._front_valid = self.var("front_valid", 1) self._front_par_read = self.var("front_par_read", 1) self._front_par_out = self.var("front_par_out", self.data_width, size=(self.fw_int, 1), explicit_array=True, packed=True) self._front_rd_ptr = self.var("front_rd_ptr", max(1, clog2(self.fw_int))) self._front_push = self.var("front_push", 1) self.wire(self._front_push, self._push & (~self._full | self._pop)) self._front_rf = RegFIFO(data_width=self.data_width, width_mult=1, depth=self.fw_int, parallel=True, break_out_rd_ptr=True) # This one breaks out the read pointer so we can properly # reorder the data to storage self.add_child("front_rf", self._front_rf, clk=self._clk, clk_en=kts.const(1, 1), rst_n=self._rst_n, push=self._front_push, pop=self._front_pop, empty=self._front_empty, full=self._front_full, valid=self._front_valid, parallel_read=self._front_par_read, parallel_load=kts.const(0, 1), # We don't need to parallel load the front parallel_in=0, # Same reason as above parallel_out=self._front_par_out, num_load=0, rd_ptr_out=self._front_rd_ptr) self.wire(self._front_rf.ports.data_in[0], self._data_in) self.wire(self._front_data_out, self._front_rf.ports.data_out[0]) self._back_data_in = self.var("back_data_in", self.data_width) self._back_data_out = self.var("back_data_out", self.data_width) self._back_push = self.var("back_push", 1) self._back_empty = self.var("back_empty", 1) self._back_full = self.var("back_full", 1) self._back_valid = self.var("back_valid", 1) self._back_pl = self.var("back_pl", 1) self._back_par_in = self.var("back_par_in", self.data_width, size=(self.fw_int, 1), explicit_array=True, packed=True) self._back_num_load = self.var("back_num_load", clog2(self.fw_int) + 1) self._back_occ = self.var("back_occ", clog2(self.fw_int) + 1) self._front_occ = self.var("front_occ", clog2(self.fw_int) + 1) self._back_rf = RegFIFO(data_width=self.data_width, width_mult=1, depth=self.fw_int, parallel=True, break_out_rd_ptr=False) self._fw_is_1 = self.var("fw_is_1", 1) self.wire(self._fw_is_1, kts.const(self.fw_int == 1, 1)) self._back_pop = self.var("back_pop", 1) if self.fw_int == 1: self.wire(self._back_pop, self._pop & (~self._empty | self._push) & ~self._back_pl) else: self.wire(self._back_pop, self._pop & (~self._empty | self._push)) self.add_child("back_rf", self._back_rf, clk=self._clk, clk_en=kts.const(1, 1), rst_n=self._rst_n, push=self._back_push, pop=self._back_pop, empty=self._back_empty, full=self._back_full, valid=self._back_valid, parallel_read=kts.const(0, 1), # Only do back load when data is going there parallel_load=self._back_pl & self._back_num_load.r_or(), parallel_in=self._back_par_in, num_load=self._back_num_load) self.wire(self._back_rf.ports.data_in[0], self._back_data_in) self.wire(self._back_data_out, self._back_rf.ports.data_out[0]) # send the writes through when a read isn't happening for i in range(self.banks): self.add_code(self.send_writes, idx=i) self.add_code(self.send_reads, idx=i) # Set the parallel load to back bank - if no delay it's immediate # if not, it's delayed :) if self.read_delay == 1: self._ren_delay = self.var("ren_delay", 1) self.add_code(self.set_parallel_ld_delay_1) self.wire(self._back_pl, self._ren_delay) else: self.wire(self._back_pl, self._ren_to_strg.r_or()) # Combine front end data - just the items + incoming # this data is actually based on the rd_ptr from the front fifo for i in range(self.fw_int): self.wire(self._front_combined[i], self._front_par_out[self._front_rd_ptr + i]) # This is always true # self.wire(self._front_combined[self.fw_int - 1], self._data_in) # prioritize queued writes, otherwise send combined data for i in range(self.banks): self.wire(self._data_to_strg[i], kts.ternary(self._queued_write[i], self._write_queue[i], self._front_combined)) # Wire the thin output from front to thin input to back self.wire(self._back_data_in, self._front_data_out) self.wire(self._back_push, self._front_valid) self.add_code(self.set_front_pop) # Queue writes for i in range(self.banks): self.add_code(self.set_write_queue, idx=i) # Track number of words in memory # if self.read_delay == 1: # self.add_code(self.set_num_words_mem_delay) # else: self.add_code(self.set_num_words_mem) # Track occupancy of the two small fifos self.add_code(self.set_front_occ) self.add_code(self.set_back_occ) if self.banks > 1: self.add_code(self.set_curr_bank_wr) self.add_code(self.set_curr_bank_rd) if self.read_delay == 1: self._prev_bank_rd = self.var("prev_bank_rd", max(1, kts.clog2(self.banks))) self.add_code(self.set_prev_bank_rd) # Parallel load data to back - based on num_load index_into = self._curr_bank_rd if self.read_delay == 1: index_into = self._prev_bank_rd for i in range(self.fw_int - 1): # Shift data over if you bypassed from the memory output self.wire(self._back_par_in[i], kts.ternary(self._back_num_load == self.fw_int, self._data_from_strg[index_into][i], self._data_from_strg[index_into][i + 1])) self.wire(self._back_par_in[self.fw_int - 1], kts.ternary(self._back_num_load == self.fw_int, self._data_from_strg[index_into][self.fw_int - 1], kts.const(0, self.data_width))) # Set the parallel read to the front fifo - analogous with trying to write to the memory self.add_code(self.set_front_par_read) # Set the number being parallely loaded into the register self.add_code(self.set_back_num_load) # Data out and valid out are (in the general case) just the data and valid from the back fifo # In the case where we have a fresh memory read, it would be from that bank_idx_read = self._curr_bank_rd if self.read_delay == 1: bank_idx_read = self._prev_bank_rd self.wire(self._data_out, kts.ternary(self._back_pl, self._data_from_strg[bank_idx_read][0], self._back_data_out)) self.wire(self._valid_out, kts.ternary(self._back_pl, self._pop, self._back_valid)) # Set addresses to storage for i in range(self.banks): self.add_code(self.set_wen_addr, idx=i) self.add_code(self.set_ren_addr, idx=i) # Now deal with a shared address vs separate addresses if self.rw_same_cycle: # Separate self._wen_addr_out = self.output("wen_addr_out", self.addr_width, size=self.banks, explicit_array=True, packed=True) self._ren_addr_out = self.output("ren_addr_out", self.addr_width, size=self.banks, explicit_array=True, packed=True) self.wire(self._wen_addr_out, self._wen_addr) self.wire(self._ren_addr_out, self._ren_addr) else: self._addr_out = self.output("addr_out", self.addr_width, size=self.banks, explicit_array=True, packed=True) # If sharing the addresses, send read addr with priority for i in range(self.banks): self.wire(self._addr_out[i], kts.ternary(self._wen_to_strg[i], self._wen_addr[i], self._ren_addr[i])) # Do final empty/full self._num_items = self.var("num_items", self.data_width) self.add_code(self.set_num_items) self._fifo_depth = self.input("fifo_depth", self.data_width) self._fifo_depth.add_attribute(ConfigRegAttr("Fifo depth...")) self.wire(self._empty, self._num_items == 0) self.wire(self._full, self._num_items == (self._fifo_depth))
def __init__(self, data_width, config_addr_width, addr_width, fetch_width, total_sets, sets_per_macro): super().__init__("storage_config_seq") self.data_width = data_width self.config_addr_width = config_addr_width self.addr_width = addr_width self.fetch_width = fetch_width self.fw_int = int(self.fetch_width / self.data_width) self.total_sets = total_sets self.sets_per_macro = sets_per_macro self.banks = int(self.total_sets / self.sets_per_macro) self.set_addr_width = clog2(total_sets) # self.storage_addr_width = self. # Clock and Reset self._clk = self.clock("clk") self._rst_n = self.reset("rst_n") # Inputs # phases = [] TODO # Take in the valid and data and attach an address + direct to a port self._config_data_in = self.input("config_data_in", self.data_width) self._config_addr_in = self.input("config_addr_in", self.config_addr_width) self._config_wr = self.input("config_wr", 1) self._config_rd = self.input("config_rd", 1) self._config_en = self.input("config_en", self.total_sets) self._clk_en = self.input("clk_en", 1) self._rd_data_stg = self.input("rd_data_stg", self.data_width, size=(self.banks, self.fw_int), explicit_array=True, packed=True) self._wr_data = self.output("wr_data", self.data_width, size=self.fw_int, explicit_array=True, packed=True) self._rd_data_out = self.output("rd_data_out", self.data_width, size=self.total_sets, explicit_array=True, packed=True) self._addr_out = self.output("addr_out", self.addr_width) # One set per macro means we directly send the config address through if self.sets_per_macro == 1: width = self.addr_width - self.config_addr_width if width > 0: self.wire(self._addr_out, kts.concat(kts.const(0, width), self._config_addr_in)) else: self.wire(self._addr_out, self._config_addr_in[self.addr_width - 1, 0]) else: width = self.addr_width - self.config_addr_width - clog2(self.sets_per_macro) self._set_to_addr = self.var("set_to_addr", clog2(self.sets_per_macro)) self._reduce_en = self.var("reduce_en", self.sets_per_macro) for i in range(self.sets_per_macro): reduce_var = self._config_en[i] for j in range(self.banks - 1): reduce_var = kts.concat(reduce_var, self._config_en[i + (self.sets_per_macro * (j + 1))]) self.wire(self._reduce_en[i], reduce_var.r_or()) self.add_code(self.demux_set_addr) if width > 0: self.wire(self._addr_out, kts.concat(kts.const(0, width), self._set_to_addr, self._config_addr_in)) else: self.wire(self._addr_out, kts.concat(self._set_to_addr, self._config_addr_in)) self._wen_out = self.output("wen_out", self.banks) self._ren_out = self.output("ren_out", self.banks) # Handle data passing if self.fw_int == 1: # If word width is same as data width, just pass everything through self.wire(self._wr_data[0], self._config_data_in) # self.wire(self._rd_data_out, self._rd_data_stg[0]) num = 0 for i in range(self.banks): for j in range(self.sets_per_macro): self.wire(self._rd_data_out[num], self._rd_data_stg[i]) num = num + 1 else: self._data_wr_reg = self.var("data_wr_reg", self.data_width, size=self.fw_int - 1, packed=True, explicit_array=True) # self._data_rd_reg = self.var("data_rd_reg", # self.data_width, # size=self.fw_int - 1, # packed=True, # explicit_array=True) # Have word counter for repeated reads/writes self._cnt = self.var("cnt", clog2(self.fw_int)) self._rd_cnt = self.var("rd_cnt", clog2(self.fw_int)) self.add_code(self.update_cnt) self.add_code(self.update_rd_cnt) # Gate wen if not about to finish the word num = 0 for i in range(self.banks): for j in range(self.sets_per_macro): self.wire(self._rd_data_out[num], self._rd_data_stg[i][self._rd_cnt]) num = num + 1 # Deal with writing to the data buffer self.add_code(self.write_buffer) # Wire the reg + such to this guy for i in range(self.fw_int - 1): self.wire(self._wr_data[i], self._data_wr_reg[i]) self.wire(self._wr_data[self.fw_int - 1], self._config_data_in) # If we have one bank, we can just always rd/wr from that one if self.banks == 1: if self.fw_int == 1: self.wire(self._wen_out, self._config_wr) else: self.wire(self._wen_out, self._config_wr & (self._cnt == (self.fw_int - 1))) self.wire(self._ren_out, self._config_rd) # Otherwise we need to extract the bank from the set else: if self.fw_int == 1: for i in range(self.banks): width = self.sets_per_macro self.wire(self._wen_out[i], self._config_wr & self._config_en[(i + 1) * width - 1, i * width].r_or()) else: for i in range(self.banks): width = self.sets_per_macro self.wire(self._wen_out[i], self._config_wr & self._config_en[(i + 1) * width - 1, i * width].r_or() & (self._cnt == (self.fw_int - 1))) for i in range(self.banks): width = self.sets_per_macro self.wire(self._ren_out[i], self._config_rd & self._config_en[(i + 1) * width - 1, i * width].r_or())
def increment(var, value): return var + kts.const(value, var.width)
def __init__(self, iterator_support=6, config_width=16, use_enable=True): super().__init__(f"sched_gen_{iterator_support}_{config_width}") self.iterator_support = iterator_support self.config_width = config_width self.use_enable = use_enable # PORT DEFS: begin # INPUTS self._clk = self.clock("clk") self._rst_n = self.reset("rst_n") # OUTPUTS self._valid_output = self.output("valid_output", 1) # VARS self._valid_out = self.var("valid_out", 1) self._cycle_count = self.input("cycle_count", self.config_width) self._mux_sel = self.input("mux_sel", max(clog2(self.iterator_support), 1)) self._addr_out = self.var("addr_out", self.config_width) # Receive signal on last iteration of looping structure and # gate the output... self._finished = self.input("finished", 1) self._valid_gate_inv = self.var("valid_gate_inv", 1) self._valid_gate = self.var("valid_gate", 1) self.wire(self._valid_gate, ~self._valid_gate_inv) # Since dim = 0 is not sufficient, we need a way to prevent # the controllers from firing on the starting offset if self.use_enable: self._enable = self.input("enable", 1) self._enable.add_attribute( ConfigRegAttr("Disable the controller so it never fires...")) self._enable.add_attribute( FormalAttr(f"{self._enable.name}", FormalSignalConstraint.SOLVE)) # Otherwise we set it as a 1 and leave it up to synthesis... else: self._enable = self.var("enable", 1) self.wire(self._enable, kratos.const(1, 1)) @always_ff((posedge, "clk"), (negedge, "rst_n")) def valid_gate_inv_ff(): if ~self._rst_n: self._valid_gate_inv = 0 # If we are finishing the looping structure, turn this off to implement one-shot elif self._finished: self._valid_gate_inv = 1 self.add_code(valid_gate_inv_ff) # Compare based on minimum of addr + global cycle... self.c_a_cmp = min(self._cycle_count.width, self._addr_out.width) # PORT DEFS: end self.add_child(f"sched_addr_gen", AddrGen(iterator_support=self.iterator_support, config_width=self.config_width), clk=self._clk, rst_n=self._rst_n, step=self._valid_out, mux_sel=self._mux_sel, addr_out=self._addr_out, restart=const(0, 1)) self.add_code(self.set_valid_out) self.add_code(self.set_valid_output)
def __init__( self, data_width=16, # CGRA Params mem_depth=32, default_iterator_support=3, interconnect_input_ports=2, # Connection to int interconnect_output_ports=2, mem_input_ports=1, mem_output_ports=1, config_data_width=32, config_addr_width=8, cycle_count_width=16, add_clk_enable=True, add_flush=True): super().__init__("pond", debug=True) self.interconnect_input_ports = interconnect_input_ports self.interconnect_output_ports = interconnect_output_ports self.mem_input_ports = mem_input_ports self.mem_output_ports = mem_output_ports self.mem_depth = mem_depth self.data_width = data_width self.config_data_width = config_data_width self.config_addr_width = config_addr_width self.add_clk_enable = add_clk_enable self.add_flush = add_flush self.cycle_count_width = cycle_count_width self.default_iterator_support = default_iterator_support self.default_config_width = kts.clog2(self.mem_depth) # inputs self._clk = self.clock("clk") self._clk.add_attribute( FormalAttr(f"{self._clk.name}", FormalSignalConstraint.CLK)) self._rst_n = self.reset("rst_n") self._rst_n.add_attribute( FormalAttr(f"{self._rst_n.name}", FormalSignalConstraint.RSTN)) self._clk_en = self.clock_en("clk_en", 1) # Enable/Disable tile self._tile_en = self.input("tile_en", 1) self._tile_en.add_attribute( ConfigRegAttr("Tile logic enable manifested as clock gate")) gclk = self.var("gclk", 1) self._gclk = kts.util.clock(gclk) self.wire(gclk, kts.util.clock(self._clk & self._tile_en)) self._cycle_count = add_counter(self, "cycle_count", self.cycle_count_width) # Create write enable + addr, same for read. # self._write = self.input("write", self.interconnect_input_ports) self._write = self.var("write", self.mem_input_ports) # self._write.add_attribute(ControlSignalAttr(is_control=True)) self._write_addr = self.var("write_addr", kts.clog2(self.mem_depth), size=self.interconnect_input_ports, explicit_array=True, packed=True) # Add "_pond" suffix to avoid error during garnet RTL generation self._data_in = self.input("data_in_pond", self.data_width, size=self.interconnect_input_ports, explicit_array=True, packed=True) self._data_in.add_attribute( FormalAttr(f"{self._data_in.name}", FormalSignalConstraint.SEQUENCE)) self._data_in.add_attribute(ControlSignalAttr(is_control=False)) self._read = self.var("read", self.mem_output_ports) self._t_write = self.var("t_write", self.interconnect_input_ports) self._t_read = self.var("t_read", self.interconnect_output_ports) # self._read.add_attribute(ControlSignalAttr(is_control=True)) self._read_addr = self.var("read_addr", kts.clog2(self.mem_depth), size=self.interconnect_output_ports, explicit_array=True, packed=True) self._s_read_addr = self.var("s_read_addr", kts.clog2(self.mem_depth), size=self.interconnect_output_ports, explicit_array=True, packed=True) self._data_out = self.output("data_out_pond", self.data_width, size=self.interconnect_output_ports, explicit_array=True, packed=True) self._data_out.add_attribute( FormalAttr(f"{self._data_out.name}", FormalSignalConstraint.SEQUENCE)) self._data_out.add_attribute(ControlSignalAttr(is_control=False)) self._valid_out = self.output("valid_out_pond", self.interconnect_output_ports) self._valid_out.add_attribute( FormalAttr(f"{self._valid_out.name}", FormalSignalConstraint.SEQUENCE)) self._valid_out.add_attribute(ControlSignalAttr(is_control=False)) self._mem_data_out = self.var("mem_data_out", self.data_width, size=self.mem_output_ports, explicit_array=True, packed=True) self._s_mem_data_in = self.var("s_mem_data_in", self.data_width, size=self.interconnect_input_ports, explicit_array=True, packed=True) self._mem_data_in = self.var("mem_data_in", self.data_width, size=self.mem_input_ports, explicit_array=True, packed=True) self._s_mem_write_addr = self.var("s_mem_write_addr", kts.clog2(self.mem_depth), size=self.interconnect_input_ports, explicit_array=True, packed=True) self._s_mem_read_addr = self.var("s_mem_read_addr", kts.clog2(self.mem_depth), size=self.interconnect_output_ports, explicit_array=True, packed=True) self._mem_write_addr = self.var("mem_write_addr", kts.clog2(self.mem_depth), size=self.mem_input_ports, explicit_array=True, packed=True) self._mem_read_addr = self.var("mem_read_addr", kts.clog2(self.mem_depth), size=self.mem_output_ports, explicit_array=True, packed=True) if self.interconnect_output_ports == 1: self.wire(self._data_out[0], self._mem_data_out[0]) else: for i in range(self.interconnect_output_ports): self.wire(self._data_out[i], self._mem_data_out[0]) # Valid out is simply passing the read signal through... self.wire(self._valid_out, self._t_read) # Create write addressors for wr_port in range(self.interconnect_input_ports): RF_WRITE_ITER = ForLoop( iterator_support=self.default_iterator_support, config_width=self.cycle_count_width) RF_WRITE_ADDR = AddrGen( iterator_support=self.default_iterator_support, config_width=self.default_config_width) RF_WRITE_SCHED = SchedGen( iterator_support=self.default_iterator_support, config_width=self.cycle_count_width, use_enable=True) self.add_child(f"rf_write_iter_{wr_port}", RF_WRITE_ITER, clk=self._gclk, rst_n=self._rst_n, step=self._t_write[wr_port]) # Whatever comes through here should hopefully just pipe through seamlessly # addressor modules self.add_child(f"rf_write_addr_{wr_port}", RF_WRITE_ADDR, clk=self._gclk, rst_n=self._rst_n, step=self._t_write[wr_port], mux_sel=RF_WRITE_ITER.ports.mux_sel_out, restart=RF_WRITE_ITER.ports.restart) safe_wire(self, self._write_addr[wr_port], RF_WRITE_ADDR.ports.addr_out) self.add_child(f"rf_write_sched_{wr_port}", RF_WRITE_SCHED, clk=self._gclk, rst_n=self._rst_n, mux_sel=RF_WRITE_ITER.ports.mux_sel_out, finished=RF_WRITE_ITER.ports.restart, cycle_count=self._cycle_count, valid_output=self._t_write[wr_port]) # Create read addressors for rd_port in range(self.interconnect_output_ports): RF_READ_ITER = ForLoop( iterator_support=self.default_iterator_support, config_width=self.cycle_count_width) RF_READ_ADDR = AddrGen( iterator_support=self.default_iterator_support, config_width=self.default_config_width) RF_READ_SCHED = SchedGen( iterator_support=self.default_iterator_support, config_width=self.cycle_count_width, use_enable=True) self.add_child(f"rf_read_iter_{rd_port}", RF_READ_ITER, clk=self._gclk, rst_n=self._rst_n, step=self._t_read[rd_port]) self.add_child(f"rf_read_addr_{rd_port}", RF_READ_ADDR, clk=self._gclk, rst_n=self._rst_n, step=self._t_read[rd_port], mux_sel=RF_READ_ITER.ports.mux_sel_out, restart=RF_READ_ITER.ports.restart) if self.interconnect_output_ports > 1: safe_wire(self, self._read_addr[rd_port], RF_READ_ADDR.ports.addr_out) else: safe_wire(self, self._read_addr[rd_port], RF_READ_ADDR.ports.addr_out) self.add_child(f"rf_read_sched_{rd_port}", RF_READ_SCHED, clk=self._gclk, rst_n=self._rst_n, mux_sel=RF_READ_ITER.ports.mux_sel_out, finished=RF_READ_ITER.ports.restart, cycle_count=self._cycle_count, valid_output=self._t_read[rd_port]) self.wire(self._write, self._t_write.r_or()) self.wire(self._mem_write_addr[0], decode(self, self._t_write, self._s_mem_write_addr)) self.wire(self._mem_data_in[0], decode(self, self._t_write, self._s_mem_data_in)) self.wire(self._read, self._t_read.r_or()) self.wire(self._mem_read_addr[0], decode(self, self._t_read, self._s_mem_read_addr)) # =================================== # Instantiate config hooks... # =================================== self.fw_int = 1 self.data_words_per_set = 2**self.config_addr_width self.sets = int( (self.fw_int * self.mem_depth) / self.data_words_per_set) self.sets_per_macro = max( 1, int(self.mem_depth / self.data_words_per_set)) self.total_sets = max(1, 1 * self.sets_per_macro) self._config_data_in = self.input("config_data_in", self.config_data_width) self._config_data_in.add_attribute(ControlSignalAttr(is_control=False)) self._config_data_in_shrt = self.var("config_data_in_shrt", self.data_width) self.wire(self._config_data_in_shrt, self._config_data_in[self.data_width - 1, 0]) self._config_addr_in = self.input("config_addr_in", self.config_addr_width) self._config_addr_in.add_attribute(ControlSignalAttr(is_control=False)) self._config_data_out_shrt = self.var("config_data_out_shrt", self.data_width, size=self.total_sets, explicit_array=True, packed=True) self._config_data_out = self.output("config_data_out", self.config_data_width, size=self.total_sets, explicit_array=True, packed=True) self._config_data_out.add_attribute( ControlSignalAttr(is_control=False)) for i in range(self.total_sets): self.wire( self._config_data_out[i], self._config_data_out_shrt[i].extend(self.config_data_width)) self._config_read = self.input("config_read", 1) self._config_read.add_attribute(ControlSignalAttr(is_control=False)) self._config_write = self.input("config_write", 1) self._config_write.add_attribute(ControlSignalAttr(is_control=False)) self._config_en = self.input("config_en", self.total_sets) self._config_en.add_attribute(ControlSignalAttr(is_control=False)) self._mem_data_cfg = self.var("mem_data_cfg", self.data_width, explicit_array=True, packed=True) self._mem_addr_cfg = self.var("mem_addr_cfg", kts.clog2(self.mem_depth)) # Add config... stg_cfg_seq = StorageConfigSeq( data_width=self.data_width, config_addr_width=self.config_addr_width, addr_width=kts.clog2(self.mem_depth), fetch_width=self.data_width, total_sets=self.total_sets, sets_per_macro=self.sets_per_macro) # The clock to config sequencer needs to be the normal clock or # if the tile is off, we bring the clock back in based on config_en cfg_seq_clk = self.var("cfg_seq_clk", 1) self._cfg_seq_clk = kts.util.clock(cfg_seq_clk) self.wire(cfg_seq_clk, kts.util.clock(self._gclk)) self.add_child(f"config_seq", stg_cfg_seq, clk=self._cfg_seq_clk, rst_n=self._rst_n, clk_en=self._clk_en | self._config_en.r_or(), config_data_in=self._config_data_in_shrt, config_addr_in=self._config_addr_in, config_wr=self._config_write, config_rd=self._config_read, config_en=self._config_en, wr_data=self._mem_data_cfg, rd_data_out=self._config_data_out_shrt, addr_out=self._mem_addr_cfg) if self.interconnect_output_ports == 1: self.wire(stg_cfg_seq.ports.rd_data_stg, self._mem_data_out) else: self.wire(stg_cfg_seq.ports.rd_data_stg[0], self._mem_data_out[0]) self.RF_GEN = RegisterFile(data_width=self.data_width, write_ports=self.mem_input_ports, read_ports=self.mem_output_ports, width_mult=1, depth=self.mem_depth, read_delay=0) # Now we can instantiate and wire up the register file self.add_child(f"rf", self.RF_GEN, clk=self._gclk, rst_n=self._rst_n, data_out=self._mem_data_out) # Opt in for config_write self._write_rf = self.var("write_rf", self.mem_input_ports) self.wire( self._write_rf[0], kts.ternary(self._config_en.r_or(), self._config_write, self._write[0])) for i in range(self.mem_input_ports - 1): self.wire( self._write_rf[i + 1], kts.ternary(self._config_en.r_or(), kts.const(0, 1), self._write[i + 1])) self.wire(self.RF_GEN.ports.wen, self._write_rf) # Opt in for config_data_in for i in range(self.interconnect_input_ports): self.wire( self._s_mem_data_in[i], kts.ternary(self._config_en.r_or(), self._mem_data_cfg, self._data_in[i])) self.wire(self.RF_GEN.ports.data_in, self._mem_data_in) # Opt in for config_addr for i in range(self.interconnect_input_ports): self.wire( self._s_mem_write_addr[i], kts.ternary(self._config_en.r_or(), self._mem_addr_cfg, self._write_addr[i])) self.wire(self.RF_GEN.ports.wr_addr, self._mem_write_addr[0]) for i in range(self.interconnect_output_ports): self.wire( self._s_mem_read_addr[i], kts.ternary(self._config_en.r_or(), self._mem_addr_cfg, self._read_addr[i])) self.wire(self.RF_GEN.ports.rd_addr, self._mem_read_addr[0]) if self.add_clk_enable: # self.clock_en("clk_en") kts.passes.auto_insert_clock_enable(self.internal_generator) clk_en_port = self.internal_generator.get_port("clk_en") clk_en_port.add_attribute(ControlSignalAttr(False)) if self.add_flush: self.add_attribute("sync-reset=flush") kts.passes.auto_insert_sync_reset(self.internal_generator) flush_port = self.internal_generator.get_port("flush") flush_port.add_attribute(ControlSignalAttr(True)) # Finally, lift the config regs... lift_config_reg(self.internal_generator)
def code(): out_ = test_add(in_, const(1, 16))