def __init__(
            self,
            data_width=16,  # CGRA Params
            mem_width=16,
            mem_depth=256,
            banks=1,
            input_iterator_support=6,  # Addr Controllers
            output_iterator_support=6,
            input_config_width=16,
            output_config_width=16,
            interconnect_input_ports=1,  # Connection to int
            interconnect_output_ports=1,
            mem_input_ports=1,
            mem_output_ports=1,
            use_sram_stub=True,
            sram_macro_info=SRAMMacroInfo(),
            read_delay=1,  # Cycle delay in read (SRAM vs Register File)
            rw_same_cycle=True,  # Does the memory allow r+w in same cycle?
            agg_height=4,
            tb_sched_max=16,
            config_data_width=32,
            config_addr_width=8,
            num_tiles=1,
            remove_tb=False,
            fifo_mode=False,
            add_clk_enable=True,
            add_flush=True,
            override_name=None):

        # name
        if override_name:
            self.__name = override_name + "Core"
            lake_name = override_name
        else:
            self.__name = "MemCore"
            lake_name = "LakeTop"

        super().__init__(config_addr_width, config_data_width)

        # Capture everything to the tile object
        self.data_width = data_width
        self.mem_width = mem_width
        self.mem_depth = mem_depth
        self.banks = banks
        self.fw_int = int(self.mem_width / self.data_width)
        self.input_iterator_support = input_iterator_support
        self.output_iterator_support = output_iterator_support
        self.input_config_width = input_config_width
        self.output_config_width = output_config_width
        self.interconnect_input_ports = interconnect_input_ports
        self.interconnect_output_ports = interconnect_output_ports
        self.mem_input_ports = mem_input_ports
        self.mem_output_ports = mem_output_ports
        self.use_sram_stub = use_sram_stub
        self.sram_macro_info = sram_macro_info
        self.read_delay = read_delay
        self.rw_same_cycle = rw_same_cycle
        self.agg_height = agg_height
        self.config_data_width = config_data_width
        self.config_addr_width = config_addr_width
        self.num_tiles = num_tiles
        self.remove_tb = remove_tb
        self.fifo_mode = fifo_mode
        self.add_clk_enable = add_clk_enable
        self.add_flush = add_flush
        # self.app_ctrl_depth_width = app_ctrl_depth_width
        # self.stcl_valid_iter = stcl_valid_iter

        # Typedefs for ease
        TData = magma.Bits[self.data_width]
        TBit = magma.Bits[1]

        self.__inputs = []
        self.__outputs = []

        # cache_key = (self.data_width, self.mem_width, self.mem_depth, self.banks,
        #              self.input_iterator_support, self.output_iterator_support,
        #              self.interconnect_input_ports, self.interconnect_output_ports,
        #              self.use_sram_stub, self.sram_macro_info, self.read_delay,
        #              self.rw_same_cycle, self.agg_height, self.max_agg_schedule,
        #              self.input_max_port_sched, self.output_max_port_sched,
        #              self.align_input, self.max_line_length, self.max_tb_height,
        #              self.tb_range_max, self.tb_sched_max, self.max_tb_stride,
        #              self.num_tb, self.tb_iterator_support, self.multiwrite,
        #              self.max_prefetch, self.config_data_width, self.config_addr_width,
        #              self.num_tiles, self.remove_tb, self.fifo_mode, self.stcl_valid_iter,
        #              self.add_clk_enable, self.add_flush, self.app_ctrl_depth_width)

        cache_key = (self.data_width, self.mem_width, self.mem_depth,
                     self.banks, self.input_iterator_support,
                     self.output_iterator_support,
                     self.interconnect_input_ports,
                     self.interconnect_output_ports, self.use_sram_stub,
                     self.sram_macro_info, self.read_delay, self.rw_same_cycle,
                     self.agg_height, self.config_data_width,
                     self.config_addr_width, self.num_tiles, self.remove_tb,
                     self.fifo_mode, self.add_clk_enable, self.add_flush)

        # Check for circuit caching
        if cache_key not in MemCore.__circuit_cache:

            # Instantiate core object here - will only use the object representation to
            # query for information. The circuit representation will be cached and retrieved
            # in the following steps.
            # lt_dut = LakeTop(data_width=self.data_width,
            #                  mem_width=self.mem_width,
            #                  mem_depth=self.mem_depth,
            #                  banks=self.banks,
            #                  input_iterator_support=self.input_iterator_support,
            #                  output_iterator_support=self.output_iterator_support,
            #                  input_config_width=self.input_config_width,
            #                  output_config_width=self.output_config_width,
            #                  interconnect_input_ports=self.interconnect_input_ports,
            #                  interconnect_output_ports=self.interconnect_output_ports,
            #                  use_sram_stub=self.use_sram_stub,
            #                  sram_macro_info=self.sram_macro_info,
            #                  read_delay=self.read_delay,
            #                  rw_same_cycle=self.rw_same_cycle,
            #                  agg_height=self.agg_height,
            #                  max_agg_schedule=self.max_agg_schedule,
            #                  input_max_port_sched=self.input_max_port_sched,
            #                  output_max_port_sched=self.output_max_port_sched,
            #                  align_input=self.align_input,
            #                  max_line_length=self.max_line_length,
            #                  max_tb_height=self.max_tb_height,
            #                  tb_range_max=self.tb_range_max,
            #                  tb_range_inner_max=self.tb_range_inner_max,
            #                  tb_sched_max=self.tb_sched_max,
            #                  max_tb_stride=self.max_tb_stride,
            #                  num_tb=self.num_tb,
            #                  tb_iterator_support=self.tb_iterator_support,
            #                  multiwrite=self.multiwrite,
            #                  max_prefetch=self.max_prefetch,
            #                  config_data_width=self.config_data_width,
            #                  config_addr_width=self.config_addr_width,
            #                  num_tiles=self.num_tiles,
            #                  app_ctrl_depth_width=self.app_ctrl_depth_width,
            #                  remove_tb=self.remove_tb,
            #                  fifo_mode=self.fifo_mode,
            #                  add_clk_enable=self.add_clk_enable,
            #                  add_flush=self.add_flush,
            #                  stcl_valid_iter=self.stcl_valid_iter)

            lt_dut = LakeTop(
                data_width=self.data_width,
                mem_width=self.mem_width,
                mem_depth=self.mem_depth,
                banks=self.banks,
                input_iterator_support=self.input_iterator_support,
                output_iterator_support=self.output_iterator_support,
                input_config_width=self.input_config_width,
                output_config_width=self.output_config_width,
                interconnect_input_ports=self.interconnect_input_ports,
                interconnect_output_ports=self.interconnect_output_ports,
                use_sram_stub=self.use_sram_stub,
                sram_macro_info=self.sram_macro_info,
                read_delay=self.read_delay,
                rw_same_cycle=self.rw_same_cycle,
                agg_height=self.agg_height,
                config_data_width=self.config_data_width,
                config_addr_width=self.config_addr_width,
                num_tiles=self.num_tiles,
                remove_tb=self.remove_tb,
                fifo_mode=self.fifo_mode,
                add_clk_enable=self.add_clk_enable,
                add_flush=self.add_flush,
                name=lake_name,
                gen_addr=False)

            change_sram_port_pass = change_sram_port_names(
                use_sram_stub, sram_macro_info)
            circ = kts.util.to_magma(
                lt_dut,
                flatten_array=True,
                check_multiple_driver=False,
                optimize_if=False,
                check_flip_flop_always_ff=False,
                additional_passes={"change_sram_port": change_sram_port_pass})
            MemCore.__circuit_cache[cache_key] = (circ, lt_dut)
        else:
            circ, lt_dut = MemCore.__circuit_cache[cache_key]

        # Save as underlying circuit object
        self.underlying = FromMagma(circ)

        # Enumerate input and output ports
        # (clk and reset are assumed)
        core_interface = get_interface(lt_dut)
        cfgs = extract_top_config(lt_dut)
        assert len(cfgs) > 0, "No configs?"

        # We basically add in the configuration bus differently
        # than the other ports...
        skip_names = [
            "config_data_in", "config_write", "config_addr_in",
            "config_data_out", "config_read", "config_en", "clk_en"
        ]

        # Create a list of signals that will be able to be
        # hardwired to a constant at runtime...
        control_signals = []
        # The rest of the signals to wire to the underlying representation...
        other_signals = []

        # for port_name, port_size, port_width, is_ctrl, port_dir, explicit_array in core_interface:
        for io_info in core_interface:
            if io_info.port_name in skip_names:
                continue
            ind_ports = io_info.port_width
            intf_type = TBit
            # For our purposes, an explicit array means the inner data HAS to be 16 bits
            if io_info.expl_arr:
                ind_ports = io_info.port_size[0]
                intf_type = TData
            dir_type = magma.In
            app_list = self.__inputs
            if io_info.port_dir == "PortDirection.Out":
                dir_type = magma.Out
                app_list = self.__outputs
            if ind_ports > 1:
                for i in range(ind_ports):
                    self.add_port(f"{io_info.port_name}_{i}",
                                  dir_type(intf_type))
                    app_list.append(self.ports[f"{io_info.port_name}_{i}"])
            else:
                self.add_port(io_info.port_name, dir_type(intf_type))
                app_list.append(self.ports[io_info.port_name])

            # classify each signal for wiring to underlying representation...
            if io_info.is_ctrl:
                control_signals.append((io_info.port_name, io_info.port_width))
            else:
                if ind_ports > 1:
                    for i in range(ind_ports):
                        other_signals.append(
                            (f"{io_info.port_name}_{i}", io_info.port_dir,
                             io_info.expl_arr, i, io_info.port_name))
                else:
                    other_signals.append(
                        (io_info.port_name, io_info.port_dir, io_info.expl_arr,
                         0, io_info.port_name))

        assert (len(self.__outputs) > 0)

        # We call clk_en stall at this level for legacy reasons????
        self.add_ports(stall=magma.In(TBit), )

        self.chain_idx_bits = max(1, kts.clog2(self.num_tiles))

        # put a 1-bit register and a mux to select the control signals
        for control_signal, width in control_signals:
            if width == 1:
                mux = MuxWrapper(2, 1, name=f"{control_signal}_sel")
                reg_value_name = f"{control_signal}_reg_value"
                reg_sel_name = f"{control_signal}_reg_sel"
                self.add_config(reg_value_name, 1)
                self.add_config(reg_sel_name, 1)
                self.wire(mux.ports.I[0], self.ports[control_signal])
                self.wire(mux.ports.I[1],
                          self.registers[reg_value_name].ports.O)
                self.wire(mux.ports.S, self.registers[reg_sel_name].ports.O)
                # 0 is the default wire, which takes from the routing network
                self.wire(mux.ports.O[0],
                          self.underlying.ports[control_signal][0])
            else:
                for i in range(width):
                    mux = MuxWrapper(2, 1, name=f"{control_signal}_{i}_sel")
                    reg_value_name = f"{control_signal}_{i}_reg_value"
                    reg_sel_name = f"{control_signal}_{i}_reg_sel"
                    self.add_config(reg_value_name, 1)
                    self.add_config(reg_sel_name, 1)
                    self.wire(mux.ports.I[0],
                              self.ports[f"{control_signal}_{i}"])
                    self.wire(mux.ports.I[1],
                              self.registers[reg_value_name].ports.O)
                    self.wire(mux.ports.S,
                              self.registers[reg_sel_name].ports.O)
                    # 0 is the default wire, which takes from the routing network
                    self.wire(mux.ports.O[0],
                              self.underlying.ports[control_signal][i])

        # Wire the other signals up...
        for pname, pdir, expl_arr, ind, uname in other_signals:
            # If we are in an explicit array moment, use the given wire name...
            if expl_arr is False:
                # And if not, use the index
                self.wire(self.ports[pname][0],
                          self.underlying.ports[uname][ind])
            else:
                self.wire(self.ports[pname], self.underlying.ports[pname])

        # CLK, RESET, and STALL PER STANDARD PROCEDURE

        # Need to invert this
        self.resetInverter = FromMagma(mantle.DefineInvert(1))
        self.wire(self.resetInverter.ports.I[0], self.ports.reset)
        self.wire(self.resetInverter.ports.O[0], self.underlying.ports.rst_n)
        self.wire(self.ports.clk, self.underlying.ports.clk)

        # Mem core uses clk_en (essentially active low stall)
        self.stallInverter = FromMagma(mantle.DefineInvert(1))
        self.wire(self.stallInverter.ports.I, self.ports.stall)
        self.wire(self.stallInverter.ports.O[0],
                  self.underlying.ports.clk_en[0])

        # we have six? features in total
        # 0:    TILE
        # 1:    TILE
        # 1-4:  SMEM
        # Feature 0: Tile
        self.__features: List[CoreFeature] = [self]
        # Features 1-4: SRAM
        self.num_sram_features = lt_dut.total_sets
        for sram_index in range(self.num_sram_features):
            core_feature = CoreFeature(self, sram_index + 1)
            self.__features.append(core_feature)

        # Wire the config
        for idx, core_feature in enumerate(self.__features):
            if (idx > 0):
                self.add_port(
                    f"config_{idx}",
                    magma.In(
                        ConfigurationType(self.config_addr_width,
                                          self.config_data_width)))
                # port aliasing
                core_feature.ports["config"] = self.ports[f"config_{idx}"]
        self.add_port(
            "config",
            magma.In(
                ConfigurationType(self.config_addr_width,
                                  self.config_data_width)))

        # or the signal up
        t = ConfigurationType(self.config_addr_width, self.config_data_width)
        t_names = ["config_addr", "config_data"]
        or_gates = {}
        for t_name in t_names:
            port_type = t[t_name]
            or_gate = FromMagma(
                mantle.DefineOr(len(self.__features), len(port_type)))
            or_gate.instance_name = f"OR_{t_name}_FEATURE"
            for idx, core_feature in enumerate(self.__features):
                self.wire(or_gate.ports[f"I{idx}"],
                          core_feature.ports.config[t_name])
            or_gates[t_name] = or_gate

        self.wire(
            or_gates["config_addr"].ports.O,
            self.underlying.ports.config_addr_in[0:self.config_addr_width])
        self.wire(or_gates["config_data"].ports.O,
                  self.underlying.ports.config_data_in)

        # read data out
        for idx, core_feature in enumerate(self.__features):
            if (idx > 0):
                # self.add_port(f"read_config_data_{idx}",
                self.add_port(f"read_config_data_{idx}",
                              magma.Out(magma.Bits[self.config_data_width]))
                # port aliasing
                core_feature.ports["read_config_data"] = \
                    self.ports[f"read_config_data_{idx}"]

        # MEM Config
        configurations = []
        # merged_configs = []
        skip_cfgs = []

        for cfg_info in cfgs:
            if cfg_info.port_name in skip_cfgs:
                continue
            if cfg_info.expl_arr:
                if cfg_info.port_size[0] > 1:
                    for i in range(cfg_info.port_size[0]):
                        configurations.append(
                            (f"{cfg_info.port_name}_{i}", cfg_info.port_width))
                else:
                    configurations.append(
                        (cfg_info.port_name, cfg_info.port_width))
            else:
                configurations.append(
                    (cfg_info.port_name, cfg_info.port_width))

        # Do all the stuff for the main config
        main_feature = self.__features[0]
        for config_reg_name, width in configurations:
            main_feature.add_config(config_reg_name, width)
            if (width == 1):
                self.wire(main_feature.registers[config_reg_name].ports.O[0],
                          self.underlying.ports[config_reg_name][0])
            else:
                self.wire(main_feature.registers[config_reg_name].ports.O,
                          self.underlying.ports[config_reg_name])

        # SRAM
        # These should also account for num features
        # or_all_cfg_rd = FromMagma(mantle.DefineOr(4, 1))
        or_all_cfg_rd = FromMagma(mantle.DefineOr(self.num_sram_features, 1))
        or_all_cfg_rd.instance_name = f"OR_CONFIG_WR_SRAM"
        or_all_cfg_wr = FromMagma(mantle.DefineOr(self.num_sram_features, 1))
        or_all_cfg_wr.instance_name = f"OR_CONFIG_RD_SRAM"
        for sram_index in range(self.num_sram_features):
            core_feature = self.__features[sram_index + 1]
            self.add_port(f"config_en_{sram_index}", magma.In(magma.Bit))
            # port aliasing
            core_feature.ports["config_en"] = \
                self.ports[f"config_en_{sram_index}"]
            # Sort of a temp hack - the name is just config_data_out
            if self.num_sram_features == 1:
                self.wire(core_feature.ports.read_config_data,
                          self.underlying.ports["config_data_out"])
            else:
                self.wire(
                    core_feature.ports.read_config_data,
                    self.underlying.ports[f"config_data_out_{sram_index}"])
            # also need to wire the sram signal
            # the config enable is the OR of the rd+wr
            or_gate_en = FromMagma(mantle.DefineOr(2, 1))
            or_gate_en.instance_name = f"OR_CONFIG_EN_SRAM_{sram_index}"

            self.wire(or_gate_en.ports.I0, core_feature.ports.config.write)
            self.wire(or_gate_en.ports.I1, core_feature.ports.config.read)
            self.wire(core_feature.ports.config_en,
                      self.underlying.ports["config_en"][sram_index])
            # Still connect to the OR of all the config rd/wr
            self.wire(core_feature.ports.config.write,
                      or_all_cfg_wr.ports[f"I{sram_index}"])
            self.wire(core_feature.ports.config.read,
                      or_all_cfg_rd.ports[f"I{sram_index}"])

        self.wire(or_all_cfg_rd.ports.O[0],
                  self.underlying.ports.config_read[0])
        self.wire(or_all_cfg_wr.ports.O[0],
                  self.underlying.ports.config_write[0])
        self._setup_config()

        conf_names = list(self.registers.keys())
        conf_names.sort()
        with open("mem_cfg.txt", "w+") as cfg_dump:
            for idx, reg in enumerate(conf_names):
                write_line = f"(\"{reg}\", 0),  # {self.registers[reg].width}\n"
                cfg_dump.write(write_line)
        with open("mem_synth.txt", "w+") as cfg_dump:
            for idx, reg in enumerate(conf_names):
                write_line = f"{reg}\n"
                cfg_dump.write(write_line)
Beispiel #2
0
    def wrap_lake_core(self):
        # Typedefs for ease
        if self.data_width:
            TData = magma.Bits[self.data_width]
        else:
            TData = magma.Bits[
                16]  # This shouldn't be used if the data_width was None
        TBit = magma.Bits[1]
        # Enumerate input and output ports
        # (clk and reset are assumed)
        core_interface = get_interface(self.dut)
        cfgs = extract_top_config(self.dut)
        assert len(cfgs) > 0, "No configs?"

        # We basically add in the configuration bus differently
        # than the other ports...
        skip_names = [
            "config_data_in", "config_write", "config_addr_in",
            "config_data_out", "config_read", "config_en", "clk_en"
        ]

        # Create a list of signals that will be able to be
        # hardwired to a constant at runtime...
        control_signals = []
        # The rest of the signals to wire to the underlying representation...
        other_signals = []

        # for port_name, port_size, port_width, is_ctrl, port_dir, explicit_array in core_interface:
        for io_info in core_interface:
            if io_info.port_name in skip_names:
                continue
            ind_ports = io_info.port_width
            intf_type = TBit
            # For our purposes, an explicit array means the inner data HAS to be 16 bits
            if io_info.expl_arr:
                ind_ports = io_info.port_size[0]
                intf_type = TData
            dir_type = magma.In
            app_list = self.__inputs
            if io_info.port_dir == "PortDirection.Out":
                dir_type = magma.Out
                app_list = self.__outputs
            if ind_ports > 1:
                for i in range(ind_ports):
                    self.add_port(f"{io_info.port_name}_{i}",
                                  dir_type(intf_type))
                    app_list.append(self.ports[f"{io_info.port_name}_{i}"])
            else:
                self.add_port(io_info.port_name, dir_type(intf_type))
                app_list.append(self.ports[io_info.port_name])

            # classify each signal for wiring to underlying representation...
            if io_info.is_ctrl:
                control_signals.append((io_info.port_name, io_info.port_width))
            else:
                if ind_ports > 1:
                    for i in range(ind_ports):
                        other_signals.append(
                            (f"{io_info.port_name}_{i}", io_info.port_dir,
                             io_info.expl_arr, i, io_info.port_name))
                else:
                    other_signals.append(
                        (io_info.port_name, io_info.port_dir, io_info.expl_arr,
                         0, io_info.port_name))

        assert (len(self.__outputs) > 0)

        # We call clk_en stall at this level for legacy reasons????
        self.add_ports(stall=magma.In(TBit), )

        # put a 1-bit register and a mux to select the control signals
        for control_signal, width in control_signals:
            if width == 1:
                mux = MuxWrapper(2, 1, name=f"{control_signal}_sel")
                reg_value_name = f"{control_signal}_reg_value"
                reg_sel_name = f"{control_signal}_reg_sel"
                self.add_config(reg_value_name, 1)
                self.add_config(reg_sel_name, 1)
                self.wire(mux.ports.I[0], self.ports[control_signal])
                self.wire(mux.ports.I[1],
                          self.registers[reg_value_name].ports.O)
                self.wire(mux.ports.S, self.registers[reg_sel_name].ports.O)
                # 0 is the default wire, which takes from the routing network
                self.wire(mux.ports.O[0],
                          self.underlying.ports[control_signal][0])
            else:
                for i in range(width):
                    mux = MuxWrapper(2, 1, name=f"{control_signal}_{i}_sel")
                    reg_value_name = f"{control_signal}_{i}_reg_value"
                    reg_sel_name = f"{control_signal}_{i}_reg_sel"
                    self.add_config(reg_value_name, 1)
                    self.add_config(reg_sel_name, 1)
                    self.wire(mux.ports.I[0],
                              self.ports[f"{control_signal}_{i}"])
                    self.wire(mux.ports.I[1],
                              self.registers[reg_value_name].ports.O)
                    self.wire(mux.ports.S,
                              self.registers[reg_sel_name].ports.O)
                    # 0 is the default wire, which takes from the routing network
                    self.wire(mux.ports.O[0],
                              self.underlying.ports[control_signal][i])

        # Wire the other signals up...
        for pname, pdir, expl_arr, ind, uname in other_signals:
            # If we are in an explicit array moment, use the given wire name...
            if expl_arr is False:
                # And if not, use the index
                self.wire(self.ports[pname][0],
                          self.underlying.ports[uname][ind])
            else:
                self.wire(self.ports[pname], self.underlying.ports[pname])

        # CLK, RESET, and STALL PER STANDARD PROCEDURE

        # Need to invert this
        self.resetInverter = FromMagma(mantle.DefineInvert(1))
        self.wire(self.resetInverter.ports.I[0], self.ports.reset)
        self.wire(
            self.convert(self.resetInverter.ports.O[0], magma.asyncreset),
            self.underlying.ports.rst_n)
        self.wire(self.ports.clk, self.underlying.ports.clk)

        # Mem core uses clk_en (essentially active low stall)
        self.stallInverter = FromMagma(mantle.DefineInvert(1))
        self.wire(self.stallInverter.ports.I, self.ports.stall)
        self.wire(self.stallInverter.ports.O[0],
                  self.underlying.ports.clk_en[0])

        # we have six? features in total
        # 0:    TILE
        # 1:    TILE
        # 1-4:  SMEM
        # Feature 0: Tile
        self.__features: List[CoreFeature] = [self]
        # Features 1-4: SRAM
        self.num_sram_features = self.dut.total_sets
        for sram_index in range(self.num_sram_features):
            core_feature = CoreFeature(self, sram_index + 1)
            core_feature.skip_compression = True
            self.__features.append(core_feature)

        # Wire the config
        for idx, core_feature in enumerate(self.__features):
            if (idx > 0):
                self.add_port(
                    f"config_{idx}",
                    magma.In(
                        ConfigurationType(self.config_addr_width,
                                          self.config_data_width)))
                # port aliasing
                core_feature.ports["config"] = self.ports[f"config_{idx}"]
        self.add_port(
            "config",
            magma.In(
                ConfigurationType(self.config_addr_width,
                                  self.config_data_width)))

        if self.num_sram_features > 0:
            # or the signal up
            t = ConfigurationType(self.config_addr_width,
                                  self.config_data_width)
            t_names = ["config_addr", "config_data"]
            or_gates = {}
            for t_name in t_names:
                port_type = t[t_name]
                or_gate = FromMagma(
                    mantle.DefineOr(len(self.__features), len(port_type)))
                or_gate.instance_name = f"OR_{t_name}_FEATURE"
                for idx, core_feature in enumerate(self.__features):
                    self.wire(or_gate.ports[f"I{idx}"],
                              core_feature.ports.config[t_name])
                or_gates[t_name] = or_gate

            self.wire(
                or_gates["config_addr"].ports.O,
                self.underlying.ports.config_addr_in[0:self.config_addr_width])
            self.wire(or_gates["config_data"].ports.O,
                      self.underlying.ports.config_data_in)

        # read data out
        for idx, core_feature in enumerate(self.__features):
            if (idx > 0):
                # self.add_port(f"read_config_data_{idx}",
                self.add_port(f"read_config_data_{idx}",
                              magma.Out(magma.Bits[self.config_data_width]))
                # port aliasing
                core_feature.ports["read_config_data"] = \
                    self.ports[f"read_config_data_{idx}"]

        # MEM Config
        configurations = []
        # merged_configs = []
        skip_cfgs = []

        for cfg_info in cfgs:
            if cfg_info.port_name in skip_cfgs:
                continue
            if cfg_info.expl_arr:
                if cfg_info.port_size[0] > 1:
                    for i in range(cfg_info.port_size[0]):
                        configurations.append(
                            (f"{cfg_info.port_name}_{i}", cfg_info.port_width))
                else:
                    configurations.append(
                        (cfg_info.port_name, cfg_info.port_width))
            else:
                configurations.append(
                    (cfg_info.port_name, cfg_info.port_width))

        # Do all the stuff for the main config
        main_feature = self.__features[0]
        for config_reg_name, width in configurations:
            main_feature.add_config(config_reg_name, width)
            if (width == 1):
                self.wire(main_feature.registers[config_reg_name].ports.O[0],
                          self.underlying.ports[config_reg_name][0])
            else:
                self.wire(main_feature.registers[config_reg_name].ports.O,
                          self.underlying.ports[config_reg_name])

        # SRAM
        # These should also account for num features
        # or_all_cfg_rd = FromMagma(mantle.DefineOr(4, 1))
        if self.num_sram_features > 0:
            or_all_cfg_rd = FromMagma(
                mantle.DefineOr(self.num_sram_features, 1))
            or_all_cfg_rd.instance_name = f"OR_CONFIG_WR_SRAM"
            or_all_cfg_wr = FromMagma(
                mantle.DefineOr(self.num_sram_features, 1))
            or_all_cfg_wr.instance_name = f"OR_CONFIG_RD_SRAM"

            for sram_index in range(self.num_sram_features):
                core_feature = self.__features[sram_index + 1]
                self.add_port(f"config_en_{sram_index}", magma.In(magma.Bit))
                # port aliasing
                core_feature.ports["config_en"] = \
                    self.ports[f"config_en_{sram_index}"]
                # Sort of a temp hack - the name is just config_data_out
                if self.num_sram_features == 1:
                    self.wire(core_feature.ports.read_config_data,
                              self.underlying.ports["config_data_out"])
                else:
                    self.wire(
                        core_feature.ports.read_config_data,
                        self.underlying.ports[f"config_data_out_{sram_index}"])
                and_gate_en = FromMagma(mantle.DefineAnd(2, 1))
                and_gate_en.instance_name = f"AND_CONFIG_EN_SRAM_{sram_index}"
                # also need to wire the sram signal
                # the config enable is the OR of the rd+wr
                or_gate_en = FromMagma(mantle.DefineOr(2, 1))
                or_gate_en.instance_name = f"OR_CONFIG_EN_SRAM_{sram_index}"

                self.wire(or_gate_en.ports.I0, core_feature.ports.config.write)
                self.wire(or_gate_en.ports.I1, core_feature.ports.config.read)
                self.wire(and_gate_en.ports.I0, or_gate_en.ports.O)
                self.wire(and_gate_en.ports.I1[0],
                          core_feature.ports.config_en)
                self.wire(and_gate_en.ports.O[0],
                          self.underlying.ports["config_en"][sram_index])
                # Still connect to the OR of all the config rd/wr
                self.wire(core_feature.ports.config.write,
                          or_all_cfg_wr.ports[f"I{sram_index}"])
                self.wire(core_feature.ports.config.read,
                          or_all_cfg_rd.ports[f"I{sram_index}"])

            self.wire(or_all_cfg_rd.ports.O[0],
                      self.underlying.ports.config_read[0])
            self.wire(or_all_cfg_wr.ports.O[0],
                      self.underlying.ports.config_write[0])

        self._setup_config()
Beispiel #3
0
    def __init__(self, data_width, word_width, data_depth,
                 num_banks, use_sram_stub):

        super().__init__(8, 32)

        self.data_width = data_width
        self.data_depth = data_depth
        self.num_banks = num_banks
        self.word_width = word_width
        if use_sram_stub:
            self.use_sram_stub = 1
        else:
            self.use_sram_stub = 0

        TData = magma.Bits[self.word_width]
        TBit = magma.Bits[1]

        self.add_ports(
            data_in=magma.In(TData),
            addr_in=magma.In(TData),
            data_out=magma.Out(TData),
            flush=magma.In(TBit),
            wen_in=magma.In(TBit),
            ren_in=magma.In(TBit),

            stall=magma.In(magma.Bits[4]),

            valid_out=magma.Out(TBit),

            switch_db=magma.In(TBit)
        )
        # Instead of a single read_config_data, we have multiple for each
        # "sub"-feature of this core.
        # self.ports.pop("read_config_data")

        if (data_width, word_width, data_depth,
            num_banks, use_sram_stub) not in \
           MemCore.__circuit_cache:

            wrapper = memory_core_genesis2.memory_core_wrapper
            param_mapping = memory_core_genesis2.param_mapping
            generator = wrapper.generator(param_mapping, mode="declare")
            circ = generator(data_width=self.data_width,
                             data_depth=self.data_depth,
                             word_width=self.word_width,
                             num_banks=self.num_banks,
                             use_sram_stub=self.use_sram_stub)
            MemCore.__circuit_cache[(data_width, word_width,
                                     data_depth, num_banks,
                                     use_sram_stub)] = circ
        else:
            circ = MemCore.__circuit_cache[(data_width, word_width,
                                            data_depth, num_banks,
                                            use_sram_stub)]

        self.underlying = FromMagma(circ)

        # put a 1-bit register and a mux to select the control signals
        control_signals = ["wen_in", "ren_in", "flush", "switch_db"]
        for control_signal in control_signals:
            # TODO: consult with Ankita to see if we can use the normal
            # mux here
            mux = MuxWrapper(2, 1, name=f"{control_signal}_sel")
            reg_value_name = f"{control_signal}_reg_value"
            reg_sel_name = f"{control_signal}_reg_sel"
            self.add_config(reg_value_name, 1)
            self.add_config(reg_sel_name, 1)
            self.wire(mux.ports.I[0], self.ports[control_signal])
            self.wire(mux.ports.I[1], self.registers[reg_value_name].ports.O)
            self.wire(mux.ports.S, self.registers[reg_sel_name].ports.O)
            # 0 is the default wire, which takes from the routing network
            self.wire(mux.ports.O[0], self.underlying.ports[control_signal])

        self.wire(self.ports.data_in, self.underlying.ports.data_in)
        self.wire(self.ports.addr_in, self.underlying.ports.addr_in)
        self.wire(self.ports.data_out, self.underlying.ports.data_out)
        self.wire(self.ports.reset, self.underlying.ports.reset)
        self.wire(self.ports.clk, self.underlying.ports.clk)
        self.wire(self.ports.valid_out[0], self.underlying.ports.valid_out)

        # PE core uses clk_en (essentially active low stall)
        self.stallInverter = FromMagma(mantle.DefineInvert(1))
        self.wire(self.stallInverter.ports.I, self.ports.stall[0:1])
        self.wire(self.stallInverter.ports.O[0], self.underlying.ports.clk_en)

        zero_signals = (
            ("chain_wen_in", 1),
            ("chain_in", self.word_width),
        )
        one_signals = (
            ("config_read", 1),
            ("config_write", 1)
        )
        # enable read and write by default
        for name, width in zero_signals:
            val = magma.bits(0, width) if width > 1 else magma.bit(0)
            self.wire(Const(val), self.underlying.ports[name])
        for name, width in one_signals:
            val = magma.bits(1, width) if width > 1 else magma.bit(1)
            self.wire(Const(val), self.underlying.ports[name])
        self.wire(Const(magma.bits(0, 24)),
                  self.underlying.ports.config_addr[0:24])

        # we have five features in total
        # 0:    TILE
        # 1-4:  SMEM
        # Feature 0: Tile
        self.__features: List[CoreFeature] = [self]
        # Features 1-4: SRAM
        for sram_index in range(4):
            core_feature = CoreFeature(self, sram_index + 1)
            self.__features.append(core_feature)

        # Wire the config
        for idx, core_feature in enumerate(self.__features):
            if(idx > 0):
                self.add_port(f"config_{idx}",
                              magma.In(ConfigurationType(8, 32)))
                # port aliasing
                core_feature.ports["config"] = self.ports[f"config_{idx}"]
        self.add_port("config", magma.In(ConfigurationType(8, 32)))

        # or the signal up
        t = ConfigurationType(8, 32)
        t_names = ["config_addr", "config_data"]
        or_gates = {}
        for t_name in t_names:
            port_type = t[t_name]
            or_gate = FromMagma(mantle.DefineOr(len(self.__features),
                                                len(port_type)))
            or_gate.instance_name = f"OR_{t_name}_FEATURE"
            for idx, core_feature in enumerate(self.__features):
                self.wire(or_gate.ports[f"I{idx}"],
                          core_feature.ports.config[t_name])
            or_gates[t_name] = or_gate

        self.wire(or_gates["config_addr"].ports.O,
                  self.underlying.ports.config_addr[24:32])
        self.wire(or_gates["config_data"].ports.O,
                  self.underlying.ports.config_data)

        # only the first one has config_en
#        self.wire(self.__features[0].ports.config.write[0],
#                  self.underlying.ports.config_en)

        # read data out
        for idx, core_feature in enumerate(self.__features):
            if(idx > 0):
                self.add_port(f"read_config_data_{idx}",
                              magma.Out(magma.Bits[32]))
                # port aliasing
                core_feature.ports["read_config_data"] = \
                    self.ports[f"read_config_data_{idx}"]

        # MEM config
        # self.wire(self.ports.read_config_data,
        #          self.underlying.ports.read_config_data)

        configurations = [
            ("stencil_width", 32),
            ("read_mode", 1),
            ("arbitrary_addr", 1),
            ("starting_addr", 32),
            ("iter_cnt", 32),
            ("dimensionality", 32),
            ("circular_en", 1),
            ("almost_count", 4),
            ("enable_chain", 1),
            ("mode", 2),
            ("tile_en", 1),
            ("chain_idx", 4),
            ("depth", 13)
        ]
        # Do all the stuff for the main config
        main_feature = self.__features[0]
        for config_reg_name, width in configurations:
            main_feature.add_config(config_reg_name, width)
            if(width == 1):
                self.wire(main_feature.registers[config_reg_name].ports.O[0],
                          self.underlying.ports[config_reg_name])
            else:
                self.wire(main_feature.registers[config_reg_name].ports.O,
                          self.underlying.ports[config_reg_name])

        for idx in range(8):
            main_feature.add_config(f"stride_{idx}", 32)
            main_feature.add_config(f"range_{idx}", 32)
            self.wire(main_feature.registers[f"stride_{idx}"].ports.O,
                      self.underlying.ports[f"stride_{idx}"])
            self.wire(main_feature.registers[f"range_{idx}"].ports.O,
                      self.underlying.ports[f"range_{idx}"])

        # SRAM
        for sram_index in range(4):
            core_feature = self.__features[sram_index + 1]
            self.wire(core_feature.ports.read_config_data,
                      self.underlying.ports[f"read_data_sram_{sram_index}"])
            # also need to wire the sram signal
            self.wire(core_feature.ports.config.write[0],
                      self.underlying.ports["config_en_sram"][sram_index])

        self._setup_config()

        conf_names = list(self.registers.keys())
        conf_names.sort()
        with open("mem_cfg.txt", "w+") as cfg_dump:
            for idx, reg in enumerate(conf_names):
                write_line = f"|{reg}|{idx}|{self.registers[reg].width}||\n"
                cfg_dump.write(write_line)
Beispiel #4
0
    def __init__(
            self,
            data_width=16,  # CGRA Params
            mem_width=64,
            mem_depth=512,
            banks=1,
            input_iterator_support=6,  # Addr Controllers
            output_iterator_support=6,
            input_config_width=16,
            output_config_width=16,
            interconnect_input_ports=2,  # Connection to int
            interconnect_output_ports=2,
            mem_input_ports=1,
            mem_output_ports=1,
            use_sram_stub=1,
            sram_macro_info=SRAMMacroInfo("TS1N16FFCLLSBLVTC512X32M4S"),
            read_delay=1,  # Cycle delay in read (SRAM vs Register File)
            rw_same_cycle=False,  # Does the memory allow r+w in same cycle?
            agg_height=4,
            max_agg_schedule=16,
            input_max_port_sched=16,
            output_max_port_sched=16,
            align_input=1,
            max_line_length=128,
            max_tb_height=1,
            tb_range_max=1024,
            tb_range_inner_max=64,
            tb_sched_max=16,
            max_tb_stride=15,
            num_tb=1,
            tb_iterator_support=2,
            multiwrite=1,
            max_prefetch=8,
            config_data_width=32,
            config_addr_width=8,
            num_tiles=2,
            app_ctrl_depth_width=16,
            remove_tb=False,
            fifo_mode=True,
            add_clk_enable=True,
            add_flush=True,
            core_reset_pos=False,
            stcl_valid_iter=4):

        super().__init__(config_addr_width, config_data_width)

        # Capture everything to the tile object
        self.data_width = data_width
        self.mem_width = mem_width
        self.mem_depth = mem_depth
        self.banks = banks
        self.fw_int = int(self.mem_width / self.data_width)
        self.input_iterator_support = input_iterator_support
        self.output_iterator_support = output_iterator_support
        self.input_config_width = input_config_width
        self.output_config_width = output_config_width
        self.interconnect_input_ports = interconnect_input_ports
        self.interconnect_output_ports = interconnect_output_ports
        self.mem_input_ports = mem_input_ports
        self.mem_output_ports = mem_output_ports
        self.use_sram_stub = use_sram_stub
        self.sram_macro_info = sram_macro_info
        self.read_delay = read_delay
        self.rw_same_cycle = rw_same_cycle
        self.agg_height = agg_height
        self.max_agg_schedule = max_agg_schedule
        self.input_max_port_sched = input_max_port_sched
        self.output_max_port_sched = output_max_port_sched
        self.align_input = align_input
        self.max_line_length = max_line_length
        self.max_tb_height = max_tb_height
        self.tb_range_max = tb_range_max
        self.tb_range_inner_max = tb_range_inner_max
        self.tb_sched_max = tb_sched_max
        self.max_tb_stride = max_tb_stride
        self.num_tb = num_tb
        self.tb_iterator_support = tb_iterator_support
        self.multiwrite = multiwrite
        self.max_prefetch = max_prefetch
        self.config_data_width = config_data_width
        self.config_addr_width = config_addr_width
        self.num_tiles = num_tiles
        self.remove_tb = remove_tb
        self.fifo_mode = fifo_mode
        self.add_clk_enable = add_clk_enable
        self.add_flush = add_flush
        self.core_reset_pos = core_reset_pos
        self.app_ctrl_depth_width = app_ctrl_depth_width
        self.stcl_valid_iter = stcl_valid_iter

        # Typedefs for ease
        TData = magma.Bits[self.data_width]
        TBit = magma.Bits[1]

        self.__inputs = []
        self.__outputs = []

        # Enumerate input and output ports
        # (clk and reset are assumed)
        if self.interconnect_input_ports > 1:
            for i in range(self.interconnect_input_ports):
                self.add_port(f"addr_in_{i}", magma.In(TData))
                self.__inputs.append(self.ports[f"addr_in_{i}"])
                self.add_port(f"data_in_{i}", magma.In(TData))
                self.__inputs.append(self.ports[f"data_in_{i}"])
                self.add_port(f"wen_in_{i}", magma.In(TBit))
                self.__inputs.append(self.ports[f"wen_in_{i}"])
        else:
            self.add_port("addr_in", magma.In(TData))
            self.__inputs.append(self.ports[f"addr_in"])
            self.add_port("data_in", magma.In(TData))
            self.__inputs.append(self.ports[f"data_in"])
            self.add_port("wen_in", magma.In(TBit))
            self.__inputs.append(self.ports.wen_in)

        if self.interconnect_output_ports > 1:
            for i in range(self.interconnect_output_ports):
                self.add_port(f"data_out_{i}", magma.Out(TData))
                self.__outputs.append(self.ports[f"data_out_{i}"])
                self.add_port(f"ren_in_{i}", magma.In(TBit))
                self.__inputs.append(self.ports[f"ren_in_{i}"])
                self.add_port(f"valid_out_{i}", magma.Out(TBit))
                self.__outputs.append(self.ports[f"valid_out_{i}"])
                # Chaining
                self.add_port(f"chain_valid_in_{i}", magma.In(TBit))
                self.__inputs.append(self.ports[f"chain_valid_in_{i}"])
                self.add_port(f"chain_data_in_{i}", magma.In(TData))
                self.__inputs.append(self.ports[f"chain_data_in_{i}"])
                self.add_port(f"chain_data_out_{i}", magma.Out(TData))
                self.__outputs.append(self.ports[f"chain_data_out_{i}"])
                self.add_port(f"chain_valid_out_{i}", magma.Out(TBit))
                self.__outputs.append(self.ports[f"chain_valid_out_{i}"])
        else:
            self.add_port("data_out", magma.Out(TData))
            self.__outputs.append(self.ports[f"data_out"])
            self.add_port(f"ren_in", magma.In(TBit))
            self.__inputs.append(self.ports[f"ren_in"])
            self.add_port(f"valid_out", magma.Out(TBit))
            self.__outputs.append(self.ports[f"valid_out"])
            self.add_port(f"chain_valid_in", magma.In(TBit))
            self.__inputs.append(self.ports[f"chain_valid_in"])
            self.add_port(f"chain_data_in", magma.In(TData))
            self.__inputs.append(self.ports[f"chain_data_in"])
            self.add_port(f"chain_data_out", magma.Out(TData))
            self.__outputs.append(self.ports[f"chain_data_out"])
            self.add_port(f"chain_valid_out", magma.Out(TBit))
            self.__outputs.append(self.ports[f"chain_valid_out"])

        self.add_ports(flush=magma.In(TBit),
                       full=magma.Out(TBit),
                       empty=magma.Out(TBit),
                       stall=magma.In(TBit),
                       sram_ready_out=magma.Out(TBit))

        self.__inputs.append(self.ports.flush)
        # self.__inputs.append(self.ports.stall)

        self.__outputs.append(self.ports.full)
        self.__outputs.append(self.ports.empty)
        self.__outputs.append(self.ports.sram_ready_out)

        cache_key = (self.data_width, self.mem_width, self.mem_depth,
                     self.banks, self.input_iterator_support,
                     self.output_iterator_support,
                     self.interconnect_input_ports,
                     self.interconnect_output_ports, self.use_sram_stub,
                     self.sram_macro_info, self.read_delay, self.rw_same_cycle,
                     self.agg_height, self.max_agg_schedule,
                     self.input_max_port_sched, self.output_max_port_sched,
                     self.align_input, self.max_line_length,
                     self.max_tb_height, self.tb_range_max, self.tb_sched_max,
                     self.max_tb_stride, self.num_tb, self.tb_iterator_support,
                     self.multiwrite, self.max_prefetch,
                     self.config_data_width, self.config_addr_width,
                     self.num_tiles, self.remove_tb, self.fifo_mode,
                     self.stcl_valid_iter, self.add_clk_enable, self.add_flush,
                     self.app_ctrl_depth_width)

        # Check for circuit caching
        if cache_key not in MemCore.__circuit_cache:

            # Instantiate core object here - will only use the object representation to
            # query for information. The circuit representation will be cached and retrieved
            # in the following steps.
            lt_dut = LakeTop(
                data_width=self.data_width,
                mem_width=self.mem_width,
                mem_depth=self.mem_depth,
                banks=self.banks,
                input_iterator_support=self.input_iterator_support,
                output_iterator_support=self.output_iterator_support,
                input_config_width=self.input_config_width,
                output_config_width=self.output_config_width,
                interconnect_input_ports=self.interconnect_input_ports,
                interconnect_output_ports=self.interconnect_output_ports,
                use_sram_stub=self.use_sram_stub,
                sram_macro_info=self.sram_macro_info,
                read_delay=self.read_delay,
                rw_same_cycle=self.rw_same_cycle,
                agg_height=self.agg_height,
                max_agg_schedule=self.max_agg_schedule,
                input_max_port_sched=self.input_max_port_sched,
                output_max_port_sched=self.output_max_port_sched,
                align_input=self.align_input,
                max_line_length=self.max_line_length,
                max_tb_height=self.max_tb_height,
                tb_range_max=self.tb_range_max,
                tb_range_inner_max=self.tb_range_inner_max,
                tb_sched_max=self.tb_sched_max,
                max_tb_stride=self.max_tb_stride,
                num_tb=self.num_tb,
                tb_iterator_support=self.tb_iterator_support,
                multiwrite=self.multiwrite,
                max_prefetch=self.max_prefetch,
                config_data_width=self.config_data_width,
                config_addr_width=self.config_addr_width,
                num_tiles=self.num_tiles,
                app_ctrl_depth_width=self.app_ctrl_depth_width,
                remove_tb=self.remove_tb,
                fifo_mode=self.fifo_mode,
                add_clk_enable=self.add_clk_enable,
                add_flush=self.add_flush,
                stcl_valid_iter=self.stcl_valid_iter)

            change_sram_port_pass = change_sram_port_names(
                use_sram_stub, sram_macro_info)
            circ = kts.util.to_magma(
                lt_dut,
                flatten_array=True,
                check_multiple_driver=False,
                optimize_if=False,
                check_flip_flop_always_ff=False,
                additional_passes={"change_sram_port": change_sram_port_pass})
            MemCore.__circuit_cache[cache_key] = (circ, lt_dut)
        else:
            circ, lt_dut = MemCore.__circuit_cache[cache_key]

        # Save as underlying circuit object
        self.underlying = FromMagma(circ)

        self.chain_idx_bits = max(1, kts.clog2(self.num_tiles))

        # put a 1-bit register and a mux to select the control signals
        # TODO: check if enable_chain_output needs to be here? I don't think so?
        control_signals = [("wen_in", self.interconnect_input_ports),
                           ("ren_in", self.interconnect_output_ports),
                           ("flush", 1),
                           ("chain_valid_in", self.interconnect_output_ports)]
        for control_signal, width in control_signals:
            # TODO: consult with Ankita to see if we can use the normal
            # mux here
            if width == 1:
                mux = MuxWrapper(2, 1, name=f"{control_signal}_sel")
                reg_value_name = f"{control_signal}_reg_value"
                reg_sel_name = f"{control_signal}_reg_sel"
                self.add_config(reg_value_name, 1)
                self.add_config(reg_sel_name, 1)
                self.wire(mux.ports.I[0], self.ports[control_signal])
                self.wire(mux.ports.I[1],
                          self.registers[reg_value_name].ports.O)
                self.wire(mux.ports.S, self.registers[reg_sel_name].ports.O)
                # 0 is the default wire, which takes from the routing network
                self.wire(mux.ports.O[0],
                          self.underlying.ports[control_signal][0])
            else:
                for i in range(width):
                    mux = MuxWrapper(2, 1, name=f"{control_signal}_{i}_sel")
                    reg_value_name = f"{control_signal}_{i}_reg_value"
                    reg_sel_name = f"{control_signal}_{i}_reg_sel"
                    self.add_config(reg_value_name, 1)
                    self.add_config(reg_sel_name, 1)
                    self.wire(mux.ports.I[0],
                              self.ports[f"{control_signal}_{i}"])
                    self.wire(mux.ports.I[1],
                              self.registers[reg_value_name].ports.O)
                    self.wire(mux.ports.S,
                              self.registers[reg_sel_name].ports.O)
                    # 0 is the default wire, which takes from the routing network
                    self.wire(mux.ports.O[0],
                              self.underlying.ports[control_signal][i])

        if self.interconnect_input_ports > 1:
            for i in range(self.interconnect_input_ports):
                self.wire(self.ports[f"data_in_{i}"],
                          self.underlying.ports[f"data_in_{i}"])
                self.wire(self.ports[f"addr_in_{i}"],
                          self.underlying.ports[f"addr_in_{i}"])
        else:
            self.wire(self.ports.addr_in, self.underlying.ports.addr_in)
            self.wire(self.ports.data_in, self.underlying.ports.data_in)

        if self.interconnect_output_ports > 1:
            for i in range(self.interconnect_output_ports):
                self.wire(self.ports[f"data_out_{i}"],
                          self.underlying.ports[f"data_out_{i}"])
                self.wire(self.ports[f"chain_data_in_{i}"],
                          self.underlying.ports[f"chain_data_in_{i}"])
                self.wire(self.ports[f"chain_data_out_{i}"],
                          self.underlying.ports[f"chain_data_out_{i}"])
        else:
            self.wire(self.ports.data_out, self.underlying.ports.data_out)
            self.wire(self.ports.chain_data_in,
                      self.underlying.ports.chain_data_in)
            self.wire(self.ports.chain_data_out,
                      self.underlying.ports.chain_data_out)

        # Need to invert this
        self.resetInverter = FromMagma(mantle.DefineInvert(1))
        self.wire(self.resetInverter.ports.I[0], self.ports.reset)
        self.wire(self.resetInverter.ports.O[0], self.underlying.ports.rst_n)
        self.wire(self.ports.clk, self.underlying.ports.clk)
        if self.interconnect_output_ports == 1:
            self.wire(self.ports.valid_out[0],
                      self.underlying.ports.valid_out[0])
            self.wire(self.ports.chain_valid_out[0],
                      self.underlying.ports.chain_valid_out[0])
        else:
            for j in range(self.interconnect_output_ports):
                self.wire(self.ports[f"valid_out_{j}"][0],
                          self.underlying.ports.valid_out[j])
                self.wire(self.ports[f"chain_valid_out_{j}"][0],
                          self.underlying.ports.chain_valid_out[j])
        self.wire(self.ports.empty[0], self.underlying.ports.empty[0])
        self.wire(self.ports.full[0], self.underlying.ports.full[0])

        # PE core uses clk_en (essentially active low stall)
        self.stallInverter = FromMagma(mantle.DefineInvert(1))
        self.wire(self.stallInverter.ports.I, self.ports.stall)
        self.wire(self.stallInverter.ports.O[0],
                  self.underlying.ports.clk_en[0])

        self.wire(self.ports.sram_ready_out[0],
                  self.underlying.ports.sram_ready_out[0])

        # we have six? features in total
        # 0:    TILE
        # 1:    TILE
        # 1-4:  SMEM
        # Feature 0: Tile
        self.__features: List[CoreFeature] = [self]
        # Features 1-4: SRAM
        self.num_sram_features = lt_dut.total_sets
        for sram_index in range(self.num_sram_features):
            core_feature = CoreFeature(self, sram_index + 1)
            self.__features.append(core_feature)

        # Wire the config
        for idx, core_feature in enumerate(self.__features):
            if (idx > 0):
                self.add_port(f"config_{idx}",
                              magma.In(ConfigurationType(8, 32)))
                # port aliasing
                core_feature.ports["config"] = self.ports[f"config_{idx}"]
        self.add_port("config", magma.In(ConfigurationType(8, 32)))

        # or the signal up
        t = ConfigurationType(8, 32)
        t_names = ["config_addr", "config_data"]
        or_gates = {}
        for t_name in t_names:
            port_type = t[t_name]
            or_gate = FromMagma(
                mantle.DefineOr(len(self.__features), len(port_type)))
            or_gate.instance_name = f"OR_{t_name}_FEATURE"
            for idx, core_feature in enumerate(self.__features):
                self.wire(or_gate.ports[f"I{idx}"],
                          core_feature.ports.config[t_name])
            or_gates[t_name] = or_gate

        self.wire(or_gates["config_addr"].ports.O,
                  self.underlying.ports.config_addr_in[0:8])
        self.wire(or_gates["config_data"].ports.O,
                  self.underlying.ports.config_data_in)

        # read data out
        for idx, core_feature in enumerate(self.__features):
            if (idx > 0):
                # self.add_port(f"read_config_data_{idx}",
                self.add_port(f"read_config_data_{idx}",
                              magma.Out(magma.Bits[32]))
                # port aliasing
                core_feature.ports["read_config_data"] = \
                    self.ports[f"read_config_data_{idx}"]

        # MEM Config
        configurations = [("tile_en", 1), ("fifo_ctrl_fifo_depth", 16),
                          ("mode", 2), ("enable_chain_output", 1),
                          ("enable_chain_input", 1)]
        #            ("stencil_width", 16), NOT YET

        merged_configs = []
        merged_in_sched = []
        merged_out_sched = []

        # Add config registers to configurations
        # TODO: Have lake spit this information out automatically from the wrapper

        configurations.append((f"chain_idx_input", self.chain_idx_bits))
        configurations.append((f"chain_idx_output", self.chain_idx_bits))
        for i in range(self.interconnect_input_ports):
            configurations.append((f"strg_ub_agg_align_{i}_line_length",
                                   kts.clog2(self.max_line_length)))
            configurations.append((f"strg_ub_agg_in_{i}_in_period",
                                   kts.clog2(self.input_max_port_sched)))

            # num_bits_in_sched = kts.clog2(self.agg_height)
            # sched_per_feat = math.floor(self.config_data_width / num_bits_in_sched)
            # new_width = num_bits_in_sched * sched_per_feat
            # feat_num = 0
            # num_feats_merge = math.ceil(self.input_max_port_sched / sched_per_feat)
            # for k in range(num_feats_merge):
            #    num_here = sched_per_feat
            #    if self.input_max_port_sched - (k * sched_per_feat) < sched_per_feat:
            #        num_here = self.input_max_port_sched - (k * sched_per_feat)
            #    merged_configs.append((f"strg_ub_agg_in_{i}_in_sched_merged_{k * sched_per_feat}",
            #                          num_here * num_bits_in_sched, num_here))
            for j in range(self.input_max_port_sched):
                configurations.append((f"strg_ub_agg_in_{i}_in_sched_{j}",
                                       kts.clog2(self.agg_height)))

            configurations.append((f"strg_ub_agg_in_{i}_out_period",
                                   kts.clog2(self.input_max_port_sched)))

            for j in range(self.output_max_port_sched):
                configurations.append((f"strg_ub_agg_in_{i}_out_sched_{j}",
                                       kts.clog2(self.agg_height)))

            configurations.append((f"strg_ub_app_ctrl_write_depth_wo_{i}",
                                   self.app_ctrl_depth_width))
            configurations.append((f"strg_ub_app_ctrl_write_depth_ss_{i}",
                                   self.app_ctrl_depth_width))
            configurations.append(
                (f"strg_ub_app_ctrl_coarse_write_depth_wo_{i}",
                 self.app_ctrl_depth_width))
            configurations.append(
                (f"strg_ub_app_ctrl_coarse_write_depth_ss_{i}",
                 self.app_ctrl_depth_width))

            configurations.append(
                (f"strg_ub_input_addr_ctrl_address_gen_{i}_dimensionality",
                 1 + kts.clog2(self.input_iterator_support)))
            configurations.append(
                (f"strg_ub_input_addr_ctrl_address_gen_{i}_starting_addr",
                 self.input_config_width))
            for j in range(self.input_iterator_support):
                configurations.append(
                    (f"strg_ub_input_addr_ctrl_address_gen_{i}_ranges_{j}",
                     self.input_config_width))
                configurations.append(
                    (f"strg_ub_input_addr_ctrl_address_gen_{i}_strides_{j}",
                     self.input_config_width))

        configurations.append(
            (f"strg_ub_app_ctrl_prefill", self.interconnect_output_ports))
        configurations.append((f"strg_ub_app_ctrl_coarse_prefill",
                               self.interconnect_output_ports))

        for i in range(self.stcl_valid_iter):
            configurations.append((f"strg_ub_app_ctrl_ranges_{i}", 16))
            configurations.append((f"strg_ub_app_ctrl_threshold_{i}", 16))

        for i in range(self.interconnect_output_ports):
            configurations.append((f"strg_ub_app_ctrl_input_port_{i}",
                                   kts.clog2(self.interconnect_input_ports)))
            configurations.append((f"strg_ub_app_ctrl_read_depth_{i}",
                                   self.app_ctrl_depth_width))
            configurations.append((f"strg_ub_app_ctrl_coarse_input_port_{i}",
                                   kts.clog2(self.interconnect_input_ports)))
            configurations.append((f"strg_ub_app_ctrl_coarse_read_depth_{i}",
                                   self.app_ctrl_depth_width))

            configurations.append(
                (f"strg_ub_output_addr_ctrl_address_gen_{i}_dimensionality",
                 1 + kts.clog2(self.output_iterator_support)))
            configurations.append(
                (f"strg_ub_output_addr_ctrl_address_gen_{i}_starting_addr",
                 self.output_config_width))
            for j in range(self.output_iterator_support):
                configurations.append(
                    (f"strg_ub_output_addr_ctrl_address_gen_{i}_ranges_{j}",
                     self.output_config_width))
                configurations.append(
                    (f"strg_ub_output_addr_ctrl_address_gen_{i}_strides_{j}",
                     self.output_config_width))

            configurations.append((f"strg_ub_pre_fetch_{i}_input_latency",
                                   kts.clog2(self.max_prefetch) + 1))
            configurations.append((f"strg_ub_sync_grp_sync_group_{i}",
                                   self.interconnect_output_ports))
            configurations.append(
                (f"strg_ub_rate_matched_{i}",
                 1 + kts.clog2(self.interconnect_input_ports)))

            for j in range(self.num_tb):
                configurations.append(
                    (f"strg_ub_tba_{i}_tb_{j}_dimensionality", 2))
                num_indices_bits = 1 + kts.clog2(self.fw_int)
                indices_per_feat = math.floor(self.config_data_width /
                                              num_indices_bits)
                new_width = num_indices_bits * indices_per_feat
                feat_num = 0
                num_feats_merge = math.ceil(self.tb_range_inner_max /
                                            indices_per_feat)
                for k in range(num_feats_merge):
                    num_idx = indices_per_feat
                    if (self.tb_range_inner_max -
                        (k * indices_per_feat)) < indices_per_feat:
                        num_idx = self.tb_range_inner_max - (k *
                                                             indices_per_feat)
                    merged_configs.append((
                        f"strg_ub_tba_{i}_tb_{j}_indices_merged_{k * indices_per_feat}",
                        num_idx * num_indices_bits, num_idx))


#                for k in range(self.tb_range_inner_max):
#                    configurations.append((f"strg_ub_tba_{i}_tb_{j}_indices_{k}", kts.clog2(self.fw_int) + 1))
                configurations.append((f"strg_ub_tba_{i}_tb_{j}_range_inner",
                                       kts.clog2(self.tb_range_inner_max)))
                configurations.append((f"strg_ub_tba_{i}_tb_{j}_range_outer",
                                       kts.clog2(self.tb_range_max)))
                configurations.append((f"strg_ub_tba_{i}_tb_{j}_stride",
                                       kts.clog2(self.max_tb_stride)))
                configurations.append((f"strg_ub_tba_{i}_tb_{j}_tb_height",
                                       max(1, kts.clog2(self.num_tb))))
                configurations.append((f"strg_ub_tba_{i}_tb_{j}_starting_addr",
                                       max(1, kts.clog2(self.fw_int))))

        # Do all the stuff for the main config
        main_feature = self.__features[0]
        for config_reg_name, width in configurations:
            main_feature.add_config(config_reg_name, width)
            if (width == 1):
                self.wire(main_feature.registers[config_reg_name].ports.O[0],
                          self.underlying.ports[config_reg_name][0])
            else:
                self.wire(main_feature.registers[config_reg_name].ports.O,
                          self.underlying.ports[config_reg_name])

        for config_reg_name, width, num_merged in merged_configs:
            main_feature.add_config(config_reg_name, width)
            token_under = config_reg_name.split("_")
            base_name = config_reg_name.split("_merged")[0]
            base_indices = int(config_reg_name.split("_merged_")[1])
            num_bits = width // num_merged
            for i in range(num_merged):
                self.wire(
                    main_feature.registers[config_reg_name].ports.
                    O[i * num_bits:(i + 1) * num_bits],
                    self.underlying.ports[f"{base_name}_{base_indices + i}"])

        # SRAM
        # These should also account for num features
        # or_all_cfg_rd = FromMagma(mantle.DefineOr(4, 1))
        or_all_cfg_rd = FromMagma(mantle.DefineOr(self.num_sram_features, 1))
        or_all_cfg_rd.instance_name = f"OR_CONFIG_WR_SRAM"
        or_all_cfg_wr = FromMagma(mantle.DefineOr(self.num_sram_features, 1))
        or_all_cfg_wr.instance_name = f"OR_CONFIG_RD_SRAM"
        for sram_index in range(self.num_sram_features):
            core_feature = self.__features[sram_index + 1]
            self.add_port(f"config_en_{sram_index}", magma.In(magma.Bit))
            # port aliasing
            core_feature.ports["config_en"] = \
                self.ports[f"config_en_{sram_index}"]
            self.wire(core_feature.ports.read_config_data,
                      self.underlying.ports[f"config_data_out_{sram_index}"])
            # also need to wire the sram signal
            # the config enable is the OR of the rd+wr
            or_gate_en = FromMagma(mantle.DefineOr(2, 1))
            or_gate_en.instance_name = f"OR_CONFIG_EN_SRAM_{sram_index}"

            self.wire(or_gate_en.ports.I0, core_feature.ports.config.write)
            self.wire(or_gate_en.ports.I1, core_feature.ports.config.read)
            self.wire(core_feature.ports.config_en,
                      self.underlying.ports["config_en"][sram_index])
            # Still connect to the OR of all the config rd/wr
            self.wire(core_feature.ports.config.write,
                      or_all_cfg_wr.ports[f"I{sram_index}"])
            self.wire(core_feature.ports.config.read,
                      or_all_cfg_rd.ports[f"I{sram_index}"])

        self.wire(or_all_cfg_rd.ports.O[0],
                  self.underlying.ports.config_read[0])
        self.wire(or_all_cfg_wr.ports.O[0],
                  self.underlying.ports.config_write[0])
        self._setup_config()

        conf_names = list(self.registers.keys())
        conf_names.sort()
        with open("mem_cfg.txt", "w+") as cfg_dump:
            for idx, reg in enumerate(conf_names):
                write_line = f"|{reg}|{idx}|{self.registers[reg].width}||\n"
                cfg_dump.write(write_line)
        with open("mem_synth.txt", "w+") as cfg_dump:
            for idx, reg in enumerate(conf_names):
                write_line = f"{reg}\n"
                cfg_dump.write(write_line)
Beispiel #5
0
    def __init__(self, data_width, data_depth):
        super().__init__(8, 32)

        self.data_width = data_width
        self.data_depth = data_depth
        TData = magma.Bits[self.data_width]
        TBit = magma.Bits[1]

        self.add_ports(data_in=magma.In(TData),
                       addr_in=magma.In(TData),
                       data_out=magma.Out(TData),
                       flush=magma.In(TBit),
                       wen_in=magma.In(TBit),
                       ren_in=magma.In(TBit),
                       stall=magma.In(magma.Bits[4]))
        # Instead of a single read_config_data, we have multiple for each
        # "sub"-feature of this core.
        self.ports.pop("read_config_data")

        wrapper = memory_core_genesis2.memory_core_wrapper
        param_mapping = memory_core_genesis2.param_mapping
        generator = wrapper.generator(param_mapping, mode="declare")
        circ = generator(data_width=self.data_width,
                         data_depth=self.data_depth)
        self.underlying = FromMagma(circ)

        self.wire(self.ports.data_in, self.underlying.ports.data_in)
        self.wire(self.ports.addr_in, self.underlying.ports.addr_in)
        self.wire(self.ports.data_out, self.underlying.ports.data_out)
        self.wire(self.ports.reset, self.underlying.ports.reset)
        self.wire(self.ports.flush[0], self.underlying.ports.flush)
        self.wire(self.ports.wen_in[0], self.underlying.ports.wen_in)
        self.wire(self.ports.ren_in[0], self.underlying.ports.ren_in)

        # PE core uses clk_en (essentially active low stall)
        self.stallInverter = FromMagma(mantle.DefineInvert(1))
        self.wire(self.stallInverter.ports.I, self.ports.stall[0:1])
        self.wire(self.stallInverter.ports.O[0], self.underlying.ports.clk_en)

        # TODO(rsetaluri): Actually wire these inputs.
        zero_signals = (
            ("config_en_linebuf", 1),
            ("chain_wen_in", 1),
            ("chain_in", self.data_width),
        )
        one_signals = (
            ("config_read", 1),
            ("config_write", 1),
        )
        # enable read and write by default
        for name, width in zero_signals:
            val = magma.bits(0, width) if width > 1 else magma.bit(0)
            self.wire(Const(val), self.underlying.ports[name])
        for name, width in one_signals:
            val = magma.bits(1, width) if width > 1 else magma.bit(1)
            self.wire(Const(val), self.underlying.ports[name])
        self.wire(Const(magma.bits(0, 24)),
                  self.underlying.ports.config_addr[0:24])
        # we have five features in total
        # 0:   LINEBUF
        # 1-4: SMEM
        # current setup is already in line buffer mode, so we pass self in
        # notice that config_en_linebuf is to change the address in the
        # line buffer mode, which is not used in practice
        self.__features: List[CoreFeature] = [CoreFeature(self, 0)]
        for sram_index in range(4):
            core_feature = CoreFeature(self, sram_index + 1)
            self.__features.append(core_feature)

        for idx, core_feature in enumerate(self.__features):
            self.add_port(f"config_{idx}", magma.In(ConfigurationType(8, 32)))
            # port aliasing
            core_feature.ports["config"] = self.ports[f"config_{idx}"]
        # or the signal up
        t = ConfigurationType(8, 32)
        t_names = ["config_addr", "config_data"]
        or_gates = {}
        for t_name in t_names:
            port_type = t[t_name]
            or_gate = FromMagma(
                mantle.DefineOr(len(self.__features), len(port_type)))
            or_gate.instance_name = f"OR_{t_name}_FEATURE"
            for idx, core_feature in enumerate(self.__features):
                self.wire(or_gate.ports[f"I{idx}"],
                          core_feature.ports.config[t_name])
            or_gates[t_name] = or_gate
        self.wire(or_gates["config_addr"].ports.O,
                  self.underlying.ports.config_addr[24:32])
        self.wire(or_gates["config_data"].ports.O,
                  self.underlying.ports.config_data)
        # only the first one has config_en
        self.wire(self.__features[0].ports.config.write[0],
                  self.underlying.ports.config_en)

        # read data out
        for idx, core_feature in enumerate(self.__features):
            self.add_port(f"read_config_data_{idx}", magma.Out(magma.Bits[32]))
            # port aliasing
            core_feature.ports["read_config_data"] = \
                self.ports[f"read_config_data_{idx}"]
        # MEM config
        self.wire(self.ports.read_config_data_0,
                  self.underlying.ports.read_data)
        # SRAM
        for sram_index in range(4):
            core_feature = self.__features[sram_index + 1]
            self.wire(core_feature.ports.read_config_data,
                      self.underlying.ports[f"read_data_sram_{sram_index}"])
            # also need to wire the sram signal
            self.add_port(f"config_en_{sram_index}", magma.In(magma.Bit))
            # port aliasing
            core_feature.ports["config_en"] = \
                self.ports[f"config_en_{sram_index}"]
            self.wire(self.underlying.ports["config_en_sram"][sram_index],
                      self.ports[f"config_en_{sram_index}"])