def __init__( self, data_width=16, # CGRA Params mem_width=16, mem_depth=256, banks=1, input_iterator_support=6, # Addr Controllers output_iterator_support=6, input_config_width=16, output_config_width=16, interconnect_input_ports=1, # Connection to int interconnect_output_ports=1, mem_input_ports=1, mem_output_ports=1, use_sram_stub=True, sram_macro_info=SRAMMacroInfo(), read_delay=1, # Cycle delay in read (SRAM vs Register File) rw_same_cycle=True, # Does the memory allow r+w in same cycle? agg_height=4, tb_sched_max=16, config_data_width=32, config_addr_width=8, num_tiles=1, remove_tb=False, fifo_mode=False, add_clk_enable=True, add_flush=True, override_name=None): # name if override_name: self.__name = override_name + "Core" lake_name = override_name else: self.__name = "MemCore" lake_name = "LakeTop" super().__init__(config_addr_width, config_data_width) # Capture everything to the tile object self.data_width = data_width self.mem_width = mem_width self.mem_depth = mem_depth self.banks = banks self.fw_int = int(self.mem_width / self.data_width) self.input_iterator_support = input_iterator_support self.output_iterator_support = output_iterator_support self.input_config_width = input_config_width self.output_config_width = output_config_width self.interconnect_input_ports = interconnect_input_ports self.interconnect_output_ports = interconnect_output_ports self.mem_input_ports = mem_input_ports self.mem_output_ports = mem_output_ports self.use_sram_stub = use_sram_stub self.sram_macro_info = sram_macro_info self.read_delay = read_delay self.rw_same_cycle = rw_same_cycle self.agg_height = agg_height self.config_data_width = config_data_width self.config_addr_width = config_addr_width self.num_tiles = num_tiles self.remove_tb = remove_tb self.fifo_mode = fifo_mode self.add_clk_enable = add_clk_enable self.add_flush = add_flush # self.app_ctrl_depth_width = app_ctrl_depth_width # self.stcl_valid_iter = stcl_valid_iter # Typedefs for ease TData = magma.Bits[self.data_width] TBit = magma.Bits[1] self.__inputs = [] self.__outputs = [] # cache_key = (self.data_width, self.mem_width, self.mem_depth, self.banks, # self.input_iterator_support, self.output_iterator_support, # self.interconnect_input_ports, self.interconnect_output_ports, # self.use_sram_stub, self.sram_macro_info, self.read_delay, # self.rw_same_cycle, self.agg_height, self.max_agg_schedule, # self.input_max_port_sched, self.output_max_port_sched, # self.align_input, self.max_line_length, self.max_tb_height, # self.tb_range_max, self.tb_sched_max, self.max_tb_stride, # self.num_tb, self.tb_iterator_support, self.multiwrite, # self.max_prefetch, self.config_data_width, self.config_addr_width, # self.num_tiles, self.remove_tb, self.fifo_mode, self.stcl_valid_iter, # self.add_clk_enable, self.add_flush, self.app_ctrl_depth_width) cache_key = (self.data_width, self.mem_width, self.mem_depth, self.banks, self.input_iterator_support, self.output_iterator_support, self.interconnect_input_ports, self.interconnect_output_ports, self.use_sram_stub, self.sram_macro_info, self.read_delay, self.rw_same_cycle, self.agg_height, self.config_data_width, self.config_addr_width, self.num_tiles, self.remove_tb, self.fifo_mode, self.add_clk_enable, self.add_flush) # Check for circuit caching if cache_key not in MemCore.__circuit_cache: # Instantiate core object here - will only use the object representation to # query for information. The circuit representation will be cached and retrieved # in the following steps. # lt_dut = LakeTop(data_width=self.data_width, # mem_width=self.mem_width, # mem_depth=self.mem_depth, # banks=self.banks, # input_iterator_support=self.input_iterator_support, # output_iterator_support=self.output_iterator_support, # input_config_width=self.input_config_width, # output_config_width=self.output_config_width, # interconnect_input_ports=self.interconnect_input_ports, # interconnect_output_ports=self.interconnect_output_ports, # use_sram_stub=self.use_sram_stub, # sram_macro_info=self.sram_macro_info, # read_delay=self.read_delay, # rw_same_cycle=self.rw_same_cycle, # agg_height=self.agg_height, # max_agg_schedule=self.max_agg_schedule, # input_max_port_sched=self.input_max_port_sched, # output_max_port_sched=self.output_max_port_sched, # align_input=self.align_input, # max_line_length=self.max_line_length, # max_tb_height=self.max_tb_height, # tb_range_max=self.tb_range_max, # tb_range_inner_max=self.tb_range_inner_max, # tb_sched_max=self.tb_sched_max, # max_tb_stride=self.max_tb_stride, # num_tb=self.num_tb, # tb_iterator_support=self.tb_iterator_support, # multiwrite=self.multiwrite, # max_prefetch=self.max_prefetch, # config_data_width=self.config_data_width, # config_addr_width=self.config_addr_width, # num_tiles=self.num_tiles, # app_ctrl_depth_width=self.app_ctrl_depth_width, # remove_tb=self.remove_tb, # fifo_mode=self.fifo_mode, # add_clk_enable=self.add_clk_enable, # add_flush=self.add_flush, # stcl_valid_iter=self.stcl_valid_iter) lt_dut = LakeTop( data_width=self.data_width, mem_width=self.mem_width, mem_depth=self.mem_depth, banks=self.banks, input_iterator_support=self.input_iterator_support, output_iterator_support=self.output_iterator_support, input_config_width=self.input_config_width, output_config_width=self.output_config_width, interconnect_input_ports=self.interconnect_input_ports, interconnect_output_ports=self.interconnect_output_ports, use_sram_stub=self.use_sram_stub, sram_macro_info=self.sram_macro_info, read_delay=self.read_delay, rw_same_cycle=self.rw_same_cycle, agg_height=self.agg_height, config_data_width=self.config_data_width, config_addr_width=self.config_addr_width, num_tiles=self.num_tiles, remove_tb=self.remove_tb, fifo_mode=self.fifo_mode, add_clk_enable=self.add_clk_enable, add_flush=self.add_flush, name=lake_name, gen_addr=False) change_sram_port_pass = change_sram_port_names( use_sram_stub, sram_macro_info) circ = kts.util.to_magma( lt_dut, flatten_array=True, check_multiple_driver=False, optimize_if=False, check_flip_flop_always_ff=False, additional_passes={"change_sram_port": change_sram_port_pass}) MemCore.__circuit_cache[cache_key] = (circ, lt_dut) else: circ, lt_dut = MemCore.__circuit_cache[cache_key] # Save as underlying circuit object self.underlying = FromMagma(circ) # Enumerate input and output ports # (clk and reset are assumed) core_interface = get_interface(lt_dut) cfgs = extract_top_config(lt_dut) assert len(cfgs) > 0, "No configs?" # We basically add in the configuration bus differently # than the other ports... skip_names = [ "config_data_in", "config_write", "config_addr_in", "config_data_out", "config_read", "config_en", "clk_en" ] # Create a list of signals that will be able to be # hardwired to a constant at runtime... control_signals = [] # The rest of the signals to wire to the underlying representation... other_signals = [] # for port_name, port_size, port_width, is_ctrl, port_dir, explicit_array in core_interface: for io_info in core_interface: if io_info.port_name in skip_names: continue ind_ports = io_info.port_width intf_type = TBit # For our purposes, an explicit array means the inner data HAS to be 16 bits if io_info.expl_arr: ind_ports = io_info.port_size[0] intf_type = TData dir_type = magma.In app_list = self.__inputs if io_info.port_dir == "PortDirection.Out": dir_type = magma.Out app_list = self.__outputs if ind_ports > 1: for i in range(ind_ports): self.add_port(f"{io_info.port_name}_{i}", dir_type(intf_type)) app_list.append(self.ports[f"{io_info.port_name}_{i}"]) else: self.add_port(io_info.port_name, dir_type(intf_type)) app_list.append(self.ports[io_info.port_name]) # classify each signal for wiring to underlying representation... if io_info.is_ctrl: control_signals.append((io_info.port_name, io_info.port_width)) else: if ind_ports > 1: for i in range(ind_ports): other_signals.append( (f"{io_info.port_name}_{i}", io_info.port_dir, io_info.expl_arr, i, io_info.port_name)) else: other_signals.append( (io_info.port_name, io_info.port_dir, io_info.expl_arr, 0, io_info.port_name)) assert (len(self.__outputs) > 0) # We call clk_en stall at this level for legacy reasons???? self.add_ports(stall=magma.In(TBit), ) self.chain_idx_bits = max(1, kts.clog2(self.num_tiles)) # put a 1-bit register and a mux to select the control signals for control_signal, width in control_signals: if width == 1: mux = MuxWrapper(2, 1, name=f"{control_signal}_sel") reg_value_name = f"{control_signal}_reg_value" reg_sel_name = f"{control_signal}_reg_sel" self.add_config(reg_value_name, 1) self.add_config(reg_sel_name, 1) self.wire(mux.ports.I[0], self.ports[control_signal]) self.wire(mux.ports.I[1], self.registers[reg_value_name].ports.O) self.wire(mux.ports.S, self.registers[reg_sel_name].ports.O) # 0 is the default wire, which takes from the routing network self.wire(mux.ports.O[0], self.underlying.ports[control_signal][0]) else: for i in range(width): mux = MuxWrapper(2, 1, name=f"{control_signal}_{i}_sel") reg_value_name = f"{control_signal}_{i}_reg_value" reg_sel_name = f"{control_signal}_{i}_reg_sel" self.add_config(reg_value_name, 1) self.add_config(reg_sel_name, 1) self.wire(mux.ports.I[0], self.ports[f"{control_signal}_{i}"]) self.wire(mux.ports.I[1], self.registers[reg_value_name].ports.O) self.wire(mux.ports.S, self.registers[reg_sel_name].ports.O) # 0 is the default wire, which takes from the routing network self.wire(mux.ports.O[0], self.underlying.ports[control_signal][i]) # Wire the other signals up... for pname, pdir, expl_arr, ind, uname in other_signals: # If we are in an explicit array moment, use the given wire name... if expl_arr is False: # And if not, use the index self.wire(self.ports[pname][0], self.underlying.ports[uname][ind]) else: self.wire(self.ports[pname], self.underlying.ports[pname]) # CLK, RESET, and STALL PER STANDARD PROCEDURE # Need to invert this self.resetInverter = FromMagma(mantle.DefineInvert(1)) self.wire(self.resetInverter.ports.I[0], self.ports.reset) self.wire(self.resetInverter.ports.O[0], self.underlying.ports.rst_n) self.wire(self.ports.clk, self.underlying.ports.clk) # Mem core uses clk_en (essentially active low stall) self.stallInverter = FromMagma(mantle.DefineInvert(1)) self.wire(self.stallInverter.ports.I, self.ports.stall) self.wire(self.stallInverter.ports.O[0], self.underlying.ports.clk_en[0]) # we have six? features in total # 0: TILE # 1: TILE # 1-4: SMEM # Feature 0: Tile self.__features: List[CoreFeature] = [self] # Features 1-4: SRAM self.num_sram_features = lt_dut.total_sets for sram_index in range(self.num_sram_features): core_feature = CoreFeature(self, sram_index + 1) self.__features.append(core_feature) # Wire the config for idx, core_feature in enumerate(self.__features): if (idx > 0): self.add_port( f"config_{idx}", magma.In( ConfigurationType(self.config_addr_width, self.config_data_width))) # port aliasing core_feature.ports["config"] = self.ports[f"config_{idx}"] self.add_port( "config", magma.In( ConfigurationType(self.config_addr_width, self.config_data_width))) # or the signal up t = ConfigurationType(self.config_addr_width, self.config_data_width) t_names = ["config_addr", "config_data"] or_gates = {} for t_name in t_names: port_type = t[t_name] or_gate = FromMagma( mantle.DefineOr(len(self.__features), len(port_type))) or_gate.instance_name = f"OR_{t_name}_FEATURE" for idx, core_feature in enumerate(self.__features): self.wire(or_gate.ports[f"I{idx}"], core_feature.ports.config[t_name]) or_gates[t_name] = or_gate self.wire( or_gates["config_addr"].ports.O, self.underlying.ports.config_addr_in[0:self.config_addr_width]) self.wire(or_gates["config_data"].ports.O, self.underlying.ports.config_data_in) # read data out for idx, core_feature in enumerate(self.__features): if (idx > 0): # self.add_port(f"read_config_data_{idx}", self.add_port(f"read_config_data_{idx}", magma.Out(magma.Bits[self.config_data_width])) # port aliasing core_feature.ports["read_config_data"] = \ self.ports[f"read_config_data_{idx}"] # MEM Config configurations = [] # merged_configs = [] skip_cfgs = [] for cfg_info in cfgs: if cfg_info.port_name in skip_cfgs: continue if cfg_info.expl_arr: if cfg_info.port_size[0] > 1: for i in range(cfg_info.port_size[0]): configurations.append( (f"{cfg_info.port_name}_{i}", cfg_info.port_width)) else: configurations.append( (cfg_info.port_name, cfg_info.port_width)) else: configurations.append( (cfg_info.port_name, cfg_info.port_width)) # Do all the stuff for the main config main_feature = self.__features[0] for config_reg_name, width in configurations: main_feature.add_config(config_reg_name, width) if (width == 1): self.wire(main_feature.registers[config_reg_name].ports.O[0], self.underlying.ports[config_reg_name][0]) else: self.wire(main_feature.registers[config_reg_name].ports.O, self.underlying.ports[config_reg_name]) # SRAM # These should also account for num features # or_all_cfg_rd = FromMagma(mantle.DefineOr(4, 1)) or_all_cfg_rd = FromMagma(mantle.DefineOr(self.num_sram_features, 1)) or_all_cfg_rd.instance_name = f"OR_CONFIG_WR_SRAM" or_all_cfg_wr = FromMagma(mantle.DefineOr(self.num_sram_features, 1)) or_all_cfg_wr.instance_name = f"OR_CONFIG_RD_SRAM" for sram_index in range(self.num_sram_features): core_feature = self.__features[sram_index + 1] self.add_port(f"config_en_{sram_index}", magma.In(magma.Bit)) # port aliasing core_feature.ports["config_en"] = \ self.ports[f"config_en_{sram_index}"] # Sort of a temp hack - the name is just config_data_out if self.num_sram_features == 1: self.wire(core_feature.ports.read_config_data, self.underlying.ports["config_data_out"]) else: self.wire( core_feature.ports.read_config_data, self.underlying.ports[f"config_data_out_{sram_index}"]) # also need to wire the sram signal # the config enable is the OR of the rd+wr or_gate_en = FromMagma(mantle.DefineOr(2, 1)) or_gate_en.instance_name = f"OR_CONFIG_EN_SRAM_{sram_index}" self.wire(or_gate_en.ports.I0, core_feature.ports.config.write) self.wire(or_gate_en.ports.I1, core_feature.ports.config.read) self.wire(core_feature.ports.config_en, self.underlying.ports["config_en"][sram_index]) # Still connect to the OR of all the config rd/wr self.wire(core_feature.ports.config.write, or_all_cfg_wr.ports[f"I{sram_index}"]) self.wire(core_feature.ports.config.read, or_all_cfg_rd.ports[f"I{sram_index}"]) self.wire(or_all_cfg_rd.ports.O[0], self.underlying.ports.config_read[0]) self.wire(or_all_cfg_wr.ports.O[0], self.underlying.ports.config_write[0]) self._setup_config() conf_names = list(self.registers.keys()) conf_names.sort() with open("mem_cfg.txt", "w+") as cfg_dump: for idx, reg in enumerate(conf_names): write_line = f"(\"{reg}\", 0), # {self.registers[reg].width}\n" cfg_dump.write(write_line) with open("mem_synth.txt", "w+") as cfg_dump: for idx, reg in enumerate(conf_names): write_line = f"{reg}\n" cfg_dump.write(write_line)
def wrap_lake_core(self): # Typedefs for ease if self.data_width: TData = magma.Bits[self.data_width] else: TData = magma.Bits[ 16] # This shouldn't be used if the data_width was None TBit = magma.Bits[1] # Enumerate input and output ports # (clk and reset are assumed) core_interface = get_interface(self.dut) cfgs = extract_top_config(self.dut) assert len(cfgs) > 0, "No configs?" # We basically add in the configuration bus differently # than the other ports... skip_names = [ "config_data_in", "config_write", "config_addr_in", "config_data_out", "config_read", "config_en", "clk_en" ] # Create a list of signals that will be able to be # hardwired to a constant at runtime... control_signals = [] # The rest of the signals to wire to the underlying representation... other_signals = [] # for port_name, port_size, port_width, is_ctrl, port_dir, explicit_array in core_interface: for io_info in core_interface: if io_info.port_name in skip_names: continue ind_ports = io_info.port_width intf_type = TBit # For our purposes, an explicit array means the inner data HAS to be 16 bits if io_info.expl_arr: ind_ports = io_info.port_size[0] intf_type = TData dir_type = magma.In app_list = self.__inputs if io_info.port_dir == "PortDirection.Out": dir_type = magma.Out app_list = self.__outputs if ind_ports > 1: for i in range(ind_ports): self.add_port(f"{io_info.port_name}_{i}", dir_type(intf_type)) app_list.append(self.ports[f"{io_info.port_name}_{i}"]) else: self.add_port(io_info.port_name, dir_type(intf_type)) app_list.append(self.ports[io_info.port_name]) # classify each signal for wiring to underlying representation... if io_info.is_ctrl: control_signals.append((io_info.port_name, io_info.port_width)) else: if ind_ports > 1: for i in range(ind_ports): other_signals.append( (f"{io_info.port_name}_{i}", io_info.port_dir, io_info.expl_arr, i, io_info.port_name)) else: other_signals.append( (io_info.port_name, io_info.port_dir, io_info.expl_arr, 0, io_info.port_name)) assert (len(self.__outputs) > 0) # We call clk_en stall at this level for legacy reasons???? self.add_ports(stall=magma.In(TBit), ) # put a 1-bit register and a mux to select the control signals for control_signal, width in control_signals: if width == 1: mux = MuxWrapper(2, 1, name=f"{control_signal}_sel") reg_value_name = f"{control_signal}_reg_value" reg_sel_name = f"{control_signal}_reg_sel" self.add_config(reg_value_name, 1) self.add_config(reg_sel_name, 1) self.wire(mux.ports.I[0], self.ports[control_signal]) self.wire(mux.ports.I[1], self.registers[reg_value_name].ports.O) self.wire(mux.ports.S, self.registers[reg_sel_name].ports.O) # 0 is the default wire, which takes from the routing network self.wire(mux.ports.O[0], self.underlying.ports[control_signal][0]) else: for i in range(width): mux = MuxWrapper(2, 1, name=f"{control_signal}_{i}_sel") reg_value_name = f"{control_signal}_{i}_reg_value" reg_sel_name = f"{control_signal}_{i}_reg_sel" self.add_config(reg_value_name, 1) self.add_config(reg_sel_name, 1) self.wire(mux.ports.I[0], self.ports[f"{control_signal}_{i}"]) self.wire(mux.ports.I[1], self.registers[reg_value_name].ports.O) self.wire(mux.ports.S, self.registers[reg_sel_name].ports.O) # 0 is the default wire, which takes from the routing network self.wire(mux.ports.O[0], self.underlying.ports[control_signal][i]) # Wire the other signals up... for pname, pdir, expl_arr, ind, uname in other_signals: # If we are in an explicit array moment, use the given wire name... if expl_arr is False: # And if not, use the index self.wire(self.ports[pname][0], self.underlying.ports[uname][ind]) else: self.wire(self.ports[pname], self.underlying.ports[pname]) # CLK, RESET, and STALL PER STANDARD PROCEDURE # Need to invert this self.resetInverter = FromMagma(mantle.DefineInvert(1)) self.wire(self.resetInverter.ports.I[0], self.ports.reset) self.wire( self.convert(self.resetInverter.ports.O[0], magma.asyncreset), self.underlying.ports.rst_n) self.wire(self.ports.clk, self.underlying.ports.clk) # Mem core uses clk_en (essentially active low stall) self.stallInverter = FromMagma(mantle.DefineInvert(1)) self.wire(self.stallInverter.ports.I, self.ports.stall) self.wire(self.stallInverter.ports.O[0], self.underlying.ports.clk_en[0]) # we have six? features in total # 0: TILE # 1: TILE # 1-4: SMEM # Feature 0: Tile self.__features: List[CoreFeature] = [self] # Features 1-4: SRAM self.num_sram_features = self.dut.total_sets for sram_index in range(self.num_sram_features): core_feature = CoreFeature(self, sram_index + 1) core_feature.skip_compression = True self.__features.append(core_feature) # Wire the config for idx, core_feature in enumerate(self.__features): if (idx > 0): self.add_port( f"config_{idx}", magma.In( ConfigurationType(self.config_addr_width, self.config_data_width))) # port aliasing core_feature.ports["config"] = self.ports[f"config_{idx}"] self.add_port( "config", magma.In( ConfigurationType(self.config_addr_width, self.config_data_width))) if self.num_sram_features > 0: # or the signal up t = ConfigurationType(self.config_addr_width, self.config_data_width) t_names = ["config_addr", "config_data"] or_gates = {} for t_name in t_names: port_type = t[t_name] or_gate = FromMagma( mantle.DefineOr(len(self.__features), len(port_type))) or_gate.instance_name = f"OR_{t_name}_FEATURE" for idx, core_feature in enumerate(self.__features): self.wire(or_gate.ports[f"I{idx}"], core_feature.ports.config[t_name]) or_gates[t_name] = or_gate self.wire( or_gates["config_addr"].ports.O, self.underlying.ports.config_addr_in[0:self.config_addr_width]) self.wire(or_gates["config_data"].ports.O, self.underlying.ports.config_data_in) # read data out for idx, core_feature in enumerate(self.__features): if (idx > 0): # self.add_port(f"read_config_data_{idx}", self.add_port(f"read_config_data_{idx}", magma.Out(magma.Bits[self.config_data_width])) # port aliasing core_feature.ports["read_config_data"] = \ self.ports[f"read_config_data_{idx}"] # MEM Config configurations = [] # merged_configs = [] skip_cfgs = [] for cfg_info in cfgs: if cfg_info.port_name in skip_cfgs: continue if cfg_info.expl_arr: if cfg_info.port_size[0] > 1: for i in range(cfg_info.port_size[0]): configurations.append( (f"{cfg_info.port_name}_{i}", cfg_info.port_width)) else: configurations.append( (cfg_info.port_name, cfg_info.port_width)) else: configurations.append( (cfg_info.port_name, cfg_info.port_width)) # Do all the stuff for the main config main_feature = self.__features[0] for config_reg_name, width in configurations: main_feature.add_config(config_reg_name, width) if (width == 1): self.wire(main_feature.registers[config_reg_name].ports.O[0], self.underlying.ports[config_reg_name][0]) else: self.wire(main_feature.registers[config_reg_name].ports.O, self.underlying.ports[config_reg_name]) # SRAM # These should also account for num features # or_all_cfg_rd = FromMagma(mantle.DefineOr(4, 1)) if self.num_sram_features > 0: or_all_cfg_rd = FromMagma( mantle.DefineOr(self.num_sram_features, 1)) or_all_cfg_rd.instance_name = f"OR_CONFIG_WR_SRAM" or_all_cfg_wr = FromMagma( mantle.DefineOr(self.num_sram_features, 1)) or_all_cfg_wr.instance_name = f"OR_CONFIG_RD_SRAM" for sram_index in range(self.num_sram_features): core_feature = self.__features[sram_index + 1] self.add_port(f"config_en_{sram_index}", magma.In(magma.Bit)) # port aliasing core_feature.ports["config_en"] = \ self.ports[f"config_en_{sram_index}"] # Sort of a temp hack - the name is just config_data_out if self.num_sram_features == 1: self.wire(core_feature.ports.read_config_data, self.underlying.ports["config_data_out"]) else: self.wire( core_feature.ports.read_config_data, self.underlying.ports[f"config_data_out_{sram_index}"]) and_gate_en = FromMagma(mantle.DefineAnd(2, 1)) and_gate_en.instance_name = f"AND_CONFIG_EN_SRAM_{sram_index}" # also need to wire the sram signal # the config enable is the OR of the rd+wr or_gate_en = FromMagma(mantle.DefineOr(2, 1)) or_gate_en.instance_name = f"OR_CONFIG_EN_SRAM_{sram_index}" self.wire(or_gate_en.ports.I0, core_feature.ports.config.write) self.wire(or_gate_en.ports.I1, core_feature.ports.config.read) self.wire(and_gate_en.ports.I0, or_gate_en.ports.O) self.wire(and_gate_en.ports.I1[0], core_feature.ports.config_en) self.wire(and_gate_en.ports.O[0], self.underlying.ports["config_en"][sram_index]) # Still connect to the OR of all the config rd/wr self.wire(core_feature.ports.config.write, or_all_cfg_wr.ports[f"I{sram_index}"]) self.wire(core_feature.ports.config.read, or_all_cfg_rd.ports[f"I{sram_index}"]) self.wire(or_all_cfg_rd.ports.O[0], self.underlying.ports.config_read[0]) self.wire(or_all_cfg_wr.ports.O[0], self.underlying.ports.config_write[0]) self._setup_config()
def __init__(self, data_width, word_width, data_depth, num_banks, use_sram_stub): super().__init__(8, 32) self.data_width = data_width self.data_depth = data_depth self.num_banks = num_banks self.word_width = word_width if use_sram_stub: self.use_sram_stub = 1 else: self.use_sram_stub = 0 TData = magma.Bits[self.word_width] TBit = magma.Bits[1] self.add_ports( data_in=magma.In(TData), addr_in=magma.In(TData), data_out=magma.Out(TData), flush=magma.In(TBit), wen_in=magma.In(TBit), ren_in=magma.In(TBit), stall=magma.In(magma.Bits[4]), valid_out=magma.Out(TBit), switch_db=magma.In(TBit) ) # Instead of a single read_config_data, we have multiple for each # "sub"-feature of this core. # self.ports.pop("read_config_data") if (data_width, word_width, data_depth, num_banks, use_sram_stub) not in \ MemCore.__circuit_cache: wrapper = memory_core_genesis2.memory_core_wrapper param_mapping = memory_core_genesis2.param_mapping generator = wrapper.generator(param_mapping, mode="declare") circ = generator(data_width=self.data_width, data_depth=self.data_depth, word_width=self.word_width, num_banks=self.num_banks, use_sram_stub=self.use_sram_stub) MemCore.__circuit_cache[(data_width, word_width, data_depth, num_banks, use_sram_stub)] = circ else: circ = MemCore.__circuit_cache[(data_width, word_width, data_depth, num_banks, use_sram_stub)] self.underlying = FromMagma(circ) # put a 1-bit register and a mux to select the control signals control_signals = ["wen_in", "ren_in", "flush", "switch_db"] for control_signal in control_signals: # TODO: consult with Ankita to see if we can use the normal # mux here mux = MuxWrapper(2, 1, name=f"{control_signal}_sel") reg_value_name = f"{control_signal}_reg_value" reg_sel_name = f"{control_signal}_reg_sel" self.add_config(reg_value_name, 1) self.add_config(reg_sel_name, 1) self.wire(mux.ports.I[0], self.ports[control_signal]) self.wire(mux.ports.I[1], self.registers[reg_value_name].ports.O) self.wire(mux.ports.S, self.registers[reg_sel_name].ports.O) # 0 is the default wire, which takes from the routing network self.wire(mux.ports.O[0], self.underlying.ports[control_signal]) self.wire(self.ports.data_in, self.underlying.ports.data_in) self.wire(self.ports.addr_in, self.underlying.ports.addr_in) self.wire(self.ports.data_out, self.underlying.ports.data_out) self.wire(self.ports.reset, self.underlying.ports.reset) self.wire(self.ports.clk, self.underlying.ports.clk) self.wire(self.ports.valid_out[0], self.underlying.ports.valid_out) # PE core uses clk_en (essentially active low stall) self.stallInverter = FromMagma(mantle.DefineInvert(1)) self.wire(self.stallInverter.ports.I, self.ports.stall[0:1]) self.wire(self.stallInverter.ports.O[0], self.underlying.ports.clk_en) zero_signals = ( ("chain_wen_in", 1), ("chain_in", self.word_width), ) one_signals = ( ("config_read", 1), ("config_write", 1) ) # enable read and write by default for name, width in zero_signals: val = magma.bits(0, width) if width > 1 else magma.bit(0) self.wire(Const(val), self.underlying.ports[name]) for name, width in one_signals: val = magma.bits(1, width) if width > 1 else magma.bit(1) self.wire(Const(val), self.underlying.ports[name]) self.wire(Const(magma.bits(0, 24)), self.underlying.ports.config_addr[0:24]) # we have five features in total # 0: TILE # 1-4: SMEM # Feature 0: Tile self.__features: List[CoreFeature] = [self] # Features 1-4: SRAM for sram_index in range(4): core_feature = CoreFeature(self, sram_index + 1) self.__features.append(core_feature) # Wire the config for idx, core_feature in enumerate(self.__features): if(idx > 0): self.add_port(f"config_{idx}", magma.In(ConfigurationType(8, 32))) # port aliasing core_feature.ports["config"] = self.ports[f"config_{idx}"] self.add_port("config", magma.In(ConfigurationType(8, 32))) # or the signal up t = ConfigurationType(8, 32) t_names = ["config_addr", "config_data"] or_gates = {} for t_name in t_names: port_type = t[t_name] or_gate = FromMagma(mantle.DefineOr(len(self.__features), len(port_type))) or_gate.instance_name = f"OR_{t_name}_FEATURE" for idx, core_feature in enumerate(self.__features): self.wire(or_gate.ports[f"I{idx}"], core_feature.ports.config[t_name]) or_gates[t_name] = or_gate self.wire(or_gates["config_addr"].ports.O, self.underlying.ports.config_addr[24:32]) self.wire(or_gates["config_data"].ports.O, self.underlying.ports.config_data) # only the first one has config_en # self.wire(self.__features[0].ports.config.write[0], # self.underlying.ports.config_en) # read data out for idx, core_feature in enumerate(self.__features): if(idx > 0): self.add_port(f"read_config_data_{idx}", magma.Out(magma.Bits[32])) # port aliasing core_feature.ports["read_config_data"] = \ self.ports[f"read_config_data_{idx}"] # MEM config # self.wire(self.ports.read_config_data, # self.underlying.ports.read_config_data) configurations = [ ("stencil_width", 32), ("read_mode", 1), ("arbitrary_addr", 1), ("starting_addr", 32), ("iter_cnt", 32), ("dimensionality", 32), ("circular_en", 1), ("almost_count", 4), ("enable_chain", 1), ("mode", 2), ("tile_en", 1), ("chain_idx", 4), ("depth", 13) ] # Do all the stuff for the main config main_feature = self.__features[0] for config_reg_name, width in configurations: main_feature.add_config(config_reg_name, width) if(width == 1): self.wire(main_feature.registers[config_reg_name].ports.O[0], self.underlying.ports[config_reg_name]) else: self.wire(main_feature.registers[config_reg_name].ports.O, self.underlying.ports[config_reg_name]) for idx in range(8): main_feature.add_config(f"stride_{idx}", 32) main_feature.add_config(f"range_{idx}", 32) self.wire(main_feature.registers[f"stride_{idx}"].ports.O, self.underlying.ports[f"stride_{idx}"]) self.wire(main_feature.registers[f"range_{idx}"].ports.O, self.underlying.ports[f"range_{idx}"]) # SRAM for sram_index in range(4): core_feature = self.__features[sram_index + 1] self.wire(core_feature.ports.read_config_data, self.underlying.ports[f"read_data_sram_{sram_index}"]) # also need to wire the sram signal self.wire(core_feature.ports.config.write[0], self.underlying.ports["config_en_sram"][sram_index]) self._setup_config() conf_names = list(self.registers.keys()) conf_names.sort() with open("mem_cfg.txt", "w+") as cfg_dump: for idx, reg in enumerate(conf_names): write_line = f"|{reg}|{idx}|{self.registers[reg].width}||\n" cfg_dump.write(write_line)
def __init__( self, data_width=16, # CGRA Params mem_width=64, mem_depth=512, banks=1, input_iterator_support=6, # Addr Controllers output_iterator_support=6, input_config_width=16, output_config_width=16, interconnect_input_ports=2, # Connection to int interconnect_output_ports=2, mem_input_ports=1, mem_output_ports=1, use_sram_stub=1, sram_macro_info=SRAMMacroInfo("TS1N16FFCLLSBLVTC512X32M4S"), read_delay=1, # Cycle delay in read (SRAM vs Register File) rw_same_cycle=False, # Does the memory allow r+w in same cycle? agg_height=4, max_agg_schedule=16, input_max_port_sched=16, output_max_port_sched=16, align_input=1, max_line_length=128, max_tb_height=1, tb_range_max=1024, tb_range_inner_max=64, tb_sched_max=16, max_tb_stride=15, num_tb=1, tb_iterator_support=2, multiwrite=1, max_prefetch=8, config_data_width=32, config_addr_width=8, num_tiles=2, app_ctrl_depth_width=16, remove_tb=False, fifo_mode=True, add_clk_enable=True, add_flush=True, core_reset_pos=False, stcl_valid_iter=4): super().__init__(config_addr_width, config_data_width) # Capture everything to the tile object self.data_width = data_width self.mem_width = mem_width self.mem_depth = mem_depth self.banks = banks self.fw_int = int(self.mem_width / self.data_width) self.input_iterator_support = input_iterator_support self.output_iterator_support = output_iterator_support self.input_config_width = input_config_width self.output_config_width = output_config_width self.interconnect_input_ports = interconnect_input_ports self.interconnect_output_ports = interconnect_output_ports self.mem_input_ports = mem_input_ports self.mem_output_ports = mem_output_ports self.use_sram_stub = use_sram_stub self.sram_macro_info = sram_macro_info self.read_delay = read_delay self.rw_same_cycle = rw_same_cycle self.agg_height = agg_height self.max_agg_schedule = max_agg_schedule self.input_max_port_sched = input_max_port_sched self.output_max_port_sched = output_max_port_sched self.align_input = align_input self.max_line_length = max_line_length self.max_tb_height = max_tb_height self.tb_range_max = tb_range_max self.tb_range_inner_max = tb_range_inner_max self.tb_sched_max = tb_sched_max self.max_tb_stride = max_tb_stride self.num_tb = num_tb self.tb_iterator_support = tb_iterator_support self.multiwrite = multiwrite self.max_prefetch = max_prefetch self.config_data_width = config_data_width self.config_addr_width = config_addr_width self.num_tiles = num_tiles self.remove_tb = remove_tb self.fifo_mode = fifo_mode self.add_clk_enable = add_clk_enable self.add_flush = add_flush self.core_reset_pos = core_reset_pos self.app_ctrl_depth_width = app_ctrl_depth_width self.stcl_valid_iter = stcl_valid_iter # Typedefs for ease TData = magma.Bits[self.data_width] TBit = magma.Bits[1] self.__inputs = [] self.__outputs = [] # Enumerate input and output ports # (clk and reset are assumed) if self.interconnect_input_ports > 1: for i in range(self.interconnect_input_ports): self.add_port(f"addr_in_{i}", magma.In(TData)) self.__inputs.append(self.ports[f"addr_in_{i}"]) self.add_port(f"data_in_{i}", magma.In(TData)) self.__inputs.append(self.ports[f"data_in_{i}"]) self.add_port(f"wen_in_{i}", magma.In(TBit)) self.__inputs.append(self.ports[f"wen_in_{i}"]) else: self.add_port("addr_in", magma.In(TData)) self.__inputs.append(self.ports[f"addr_in"]) self.add_port("data_in", magma.In(TData)) self.__inputs.append(self.ports[f"data_in"]) self.add_port("wen_in", magma.In(TBit)) self.__inputs.append(self.ports.wen_in) if self.interconnect_output_ports > 1: for i in range(self.interconnect_output_ports): self.add_port(f"data_out_{i}", magma.Out(TData)) self.__outputs.append(self.ports[f"data_out_{i}"]) self.add_port(f"ren_in_{i}", magma.In(TBit)) self.__inputs.append(self.ports[f"ren_in_{i}"]) self.add_port(f"valid_out_{i}", magma.Out(TBit)) self.__outputs.append(self.ports[f"valid_out_{i}"]) # Chaining self.add_port(f"chain_valid_in_{i}", magma.In(TBit)) self.__inputs.append(self.ports[f"chain_valid_in_{i}"]) self.add_port(f"chain_data_in_{i}", magma.In(TData)) self.__inputs.append(self.ports[f"chain_data_in_{i}"]) self.add_port(f"chain_data_out_{i}", magma.Out(TData)) self.__outputs.append(self.ports[f"chain_data_out_{i}"]) self.add_port(f"chain_valid_out_{i}", magma.Out(TBit)) self.__outputs.append(self.ports[f"chain_valid_out_{i}"]) else: self.add_port("data_out", magma.Out(TData)) self.__outputs.append(self.ports[f"data_out"]) self.add_port(f"ren_in", magma.In(TBit)) self.__inputs.append(self.ports[f"ren_in"]) self.add_port(f"valid_out", magma.Out(TBit)) self.__outputs.append(self.ports[f"valid_out"]) self.add_port(f"chain_valid_in", magma.In(TBit)) self.__inputs.append(self.ports[f"chain_valid_in"]) self.add_port(f"chain_data_in", magma.In(TData)) self.__inputs.append(self.ports[f"chain_data_in"]) self.add_port(f"chain_data_out", magma.Out(TData)) self.__outputs.append(self.ports[f"chain_data_out"]) self.add_port(f"chain_valid_out", magma.Out(TBit)) self.__outputs.append(self.ports[f"chain_valid_out"]) self.add_ports(flush=magma.In(TBit), full=magma.Out(TBit), empty=magma.Out(TBit), stall=magma.In(TBit), sram_ready_out=magma.Out(TBit)) self.__inputs.append(self.ports.flush) # self.__inputs.append(self.ports.stall) self.__outputs.append(self.ports.full) self.__outputs.append(self.ports.empty) self.__outputs.append(self.ports.sram_ready_out) cache_key = (self.data_width, self.mem_width, self.mem_depth, self.banks, self.input_iterator_support, self.output_iterator_support, self.interconnect_input_ports, self.interconnect_output_ports, self.use_sram_stub, self.sram_macro_info, self.read_delay, self.rw_same_cycle, self.agg_height, self.max_agg_schedule, self.input_max_port_sched, self.output_max_port_sched, self.align_input, self.max_line_length, self.max_tb_height, self.tb_range_max, self.tb_sched_max, self.max_tb_stride, self.num_tb, self.tb_iterator_support, self.multiwrite, self.max_prefetch, self.config_data_width, self.config_addr_width, self.num_tiles, self.remove_tb, self.fifo_mode, self.stcl_valid_iter, self.add_clk_enable, self.add_flush, self.app_ctrl_depth_width) # Check for circuit caching if cache_key not in MemCore.__circuit_cache: # Instantiate core object here - will only use the object representation to # query for information. The circuit representation will be cached and retrieved # in the following steps. lt_dut = LakeTop( data_width=self.data_width, mem_width=self.mem_width, mem_depth=self.mem_depth, banks=self.banks, input_iterator_support=self.input_iterator_support, output_iterator_support=self.output_iterator_support, input_config_width=self.input_config_width, output_config_width=self.output_config_width, interconnect_input_ports=self.interconnect_input_ports, interconnect_output_ports=self.interconnect_output_ports, use_sram_stub=self.use_sram_stub, sram_macro_info=self.sram_macro_info, read_delay=self.read_delay, rw_same_cycle=self.rw_same_cycle, agg_height=self.agg_height, max_agg_schedule=self.max_agg_schedule, input_max_port_sched=self.input_max_port_sched, output_max_port_sched=self.output_max_port_sched, align_input=self.align_input, max_line_length=self.max_line_length, max_tb_height=self.max_tb_height, tb_range_max=self.tb_range_max, tb_range_inner_max=self.tb_range_inner_max, tb_sched_max=self.tb_sched_max, max_tb_stride=self.max_tb_stride, num_tb=self.num_tb, tb_iterator_support=self.tb_iterator_support, multiwrite=self.multiwrite, max_prefetch=self.max_prefetch, config_data_width=self.config_data_width, config_addr_width=self.config_addr_width, num_tiles=self.num_tiles, app_ctrl_depth_width=self.app_ctrl_depth_width, remove_tb=self.remove_tb, fifo_mode=self.fifo_mode, add_clk_enable=self.add_clk_enable, add_flush=self.add_flush, stcl_valid_iter=self.stcl_valid_iter) change_sram_port_pass = change_sram_port_names( use_sram_stub, sram_macro_info) circ = kts.util.to_magma( lt_dut, flatten_array=True, check_multiple_driver=False, optimize_if=False, check_flip_flop_always_ff=False, additional_passes={"change_sram_port": change_sram_port_pass}) MemCore.__circuit_cache[cache_key] = (circ, lt_dut) else: circ, lt_dut = MemCore.__circuit_cache[cache_key] # Save as underlying circuit object self.underlying = FromMagma(circ) self.chain_idx_bits = max(1, kts.clog2(self.num_tiles)) # put a 1-bit register and a mux to select the control signals # TODO: check if enable_chain_output needs to be here? I don't think so? control_signals = [("wen_in", self.interconnect_input_ports), ("ren_in", self.interconnect_output_ports), ("flush", 1), ("chain_valid_in", self.interconnect_output_ports)] for control_signal, width in control_signals: # TODO: consult with Ankita to see if we can use the normal # mux here if width == 1: mux = MuxWrapper(2, 1, name=f"{control_signal}_sel") reg_value_name = f"{control_signal}_reg_value" reg_sel_name = f"{control_signal}_reg_sel" self.add_config(reg_value_name, 1) self.add_config(reg_sel_name, 1) self.wire(mux.ports.I[0], self.ports[control_signal]) self.wire(mux.ports.I[1], self.registers[reg_value_name].ports.O) self.wire(mux.ports.S, self.registers[reg_sel_name].ports.O) # 0 is the default wire, which takes from the routing network self.wire(mux.ports.O[0], self.underlying.ports[control_signal][0]) else: for i in range(width): mux = MuxWrapper(2, 1, name=f"{control_signal}_{i}_sel") reg_value_name = f"{control_signal}_{i}_reg_value" reg_sel_name = f"{control_signal}_{i}_reg_sel" self.add_config(reg_value_name, 1) self.add_config(reg_sel_name, 1) self.wire(mux.ports.I[0], self.ports[f"{control_signal}_{i}"]) self.wire(mux.ports.I[1], self.registers[reg_value_name].ports.O) self.wire(mux.ports.S, self.registers[reg_sel_name].ports.O) # 0 is the default wire, which takes from the routing network self.wire(mux.ports.O[0], self.underlying.ports[control_signal][i]) if self.interconnect_input_ports > 1: for i in range(self.interconnect_input_ports): self.wire(self.ports[f"data_in_{i}"], self.underlying.ports[f"data_in_{i}"]) self.wire(self.ports[f"addr_in_{i}"], self.underlying.ports[f"addr_in_{i}"]) else: self.wire(self.ports.addr_in, self.underlying.ports.addr_in) self.wire(self.ports.data_in, self.underlying.ports.data_in) if self.interconnect_output_ports > 1: for i in range(self.interconnect_output_ports): self.wire(self.ports[f"data_out_{i}"], self.underlying.ports[f"data_out_{i}"]) self.wire(self.ports[f"chain_data_in_{i}"], self.underlying.ports[f"chain_data_in_{i}"]) self.wire(self.ports[f"chain_data_out_{i}"], self.underlying.ports[f"chain_data_out_{i}"]) else: self.wire(self.ports.data_out, self.underlying.ports.data_out) self.wire(self.ports.chain_data_in, self.underlying.ports.chain_data_in) self.wire(self.ports.chain_data_out, self.underlying.ports.chain_data_out) # Need to invert this self.resetInverter = FromMagma(mantle.DefineInvert(1)) self.wire(self.resetInverter.ports.I[0], self.ports.reset) self.wire(self.resetInverter.ports.O[0], self.underlying.ports.rst_n) self.wire(self.ports.clk, self.underlying.ports.clk) if self.interconnect_output_ports == 1: self.wire(self.ports.valid_out[0], self.underlying.ports.valid_out[0]) self.wire(self.ports.chain_valid_out[0], self.underlying.ports.chain_valid_out[0]) else: for j in range(self.interconnect_output_ports): self.wire(self.ports[f"valid_out_{j}"][0], self.underlying.ports.valid_out[j]) self.wire(self.ports[f"chain_valid_out_{j}"][0], self.underlying.ports.chain_valid_out[j]) self.wire(self.ports.empty[0], self.underlying.ports.empty[0]) self.wire(self.ports.full[0], self.underlying.ports.full[0]) # PE core uses clk_en (essentially active low stall) self.stallInverter = FromMagma(mantle.DefineInvert(1)) self.wire(self.stallInverter.ports.I, self.ports.stall) self.wire(self.stallInverter.ports.O[0], self.underlying.ports.clk_en[0]) self.wire(self.ports.sram_ready_out[0], self.underlying.ports.sram_ready_out[0]) # we have six? features in total # 0: TILE # 1: TILE # 1-4: SMEM # Feature 0: Tile self.__features: List[CoreFeature] = [self] # Features 1-4: SRAM self.num_sram_features = lt_dut.total_sets for sram_index in range(self.num_sram_features): core_feature = CoreFeature(self, sram_index + 1) self.__features.append(core_feature) # Wire the config for idx, core_feature in enumerate(self.__features): if (idx > 0): self.add_port(f"config_{idx}", magma.In(ConfigurationType(8, 32))) # port aliasing core_feature.ports["config"] = self.ports[f"config_{idx}"] self.add_port("config", magma.In(ConfigurationType(8, 32))) # or the signal up t = ConfigurationType(8, 32) t_names = ["config_addr", "config_data"] or_gates = {} for t_name in t_names: port_type = t[t_name] or_gate = FromMagma( mantle.DefineOr(len(self.__features), len(port_type))) or_gate.instance_name = f"OR_{t_name}_FEATURE" for idx, core_feature in enumerate(self.__features): self.wire(or_gate.ports[f"I{idx}"], core_feature.ports.config[t_name]) or_gates[t_name] = or_gate self.wire(or_gates["config_addr"].ports.O, self.underlying.ports.config_addr_in[0:8]) self.wire(or_gates["config_data"].ports.O, self.underlying.ports.config_data_in) # read data out for idx, core_feature in enumerate(self.__features): if (idx > 0): # self.add_port(f"read_config_data_{idx}", self.add_port(f"read_config_data_{idx}", magma.Out(magma.Bits[32])) # port aliasing core_feature.ports["read_config_data"] = \ self.ports[f"read_config_data_{idx}"] # MEM Config configurations = [("tile_en", 1), ("fifo_ctrl_fifo_depth", 16), ("mode", 2), ("enable_chain_output", 1), ("enable_chain_input", 1)] # ("stencil_width", 16), NOT YET merged_configs = [] merged_in_sched = [] merged_out_sched = [] # Add config registers to configurations # TODO: Have lake spit this information out automatically from the wrapper configurations.append((f"chain_idx_input", self.chain_idx_bits)) configurations.append((f"chain_idx_output", self.chain_idx_bits)) for i in range(self.interconnect_input_ports): configurations.append((f"strg_ub_agg_align_{i}_line_length", kts.clog2(self.max_line_length))) configurations.append((f"strg_ub_agg_in_{i}_in_period", kts.clog2(self.input_max_port_sched))) # num_bits_in_sched = kts.clog2(self.agg_height) # sched_per_feat = math.floor(self.config_data_width / num_bits_in_sched) # new_width = num_bits_in_sched * sched_per_feat # feat_num = 0 # num_feats_merge = math.ceil(self.input_max_port_sched / sched_per_feat) # for k in range(num_feats_merge): # num_here = sched_per_feat # if self.input_max_port_sched - (k * sched_per_feat) < sched_per_feat: # num_here = self.input_max_port_sched - (k * sched_per_feat) # merged_configs.append((f"strg_ub_agg_in_{i}_in_sched_merged_{k * sched_per_feat}", # num_here * num_bits_in_sched, num_here)) for j in range(self.input_max_port_sched): configurations.append((f"strg_ub_agg_in_{i}_in_sched_{j}", kts.clog2(self.agg_height))) configurations.append((f"strg_ub_agg_in_{i}_out_period", kts.clog2(self.input_max_port_sched))) for j in range(self.output_max_port_sched): configurations.append((f"strg_ub_agg_in_{i}_out_sched_{j}", kts.clog2(self.agg_height))) configurations.append((f"strg_ub_app_ctrl_write_depth_wo_{i}", self.app_ctrl_depth_width)) configurations.append((f"strg_ub_app_ctrl_write_depth_ss_{i}", self.app_ctrl_depth_width)) configurations.append( (f"strg_ub_app_ctrl_coarse_write_depth_wo_{i}", self.app_ctrl_depth_width)) configurations.append( (f"strg_ub_app_ctrl_coarse_write_depth_ss_{i}", self.app_ctrl_depth_width)) configurations.append( (f"strg_ub_input_addr_ctrl_address_gen_{i}_dimensionality", 1 + kts.clog2(self.input_iterator_support))) configurations.append( (f"strg_ub_input_addr_ctrl_address_gen_{i}_starting_addr", self.input_config_width)) for j in range(self.input_iterator_support): configurations.append( (f"strg_ub_input_addr_ctrl_address_gen_{i}_ranges_{j}", self.input_config_width)) configurations.append( (f"strg_ub_input_addr_ctrl_address_gen_{i}_strides_{j}", self.input_config_width)) configurations.append( (f"strg_ub_app_ctrl_prefill", self.interconnect_output_ports)) configurations.append((f"strg_ub_app_ctrl_coarse_prefill", self.interconnect_output_ports)) for i in range(self.stcl_valid_iter): configurations.append((f"strg_ub_app_ctrl_ranges_{i}", 16)) configurations.append((f"strg_ub_app_ctrl_threshold_{i}", 16)) for i in range(self.interconnect_output_ports): configurations.append((f"strg_ub_app_ctrl_input_port_{i}", kts.clog2(self.interconnect_input_ports))) configurations.append((f"strg_ub_app_ctrl_read_depth_{i}", self.app_ctrl_depth_width)) configurations.append((f"strg_ub_app_ctrl_coarse_input_port_{i}", kts.clog2(self.interconnect_input_ports))) configurations.append((f"strg_ub_app_ctrl_coarse_read_depth_{i}", self.app_ctrl_depth_width)) configurations.append( (f"strg_ub_output_addr_ctrl_address_gen_{i}_dimensionality", 1 + kts.clog2(self.output_iterator_support))) configurations.append( (f"strg_ub_output_addr_ctrl_address_gen_{i}_starting_addr", self.output_config_width)) for j in range(self.output_iterator_support): configurations.append( (f"strg_ub_output_addr_ctrl_address_gen_{i}_ranges_{j}", self.output_config_width)) configurations.append( (f"strg_ub_output_addr_ctrl_address_gen_{i}_strides_{j}", self.output_config_width)) configurations.append((f"strg_ub_pre_fetch_{i}_input_latency", kts.clog2(self.max_prefetch) + 1)) configurations.append((f"strg_ub_sync_grp_sync_group_{i}", self.interconnect_output_ports)) configurations.append( (f"strg_ub_rate_matched_{i}", 1 + kts.clog2(self.interconnect_input_ports))) for j in range(self.num_tb): configurations.append( (f"strg_ub_tba_{i}_tb_{j}_dimensionality", 2)) num_indices_bits = 1 + kts.clog2(self.fw_int) indices_per_feat = math.floor(self.config_data_width / num_indices_bits) new_width = num_indices_bits * indices_per_feat feat_num = 0 num_feats_merge = math.ceil(self.tb_range_inner_max / indices_per_feat) for k in range(num_feats_merge): num_idx = indices_per_feat if (self.tb_range_inner_max - (k * indices_per_feat)) < indices_per_feat: num_idx = self.tb_range_inner_max - (k * indices_per_feat) merged_configs.append(( f"strg_ub_tba_{i}_tb_{j}_indices_merged_{k * indices_per_feat}", num_idx * num_indices_bits, num_idx)) # for k in range(self.tb_range_inner_max): # configurations.append((f"strg_ub_tba_{i}_tb_{j}_indices_{k}", kts.clog2(self.fw_int) + 1)) configurations.append((f"strg_ub_tba_{i}_tb_{j}_range_inner", kts.clog2(self.tb_range_inner_max))) configurations.append((f"strg_ub_tba_{i}_tb_{j}_range_outer", kts.clog2(self.tb_range_max))) configurations.append((f"strg_ub_tba_{i}_tb_{j}_stride", kts.clog2(self.max_tb_stride))) configurations.append((f"strg_ub_tba_{i}_tb_{j}_tb_height", max(1, kts.clog2(self.num_tb)))) configurations.append((f"strg_ub_tba_{i}_tb_{j}_starting_addr", max(1, kts.clog2(self.fw_int)))) # Do all the stuff for the main config main_feature = self.__features[0] for config_reg_name, width in configurations: main_feature.add_config(config_reg_name, width) if (width == 1): self.wire(main_feature.registers[config_reg_name].ports.O[0], self.underlying.ports[config_reg_name][0]) else: self.wire(main_feature.registers[config_reg_name].ports.O, self.underlying.ports[config_reg_name]) for config_reg_name, width, num_merged in merged_configs: main_feature.add_config(config_reg_name, width) token_under = config_reg_name.split("_") base_name = config_reg_name.split("_merged")[0] base_indices = int(config_reg_name.split("_merged_")[1]) num_bits = width // num_merged for i in range(num_merged): self.wire( main_feature.registers[config_reg_name].ports. O[i * num_bits:(i + 1) * num_bits], self.underlying.ports[f"{base_name}_{base_indices + i}"]) # SRAM # These should also account for num features # or_all_cfg_rd = FromMagma(mantle.DefineOr(4, 1)) or_all_cfg_rd = FromMagma(mantle.DefineOr(self.num_sram_features, 1)) or_all_cfg_rd.instance_name = f"OR_CONFIG_WR_SRAM" or_all_cfg_wr = FromMagma(mantle.DefineOr(self.num_sram_features, 1)) or_all_cfg_wr.instance_name = f"OR_CONFIG_RD_SRAM" for sram_index in range(self.num_sram_features): core_feature = self.__features[sram_index + 1] self.add_port(f"config_en_{sram_index}", magma.In(magma.Bit)) # port aliasing core_feature.ports["config_en"] = \ self.ports[f"config_en_{sram_index}"] self.wire(core_feature.ports.read_config_data, self.underlying.ports[f"config_data_out_{sram_index}"]) # also need to wire the sram signal # the config enable is the OR of the rd+wr or_gate_en = FromMagma(mantle.DefineOr(2, 1)) or_gate_en.instance_name = f"OR_CONFIG_EN_SRAM_{sram_index}" self.wire(or_gate_en.ports.I0, core_feature.ports.config.write) self.wire(or_gate_en.ports.I1, core_feature.ports.config.read) self.wire(core_feature.ports.config_en, self.underlying.ports["config_en"][sram_index]) # Still connect to the OR of all the config rd/wr self.wire(core_feature.ports.config.write, or_all_cfg_wr.ports[f"I{sram_index}"]) self.wire(core_feature.ports.config.read, or_all_cfg_rd.ports[f"I{sram_index}"]) self.wire(or_all_cfg_rd.ports.O[0], self.underlying.ports.config_read[0]) self.wire(or_all_cfg_wr.ports.O[0], self.underlying.ports.config_write[0]) self._setup_config() conf_names = list(self.registers.keys()) conf_names.sort() with open("mem_cfg.txt", "w+") as cfg_dump: for idx, reg in enumerate(conf_names): write_line = f"|{reg}|{idx}|{self.registers[reg].width}||\n" cfg_dump.write(write_line) with open("mem_synth.txt", "w+") as cfg_dump: for idx, reg in enumerate(conf_names): write_line = f"{reg}\n" cfg_dump.write(write_line)
def __init__(self, data_width, data_depth): super().__init__(8, 32) self.data_width = data_width self.data_depth = data_depth TData = magma.Bits[self.data_width] TBit = magma.Bits[1] self.add_ports(data_in=magma.In(TData), addr_in=magma.In(TData), data_out=magma.Out(TData), flush=magma.In(TBit), wen_in=magma.In(TBit), ren_in=magma.In(TBit), stall=magma.In(magma.Bits[4])) # Instead of a single read_config_data, we have multiple for each # "sub"-feature of this core. self.ports.pop("read_config_data") wrapper = memory_core_genesis2.memory_core_wrapper param_mapping = memory_core_genesis2.param_mapping generator = wrapper.generator(param_mapping, mode="declare") circ = generator(data_width=self.data_width, data_depth=self.data_depth) self.underlying = FromMagma(circ) self.wire(self.ports.data_in, self.underlying.ports.data_in) self.wire(self.ports.addr_in, self.underlying.ports.addr_in) self.wire(self.ports.data_out, self.underlying.ports.data_out) self.wire(self.ports.reset, self.underlying.ports.reset) self.wire(self.ports.flush[0], self.underlying.ports.flush) self.wire(self.ports.wen_in[0], self.underlying.ports.wen_in) self.wire(self.ports.ren_in[0], self.underlying.ports.ren_in) # PE core uses clk_en (essentially active low stall) self.stallInverter = FromMagma(mantle.DefineInvert(1)) self.wire(self.stallInverter.ports.I, self.ports.stall[0:1]) self.wire(self.stallInverter.ports.O[0], self.underlying.ports.clk_en) # TODO(rsetaluri): Actually wire these inputs. zero_signals = ( ("config_en_linebuf", 1), ("chain_wen_in", 1), ("chain_in", self.data_width), ) one_signals = ( ("config_read", 1), ("config_write", 1), ) # enable read and write by default for name, width in zero_signals: val = magma.bits(0, width) if width > 1 else magma.bit(0) self.wire(Const(val), self.underlying.ports[name]) for name, width in one_signals: val = magma.bits(1, width) if width > 1 else magma.bit(1) self.wire(Const(val), self.underlying.ports[name]) self.wire(Const(magma.bits(0, 24)), self.underlying.ports.config_addr[0:24]) # we have five features in total # 0: LINEBUF # 1-4: SMEM # current setup is already in line buffer mode, so we pass self in # notice that config_en_linebuf is to change the address in the # line buffer mode, which is not used in practice self.__features: List[CoreFeature] = [CoreFeature(self, 0)] for sram_index in range(4): core_feature = CoreFeature(self, sram_index + 1) self.__features.append(core_feature) for idx, core_feature in enumerate(self.__features): self.add_port(f"config_{idx}", magma.In(ConfigurationType(8, 32))) # port aliasing core_feature.ports["config"] = self.ports[f"config_{idx}"] # or the signal up t = ConfigurationType(8, 32) t_names = ["config_addr", "config_data"] or_gates = {} for t_name in t_names: port_type = t[t_name] or_gate = FromMagma( mantle.DefineOr(len(self.__features), len(port_type))) or_gate.instance_name = f"OR_{t_name}_FEATURE" for idx, core_feature in enumerate(self.__features): self.wire(or_gate.ports[f"I{idx}"], core_feature.ports.config[t_name]) or_gates[t_name] = or_gate self.wire(or_gates["config_addr"].ports.O, self.underlying.ports.config_addr[24:32]) self.wire(or_gates["config_data"].ports.O, self.underlying.ports.config_data) # only the first one has config_en self.wire(self.__features[0].ports.config.write[0], self.underlying.ports.config_en) # read data out for idx, core_feature in enumerate(self.__features): self.add_port(f"read_config_data_{idx}", magma.Out(magma.Bits[32])) # port aliasing core_feature.ports["read_config_data"] = \ self.ports[f"read_config_data_{idx}"] # MEM config self.wire(self.ports.read_config_data_0, self.underlying.ports.read_data) # SRAM for sram_index in range(4): core_feature = self.__features[sram_index + 1] self.wire(core_feature.ports.read_config_data, self.underlying.ports[f"read_data_sram_{sram_index}"]) # also need to wire the sram signal self.add_port(f"config_en_{sram_index}", magma.In(magma.Bit)) # port aliasing core_feature.ports["config_en"] = \ self.ports[f"config_en_{sram_index}"] self.wire(self.underlying.ports["config_en_sram"][sram_index], self.ports[f"config_en_{sram_index}"])