예제 #1
0
    def __init__(self, interconnect_output_ports, mem_depth, num_tiles, banks,
                 iterator_support, address_width, data_width, fetch_width,
                 chain_idx_output):

        self.interconnect_output_ports = interconnect_output_ports
        self.mem_depth = mem_depth
        self.num_tiles = num_tiles
        self.banks = banks
        self.iterator_support = iterator_support
        self.address_width = address_width
        self.data_width = data_width
        self.fetch_width = fetch_width
        self.fw_int = int(self.fetch_width / self.data_width)
        self.chain_idx_output = chain_idx_output

        self.config = {}

        # Create child address generators
        self.addr_gens = []
        for i in range(self.interconnect_output_ports):
            new_addr_gen = AddrGenModel(iterator_support=self.iterator_support,
                                        address_width=self.address_width)
            self.addr_gens.append(new_addr_gen)

        self.mem_addr_width = kts.clog2(self.num_tiles * self.mem_depth)
        self.chain_idx_bits = max(1, kts.clog2(self.num_tiles))

        # Get local list of addresses
        self.addresses = []
        for i in range(self.interconnect_output_ports):
            self.addresses.append(0)

        # Initialize the configuration
        for i in range(self.interconnect_output_ports):
            self.config[f"address_gen_{i}_starting_addr"] = 0
            self.config[f"address_gen_{i}_dimensionality"] = 0
            for j in range(self.iterator_support):
                self.config[f"address_gen_{i}_strides_{j}"] = 0
                self.config[f"address_gen_{i}_ranges_{j}"] = 0

        # Set up the wen
        self.ren = []
        self.mem_addresses = []
        for i in range(self.banks):
            self.ren.append([])
            for j in range(self.interconnect_output_ports):
                self.ren[i].append(0)
        for i in range(self.interconnect_output_ports):
            self.mem_addresses.append(0)
예제 #2
0
    def __init__(self, _params: GlobalBufferParams):
        super().__init__(f"glb_loop_iter")
        self._params = _params

        # INPUTS
        self.clk = self.clock("clk")
        self.clk_en = self.clock_en("clk_en")
        self.reset = self.reset("reset")

        self.ranges = self.input("ranges",
                                 self._params.axi_data_width,
                                 size=self._params.loop_level,
                                 packed=True,
                                 explicit_array=True)
        self.dim = self.input("dim", 1 + clog2(self._params.loop_level))
        self.step = self.input("step", 1)
        self.mux_sel_out = self.output("mux_sel_out",
                                       max(clog2(self._params.loop_level), 1))
        self.restart = self.output("restart", 1)

        # local varaibles
        self.dim_counter = self.var("dim_counter",
                                    self._params.axi_data_width,
                                    size=self._params.loop_level,
                                    packed=True,
                                    explicit_array=True)

        self.max_value = self.var("max_value", self._params.loop_level)
        self.mux_sel = self.var("mux_sel",
                                max(clog2(self._params.loop_level), 1))
        self.wire(self.mux_sel_out, self.mux_sel)

        self.not_done = self.var("not_done", 1)
        self.clear = self.var("clear", self._params.loop_level)
        self.inc = self.var("inc", self._params.loop_level)

        self.is_maxed = self.var("is_maxed", 1)
        self.wire(self.is_maxed,
                  (self.dim_counter[self.mux_sel] == self.ranges[self.mux_sel])
                  & self.inc[self.mux_sel])

        self.add_code(self.set_mux_sel)
        for i in range(self._params.loop_level):
            self.add_code(self.set_clear, idx=i)
            self.add_code(self.set_inc, idx=i)
            self.add_code(self.dim_counter_update, idx=i)
            self.add_code(self.max_value_update, idx=i)

        self.wire(self.restart, self.step & (~self.not_done))
예제 #3
0
def test_nested_scope():
    from kratos import clog2
    mod = Generator("FindHighestBit", True)
    width = 4
    data = mod.input("data", width)
    h_bit = mod.output("h_bit", clog2(width))
    done = mod.var("done", 1)

    @always_comb
    def find_bit():
        done = 0
        h_bit = 0
        for i in range(width):
            if ~done:
                if data[i]:
                    done = 1
                    h_bit = i

    mod.add_always(find_bit, label="block")
    verilog(mod, insert_debug_info=True)
    block = mod.get_marked_stmt("block")
    last_if = block[-1]
    for i in range(len(last_if.then_[-1].then_)):
        stmt = last_if.then_[-1].then_[i]
        context = stmt.scope_context
        if len(context) > 0:
            assert "i" in context
            is_var, var = context["i"]
            assert not is_var
            assert var == "3"
예제 #4
0
    def __init__(self, _params: GlobalBufferParams):
        super().__init__(f"glb_addr_gen")
        self._params = _params

        self.clk = self.clock("clk")
        self.clk_en = self.clock_en("clk_en")
        self.reset = self.reset("reset")
        self.restart = self.input("restart", 1)
        self.strides = self.input("strides",
                                  self._params.axi_data_width,
                                  size=self._params.loop_level,
                                  packed=True,
                                  explicit_array=True)
        self.start_addr = self.input("start_addr", self._params.axi_data_width)
        self.step = self.input("step", 1)
        self.mux_sel = self.input("mux_sel",
                                  max(clog2(self._params.loop_level), 1))
        self.addr_out = self.output("addr_out", self._params.axi_data_width)

        # local variables
        self.current_addr = self.var("current_addr",
                                     self._params.axi_data_width)

        # output address
        self.wire(self.addr_out, self.start_addr + self.current_addr)
        self.add_always(self.calculate_address)
예제 #5
0
    def __init__(self, data_width, width_mult, depth, num_tiles):
        self.data_width = data_width
        self.width_mult = width_mult
        self.depth = depth
        self.num_tiles = num_tiles
        self.address_width = kts.clog2(self.num_tiles * self.depth)

        self.chain_idx_bits = max(1, kts.clog2(num_tiles))

        self.chain_idx_tile = 0

        self.rd_reg = []
        for i in range(self.width_mult):
            self.rd_reg.append(0)
        self.mem = []
        for i in range(self.depth):
            row = []
            for j in range(self.width_mult):
                row.append(0)
            self.mem.append(row)
예제 #6
0
    def __init__(self, width, depth):
        super().__init__("FIFO")
        in_ = self.input("in", width)
        out = self.output("out", width)
        ren = self.input("ren", 1)
        wen = self.input("wen", 1)
        data = self.var("data", width, size=depth)
        clk = self.clock("clk")
        reset = self.reset("rst")
        read_ptr = self.var("read_ptr", clog2(depth))
        write_ptr = self.var("write_ptr", clog2(depth))
        full = self.output("is_full", 1)
        empty = self.output("is_empty", 1)
        full_next = self.var("is_full_next", 1)
        self.wire(empty, ~full & (read_ptr == write_ptr))

        def comb():
            if wen & (~ren) & ((write_ptr + 1) == read_ptr):
                full_next = True
            elif wen and full:
                full_next = False
            else:
                full_next = full

        @always((posedge, "clk"), (posedge, "rst"))
        def seq():
            if reset:
                read_ptr = 0
                write_ptr = 0
                full = 0
            else:
                if ren:
                    read_ptr = read_ptr + 1
                    out = data[read_ptr]
                if wen:
                    write_ptr = write_ptr + 1
                    data[write_ptr] = in_
                full = full_next

        self.add_code(comb)
        self.add_code(seq)
예제 #7
0
 def mem_ff(self):
     if self.CEB == 0:
         self.Q_w = concat(
             self.mem[resize((self.A << 2) + 3, self.addr_width + 2)],
             self.mem[resize((self.A << 2) + 2, self.addr_width + 2)],
             self.mem[resize((self.A << 2) + 1,
                             self.addr_width + 2)], self.mem[resize(
                                 (self.A << 2), self.addr_width + 2)])
         if self.WEB == 0:
             for i in range(self.data_width):
                 if self.BWEB[i] == 0:
                     self.mem[resize(
                         (self.A << 2) + i // 16,
                         self.addr_width + 2)][resize(
                             i % 16, clog2(
                                 self._params.cgra_data_width))] = self.D[i]
예제 #8
0
    def __init__(self,
                 width: int,
                 depth: int,
                 flatten_output=False,
                 reset_high=False):
        name_suffix = ""
        if flatten_output:
            name_suffix += "_array"
        if reset_high:
            name_suffix += "_reset_high"
        super().__init__(f"pipeline_w_{width}_d_{depth}{name_suffix}")
        self.clk = self.clock("clk")
        self.clk_en = self.clock_en("clk_en")
        self.reset = self.reset("reset")
        self.width = width
        self.depth = depth
        self.reset_high = reset_high

        if self.depth == 0:
            self.in_ = self.input("in_", self.width)
            self.out_ = self.output("out_", self.width)
            self.wire(self.out_, self.in_)
        else:
            self.depth_width = max(clog2(self.depth), 1)

            self.in_ = self.input("in_", self.width)
            if flatten_output:
                self.out_ = self.output("out_", self.width, size=self.depth)
            else:
                self.out_ = self.output("out_", self.width)

            if self.depth == 1 and self.width == 1:
                self.pipeline_r = self.var("pipeline_r",
                                           width=self.width,
                                           size=self.depth)
            else:
                self.pipeline_r = self.var("pipeline_r",
                                           width=self.width,
                                           size=self.depth,
                                           explicit_array=True)

            if flatten_output:
                self.wire(self.out_, self.pipeline_r)
            else:
                self.wire(self.out_, self.pipeline_r[self.depth - 1])
            self.add_always(self.pipeline)
예제 #9
0
    def add_done_pulse_pipeline(self):
        maximum_latency = 2 * self._params.num_glb_tiles + self.default_latency
        latency_width = clog2(maximum_latency)
        self.done_pulse_d_arr = self.var(
            "done_pulse_d_arr", 1, size=maximum_latency, explicit_array=True)
        self.done_pulse_pipeline = Pipeline(width=1,
                                            depth=maximum_latency,
                                            flatten_output=True)
        self.add_child("done_pulse_pipeline",
                       self.done_pulse_pipeline,
                       clk=self.clk,
                       clk_en=self.clk_en,
                       reset=self.reset,
                       in_=self.done_pulse_w,
                       out_=self.done_pulse_d_arr)

        self.wire(self.st_dma_done_pulse,
                  self.done_pulse_d_arr[resize(self.cfg_data_network_latency, latency_width) + self.default_latency])
예제 #10
0
    def __init__(self,
                 use_sram_stub,
                 sram_name,
                 data_width,
                 fw_int,
                 mem_depth,
                 mem_input_ports,
                 mem_output_ports,
                 address_width,
                 bank_num,
                 num_tiles,
                 # configuration registers passed down from top level
                 enable_chain_input,
                 enable_chain_output,
                 chain_idx_input,
                 chain_idx_output):

        # generation parameters
        self.use_sram_stub = use_sram_stub
        self.sram_name = sram_name
        self.data_width = data_width
        self.fw_int = fw_int
        self.mem_depth = mem_depth
        self.mem_input_ports = mem_input_ports
        self.mem_output_ports = mem_output_ports
        self.address_width = address_width
        self.bank_num = bank_num
        self.num_tiles = num_tiles

        # configuration registers passed down from top level
        self.enable_chain_input = enable_chain_input
        self.enable_chain_output = enable_chain_output
        self.chain_idx_input = chain_idx_input
        self.chain_idx_output = chain_idx_output

        self.chain_idx_bits = max(1, kts.clog2(num_tiles))

        self.prev_wen = 0
        self.prev_cen = 0

        self.sram = SRAMModel(data_width,
                              fw_int,
                              mem_depth,
                              num_tiles)
예제 #11
0
    def add_strm_rd_addr_pipeline(self):
        maximum_latency = 2 * self._params.num_glb_tiles + self.default_latency
        latency_width = clog2(maximum_latency)
        self.strm_rd_addr_d_arr = self.var("strm_rd_addr_d_arr",
                                           width=self._params.glb_addr_width,
                                           size=maximum_latency,
                                           explicit_array=True)
        self.strm_rd_addr_pipeline = Pipeline(
            width=self._params.glb_addr_width,
            depth=maximum_latency,
            flatten_output=True)
        self.add_child("strm_rd_addr_pipeline",
                       self.strm_rd_addr_pipeline,
                       clk=self.clk,
                       clk_en=self.clk_en,
                       reset=self.reset,
                       in_=self.strm_rd_addr_w,
                       out_=self.strm_rd_addr_d_arr)

        self.strm_data_sel = self.strm_rd_addr_d_arr[
            resize(self.cfg_data_network_latency, latency_width) +
            self.default_latency][self._params.bank_byte_offset - 1,
                                  self._params.cgra_byte_offset]
예제 #12
0
 def tile2tile_w2e_wiring(self):
     self.wire(self.proc_packet_w2e_wsti[0], self.proc_packet_d)
     self.wire(self.strm_packet_w2e_wsti[0], 0)
     self.wire(self.pcfg_packet_w2e_wsti[0], 0)
     for i in range(1, self._params.num_glb_tiles):
         self.wire(
             self.proc_packet_w2e_wsti[const(
                 i, clog2(self._params.num_glb_tiles))],
             self.proc_packet_w2e_esto[const(
                 (i - 1), clog2(self._params.num_glb_tiles))])
         self.wire(
             self.strm_packet_w2e_wsti[const(
                 i, clog2(self._params.num_glb_tiles))],
             self.strm_packet_w2e_esto[const(
                 (i - 1), clog2(self._params.num_glb_tiles))])
         self.wire(
             self.pcfg_packet_w2e_wsti[const(
                 i, clog2(self._params.num_glb_tiles))],
             self.pcfg_packet_w2e_esto[const(
                 (i - 1), clog2(self._params.num_glb_tiles))])
예제 #13
0
    def __init__(
            self,
            data_width=16,  # CGRA Params
            mem_width=64,
            mem_depth=512,
            banks=1,
            input_iterator_support=6,  # Addr Controllers
            output_iterator_support=6,
            input_config_width=16,
            output_config_width=16,
            interconnect_input_ports=2,  # Connection to int
            interconnect_output_ports=2,
            mem_input_ports=1,
            mem_output_ports=1,
            use_sram_stub=1,
            sram_macro_info=SRAMMacroInfo("TS1N16FFCLLSBLVTC512X32M4S"),
            read_delay=1,  # Cycle delay in read (SRAM vs Register File)
            rw_same_cycle=False,  # Does the memory allow r+w in same cycle?
            agg_height=4,
            max_agg_schedule=16,
            input_max_port_sched=16,
            output_max_port_sched=16,
            align_input=1,
            max_line_length=128,
            max_tb_height=1,
            tb_range_max=1024,
            tb_range_inner_max=64,
            tb_sched_max=16,
            max_tb_stride=15,
            num_tb=1,
            tb_iterator_support=2,
            multiwrite=1,
            max_prefetch=8,
            config_data_width=32,
            config_addr_width=8,
            num_tiles=2,
            app_ctrl_depth_width=16,
            remove_tb=False,
            fifo_mode=True,
            add_clk_enable=True,
            add_flush=True,
            core_reset_pos=False,
            stcl_valid_iter=4):

        super().__init__(config_addr_width, config_data_width)

        # Capture everything to the tile object
        self.data_width = data_width
        self.mem_width = mem_width
        self.mem_depth = mem_depth
        self.banks = banks
        self.fw_int = int(self.mem_width / self.data_width)
        self.input_iterator_support = input_iterator_support
        self.output_iterator_support = output_iterator_support
        self.input_config_width = input_config_width
        self.output_config_width = output_config_width
        self.interconnect_input_ports = interconnect_input_ports
        self.interconnect_output_ports = interconnect_output_ports
        self.mem_input_ports = mem_input_ports
        self.mem_output_ports = mem_output_ports
        self.use_sram_stub = use_sram_stub
        self.sram_macro_info = sram_macro_info
        self.read_delay = read_delay
        self.rw_same_cycle = rw_same_cycle
        self.agg_height = agg_height
        self.max_agg_schedule = max_agg_schedule
        self.input_max_port_sched = input_max_port_sched
        self.output_max_port_sched = output_max_port_sched
        self.align_input = align_input
        self.max_line_length = max_line_length
        self.max_tb_height = max_tb_height
        self.tb_range_max = tb_range_max
        self.tb_range_inner_max = tb_range_inner_max
        self.tb_sched_max = tb_sched_max
        self.max_tb_stride = max_tb_stride
        self.num_tb = num_tb
        self.tb_iterator_support = tb_iterator_support
        self.multiwrite = multiwrite
        self.max_prefetch = max_prefetch
        self.config_data_width = config_data_width
        self.config_addr_width = config_addr_width
        self.num_tiles = num_tiles
        self.remove_tb = remove_tb
        self.fifo_mode = fifo_mode
        self.add_clk_enable = add_clk_enable
        self.add_flush = add_flush
        self.core_reset_pos = core_reset_pos
        self.app_ctrl_depth_width = app_ctrl_depth_width
        self.stcl_valid_iter = stcl_valid_iter

        # Typedefs for ease
        TData = magma.Bits[self.data_width]
        TBit = magma.Bits[1]

        self.__inputs = []
        self.__outputs = []

        # Enumerate input and output ports
        # (clk and reset are assumed)
        if self.interconnect_input_ports > 1:
            for i in range(self.interconnect_input_ports):
                self.add_port(f"addr_in_{i}", magma.In(TData))
                self.__inputs.append(self.ports[f"addr_in_{i}"])
                self.add_port(f"data_in_{i}", magma.In(TData))
                self.__inputs.append(self.ports[f"data_in_{i}"])
                self.add_port(f"wen_in_{i}", magma.In(TBit))
                self.__inputs.append(self.ports[f"wen_in_{i}"])
        else:
            self.add_port("addr_in", magma.In(TData))
            self.__inputs.append(self.ports[f"addr_in"])
            self.add_port("data_in", magma.In(TData))
            self.__inputs.append(self.ports[f"data_in"])
            self.add_port("wen_in", magma.In(TBit))
            self.__inputs.append(self.ports.wen_in)

        if self.interconnect_output_ports > 1:
            for i in range(self.interconnect_output_ports):
                self.add_port(f"data_out_{i}", magma.Out(TData))
                self.__outputs.append(self.ports[f"data_out_{i}"])
                self.add_port(f"ren_in_{i}", magma.In(TBit))
                self.__inputs.append(self.ports[f"ren_in_{i}"])
                self.add_port(f"valid_out_{i}", magma.Out(TBit))
                self.__outputs.append(self.ports[f"valid_out_{i}"])
                # Chaining
                self.add_port(f"chain_valid_in_{i}", magma.In(TBit))
                self.__inputs.append(self.ports[f"chain_valid_in_{i}"])
                self.add_port(f"chain_data_in_{i}", magma.In(TData))
                self.__inputs.append(self.ports[f"chain_data_in_{i}"])
                self.add_port(f"chain_data_out_{i}", magma.Out(TData))
                self.__outputs.append(self.ports[f"chain_data_out_{i}"])
                self.add_port(f"chain_valid_out_{i}", magma.Out(TBit))
                self.__outputs.append(self.ports[f"chain_valid_out_{i}"])
        else:
            self.add_port("data_out", magma.Out(TData))
            self.__outputs.append(self.ports[f"data_out"])
            self.add_port(f"ren_in", magma.In(TBit))
            self.__inputs.append(self.ports[f"ren_in"])
            self.add_port(f"valid_out", magma.Out(TBit))
            self.__outputs.append(self.ports[f"valid_out"])
            self.add_port(f"chain_valid_in", magma.In(TBit))
            self.__inputs.append(self.ports[f"chain_valid_in"])
            self.add_port(f"chain_data_in", magma.In(TData))
            self.__inputs.append(self.ports[f"chain_data_in"])
            self.add_port(f"chain_data_out", magma.Out(TData))
            self.__outputs.append(self.ports[f"chain_data_out"])
            self.add_port(f"chain_valid_out", magma.Out(TBit))
            self.__outputs.append(self.ports[f"chain_valid_out"])

        self.add_ports(flush=magma.In(TBit),
                       full=magma.Out(TBit),
                       empty=magma.Out(TBit),
                       stall=magma.In(TBit),
                       sram_ready_out=magma.Out(TBit))

        self.__inputs.append(self.ports.flush)
        # self.__inputs.append(self.ports.stall)

        self.__outputs.append(self.ports.full)
        self.__outputs.append(self.ports.empty)
        self.__outputs.append(self.ports.sram_ready_out)

        cache_key = (self.data_width, self.mem_width, self.mem_depth,
                     self.banks, self.input_iterator_support,
                     self.output_iterator_support,
                     self.interconnect_input_ports,
                     self.interconnect_output_ports, self.use_sram_stub,
                     self.sram_macro_info, self.read_delay, self.rw_same_cycle,
                     self.agg_height, self.max_agg_schedule,
                     self.input_max_port_sched, self.output_max_port_sched,
                     self.align_input, self.max_line_length,
                     self.max_tb_height, self.tb_range_max, self.tb_sched_max,
                     self.max_tb_stride, self.num_tb, self.tb_iterator_support,
                     self.multiwrite, self.max_prefetch,
                     self.config_data_width, self.config_addr_width,
                     self.num_tiles, self.remove_tb, self.fifo_mode,
                     self.stcl_valid_iter, self.add_clk_enable, self.add_flush,
                     self.app_ctrl_depth_width)

        # Check for circuit caching
        if cache_key not in MemCore.__circuit_cache:

            # Instantiate core object here - will only use the object representation to
            # query for information. The circuit representation will be cached and retrieved
            # in the following steps.
            lt_dut = LakeTop(
                data_width=self.data_width,
                mem_width=self.mem_width,
                mem_depth=self.mem_depth,
                banks=self.banks,
                input_iterator_support=self.input_iterator_support,
                output_iterator_support=self.output_iterator_support,
                input_config_width=self.input_config_width,
                output_config_width=self.output_config_width,
                interconnect_input_ports=self.interconnect_input_ports,
                interconnect_output_ports=self.interconnect_output_ports,
                use_sram_stub=self.use_sram_stub,
                sram_macro_info=self.sram_macro_info,
                read_delay=self.read_delay,
                rw_same_cycle=self.rw_same_cycle,
                agg_height=self.agg_height,
                max_agg_schedule=self.max_agg_schedule,
                input_max_port_sched=self.input_max_port_sched,
                output_max_port_sched=self.output_max_port_sched,
                align_input=self.align_input,
                max_line_length=self.max_line_length,
                max_tb_height=self.max_tb_height,
                tb_range_max=self.tb_range_max,
                tb_range_inner_max=self.tb_range_inner_max,
                tb_sched_max=self.tb_sched_max,
                max_tb_stride=self.max_tb_stride,
                num_tb=self.num_tb,
                tb_iterator_support=self.tb_iterator_support,
                multiwrite=self.multiwrite,
                max_prefetch=self.max_prefetch,
                config_data_width=self.config_data_width,
                config_addr_width=self.config_addr_width,
                num_tiles=self.num_tiles,
                app_ctrl_depth_width=self.app_ctrl_depth_width,
                remove_tb=self.remove_tb,
                fifo_mode=self.fifo_mode,
                add_clk_enable=self.add_clk_enable,
                add_flush=self.add_flush,
                stcl_valid_iter=self.stcl_valid_iter)

            change_sram_port_pass = change_sram_port_names(
                use_sram_stub, sram_macro_info)
            circ = kts.util.to_magma(
                lt_dut,
                flatten_array=True,
                check_multiple_driver=False,
                optimize_if=False,
                check_flip_flop_always_ff=False,
                additional_passes={"change_sram_port": change_sram_port_pass})
            MemCore.__circuit_cache[cache_key] = (circ, lt_dut)
        else:
            circ, lt_dut = MemCore.__circuit_cache[cache_key]

        # Save as underlying circuit object
        self.underlying = FromMagma(circ)

        self.chain_idx_bits = max(1, kts.clog2(self.num_tiles))

        # put a 1-bit register and a mux to select the control signals
        # TODO: check if enable_chain_output needs to be here? I don't think so?
        control_signals = [("wen_in", self.interconnect_input_ports),
                           ("ren_in", self.interconnect_output_ports),
                           ("flush", 1),
                           ("chain_valid_in", self.interconnect_output_ports)]
        for control_signal, width in control_signals:
            # TODO: consult with Ankita to see if we can use the normal
            # mux here
            if width == 1:
                mux = MuxWrapper(2, 1, name=f"{control_signal}_sel")
                reg_value_name = f"{control_signal}_reg_value"
                reg_sel_name = f"{control_signal}_reg_sel"
                self.add_config(reg_value_name, 1)
                self.add_config(reg_sel_name, 1)
                self.wire(mux.ports.I[0], self.ports[control_signal])
                self.wire(mux.ports.I[1],
                          self.registers[reg_value_name].ports.O)
                self.wire(mux.ports.S, self.registers[reg_sel_name].ports.O)
                # 0 is the default wire, which takes from the routing network
                self.wire(mux.ports.O[0],
                          self.underlying.ports[control_signal][0])
            else:
                for i in range(width):
                    mux = MuxWrapper(2, 1, name=f"{control_signal}_{i}_sel")
                    reg_value_name = f"{control_signal}_{i}_reg_value"
                    reg_sel_name = f"{control_signal}_{i}_reg_sel"
                    self.add_config(reg_value_name, 1)
                    self.add_config(reg_sel_name, 1)
                    self.wire(mux.ports.I[0],
                              self.ports[f"{control_signal}_{i}"])
                    self.wire(mux.ports.I[1],
                              self.registers[reg_value_name].ports.O)
                    self.wire(mux.ports.S,
                              self.registers[reg_sel_name].ports.O)
                    # 0 is the default wire, which takes from the routing network
                    self.wire(mux.ports.O[0],
                              self.underlying.ports[control_signal][i])

        if self.interconnect_input_ports > 1:
            for i in range(self.interconnect_input_ports):
                self.wire(self.ports[f"data_in_{i}"],
                          self.underlying.ports[f"data_in_{i}"])
                self.wire(self.ports[f"addr_in_{i}"],
                          self.underlying.ports[f"addr_in_{i}"])
        else:
            self.wire(self.ports.addr_in, self.underlying.ports.addr_in)
            self.wire(self.ports.data_in, self.underlying.ports.data_in)

        if self.interconnect_output_ports > 1:
            for i in range(self.interconnect_output_ports):
                self.wire(self.ports[f"data_out_{i}"],
                          self.underlying.ports[f"data_out_{i}"])
                self.wire(self.ports[f"chain_data_in_{i}"],
                          self.underlying.ports[f"chain_data_in_{i}"])
                self.wire(self.ports[f"chain_data_out_{i}"],
                          self.underlying.ports[f"chain_data_out_{i}"])
        else:
            self.wire(self.ports.data_out, self.underlying.ports.data_out)
            self.wire(self.ports.chain_data_in,
                      self.underlying.ports.chain_data_in)
            self.wire(self.ports.chain_data_out,
                      self.underlying.ports.chain_data_out)

        # Need to invert this
        self.resetInverter = FromMagma(mantle.DefineInvert(1))
        self.wire(self.resetInverter.ports.I[0], self.ports.reset)
        self.wire(self.resetInverter.ports.O[0], self.underlying.ports.rst_n)
        self.wire(self.ports.clk, self.underlying.ports.clk)
        if self.interconnect_output_ports == 1:
            self.wire(self.ports.valid_out[0],
                      self.underlying.ports.valid_out[0])
            self.wire(self.ports.chain_valid_out[0],
                      self.underlying.ports.chain_valid_out[0])
        else:
            for j in range(self.interconnect_output_ports):
                self.wire(self.ports[f"valid_out_{j}"][0],
                          self.underlying.ports.valid_out[j])
                self.wire(self.ports[f"chain_valid_out_{j}"][0],
                          self.underlying.ports.chain_valid_out[j])
        self.wire(self.ports.empty[0], self.underlying.ports.empty[0])
        self.wire(self.ports.full[0], self.underlying.ports.full[0])

        # PE core uses clk_en (essentially active low stall)
        self.stallInverter = FromMagma(mantle.DefineInvert(1))
        self.wire(self.stallInverter.ports.I, self.ports.stall)
        self.wire(self.stallInverter.ports.O[0],
                  self.underlying.ports.clk_en[0])

        self.wire(self.ports.sram_ready_out[0],
                  self.underlying.ports.sram_ready_out[0])

        # we have six? features in total
        # 0:    TILE
        # 1:    TILE
        # 1-4:  SMEM
        # Feature 0: Tile
        self.__features: List[CoreFeature] = [self]
        # Features 1-4: SRAM
        self.num_sram_features = lt_dut.total_sets
        for sram_index in range(self.num_sram_features):
            core_feature = CoreFeature(self, sram_index + 1)
            self.__features.append(core_feature)

        # Wire the config
        for idx, core_feature in enumerate(self.__features):
            if (idx > 0):
                self.add_port(f"config_{idx}",
                              magma.In(ConfigurationType(8, 32)))
                # port aliasing
                core_feature.ports["config"] = self.ports[f"config_{idx}"]
        self.add_port("config", magma.In(ConfigurationType(8, 32)))

        # or the signal up
        t = ConfigurationType(8, 32)
        t_names = ["config_addr", "config_data"]
        or_gates = {}
        for t_name in t_names:
            port_type = t[t_name]
            or_gate = FromMagma(
                mantle.DefineOr(len(self.__features), len(port_type)))
            or_gate.instance_name = f"OR_{t_name}_FEATURE"
            for idx, core_feature in enumerate(self.__features):
                self.wire(or_gate.ports[f"I{idx}"],
                          core_feature.ports.config[t_name])
            or_gates[t_name] = or_gate

        self.wire(or_gates["config_addr"].ports.O,
                  self.underlying.ports.config_addr_in[0:8])
        self.wire(or_gates["config_data"].ports.O,
                  self.underlying.ports.config_data_in)

        # read data out
        for idx, core_feature in enumerate(self.__features):
            if (idx > 0):
                # self.add_port(f"read_config_data_{idx}",
                self.add_port(f"read_config_data_{idx}",
                              magma.Out(magma.Bits[32]))
                # port aliasing
                core_feature.ports["read_config_data"] = \
                    self.ports[f"read_config_data_{idx}"]

        # MEM Config
        configurations = [("tile_en", 1), ("fifo_ctrl_fifo_depth", 16),
                          ("mode", 2), ("enable_chain_output", 1),
                          ("enable_chain_input", 1)]
        #            ("stencil_width", 16), NOT YET

        merged_configs = []
        merged_in_sched = []
        merged_out_sched = []

        # Add config registers to configurations
        # TODO: Have lake spit this information out automatically from the wrapper

        configurations.append((f"chain_idx_input", self.chain_idx_bits))
        configurations.append((f"chain_idx_output", self.chain_idx_bits))
        for i in range(self.interconnect_input_ports):
            configurations.append((f"strg_ub_agg_align_{i}_line_length",
                                   kts.clog2(self.max_line_length)))
            configurations.append((f"strg_ub_agg_in_{i}_in_period",
                                   kts.clog2(self.input_max_port_sched)))

            # num_bits_in_sched = kts.clog2(self.agg_height)
            # sched_per_feat = math.floor(self.config_data_width / num_bits_in_sched)
            # new_width = num_bits_in_sched * sched_per_feat
            # feat_num = 0
            # num_feats_merge = math.ceil(self.input_max_port_sched / sched_per_feat)
            # for k in range(num_feats_merge):
            #    num_here = sched_per_feat
            #    if self.input_max_port_sched - (k * sched_per_feat) < sched_per_feat:
            #        num_here = self.input_max_port_sched - (k * sched_per_feat)
            #    merged_configs.append((f"strg_ub_agg_in_{i}_in_sched_merged_{k * sched_per_feat}",
            #                          num_here * num_bits_in_sched, num_here))
            for j in range(self.input_max_port_sched):
                configurations.append((f"strg_ub_agg_in_{i}_in_sched_{j}",
                                       kts.clog2(self.agg_height)))

            configurations.append((f"strg_ub_agg_in_{i}_out_period",
                                   kts.clog2(self.input_max_port_sched)))

            for j in range(self.output_max_port_sched):
                configurations.append((f"strg_ub_agg_in_{i}_out_sched_{j}",
                                       kts.clog2(self.agg_height)))

            configurations.append((f"strg_ub_app_ctrl_write_depth_wo_{i}",
                                   self.app_ctrl_depth_width))
            configurations.append((f"strg_ub_app_ctrl_write_depth_ss_{i}",
                                   self.app_ctrl_depth_width))
            configurations.append(
                (f"strg_ub_app_ctrl_coarse_write_depth_wo_{i}",
                 self.app_ctrl_depth_width))
            configurations.append(
                (f"strg_ub_app_ctrl_coarse_write_depth_ss_{i}",
                 self.app_ctrl_depth_width))

            configurations.append(
                (f"strg_ub_input_addr_ctrl_address_gen_{i}_dimensionality",
                 1 + kts.clog2(self.input_iterator_support)))
            configurations.append(
                (f"strg_ub_input_addr_ctrl_address_gen_{i}_starting_addr",
                 self.input_config_width))
            for j in range(self.input_iterator_support):
                configurations.append(
                    (f"strg_ub_input_addr_ctrl_address_gen_{i}_ranges_{j}",
                     self.input_config_width))
                configurations.append(
                    (f"strg_ub_input_addr_ctrl_address_gen_{i}_strides_{j}",
                     self.input_config_width))

        configurations.append(
            (f"strg_ub_app_ctrl_prefill", self.interconnect_output_ports))
        configurations.append((f"strg_ub_app_ctrl_coarse_prefill",
                               self.interconnect_output_ports))

        for i in range(self.stcl_valid_iter):
            configurations.append((f"strg_ub_app_ctrl_ranges_{i}", 16))
            configurations.append((f"strg_ub_app_ctrl_threshold_{i}", 16))

        for i in range(self.interconnect_output_ports):
            configurations.append((f"strg_ub_app_ctrl_input_port_{i}",
                                   kts.clog2(self.interconnect_input_ports)))
            configurations.append((f"strg_ub_app_ctrl_read_depth_{i}",
                                   self.app_ctrl_depth_width))
            configurations.append((f"strg_ub_app_ctrl_coarse_input_port_{i}",
                                   kts.clog2(self.interconnect_input_ports)))
            configurations.append((f"strg_ub_app_ctrl_coarse_read_depth_{i}",
                                   self.app_ctrl_depth_width))

            configurations.append(
                (f"strg_ub_output_addr_ctrl_address_gen_{i}_dimensionality",
                 1 + kts.clog2(self.output_iterator_support)))
            configurations.append(
                (f"strg_ub_output_addr_ctrl_address_gen_{i}_starting_addr",
                 self.output_config_width))
            for j in range(self.output_iterator_support):
                configurations.append(
                    (f"strg_ub_output_addr_ctrl_address_gen_{i}_ranges_{j}",
                     self.output_config_width))
                configurations.append(
                    (f"strg_ub_output_addr_ctrl_address_gen_{i}_strides_{j}",
                     self.output_config_width))

            configurations.append((f"strg_ub_pre_fetch_{i}_input_latency",
                                   kts.clog2(self.max_prefetch) + 1))
            configurations.append((f"strg_ub_sync_grp_sync_group_{i}",
                                   self.interconnect_output_ports))
            configurations.append(
                (f"strg_ub_rate_matched_{i}",
                 1 + kts.clog2(self.interconnect_input_ports)))

            for j in range(self.num_tb):
                configurations.append(
                    (f"strg_ub_tba_{i}_tb_{j}_dimensionality", 2))
                num_indices_bits = 1 + kts.clog2(self.fw_int)
                indices_per_feat = math.floor(self.config_data_width /
                                              num_indices_bits)
                new_width = num_indices_bits * indices_per_feat
                feat_num = 0
                num_feats_merge = math.ceil(self.tb_range_inner_max /
                                            indices_per_feat)
                for k in range(num_feats_merge):
                    num_idx = indices_per_feat
                    if (self.tb_range_inner_max -
                        (k * indices_per_feat)) < indices_per_feat:
                        num_idx = self.tb_range_inner_max - (k *
                                                             indices_per_feat)
                    merged_configs.append((
                        f"strg_ub_tba_{i}_tb_{j}_indices_merged_{k * indices_per_feat}",
                        num_idx * num_indices_bits, num_idx))


#                for k in range(self.tb_range_inner_max):
#                    configurations.append((f"strg_ub_tba_{i}_tb_{j}_indices_{k}", kts.clog2(self.fw_int) + 1))
                configurations.append((f"strg_ub_tba_{i}_tb_{j}_range_inner",
                                       kts.clog2(self.tb_range_inner_max)))
                configurations.append((f"strg_ub_tba_{i}_tb_{j}_range_outer",
                                       kts.clog2(self.tb_range_max)))
                configurations.append((f"strg_ub_tba_{i}_tb_{j}_stride",
                                       kts.clog2(self.max_tb_stride)))
                configurations.append((f"strg_ub_tba_{i}_tb_{j}_tb_height",
                                       max(1, kts.clog2(self.num_tb))))
                configurations.append((f"strg_ub_tba_{i}_tb_{j}_starting_addr",
                                       max(1, kts.clog2(self.fw_int))))

        # Do all the stuff for the main config
        main_feature = self.__features[0]
        for config_reg_name, width in configurations:
            main_feature.add_config(config_reg_name, width)
            if (width == 1):
                self.wire(main_feature.registers[config_reg_name].ports.O[0],
                          self.underlying.ports[config_reg_name][0])
            else:
                self.wire(main_feature.registers[config_reg_name].ports.O,
                          self.underlying.ports[config_reg_name])

        for config_reg_name, width, num_merged in merged_configs:
            main_feature.add_config(config_reg_name, width)
            token_under = config_reg_name.split("_")
            base_name = config_reg_name.split("_merged")[0]
            base_indices = int(config_reg_name.split("_merged_")[1])
            num_bits = width // num_merged
            for i in range(num_merged):
                self.wire(
                    main_feature.registers[config_reg_name].ports.
                    O[i * num_bits:(i + 1) * num_bits],
                    self.underlying.ports[f"{base_name}_{base_indices + i}"])

        # SRAM
        # These should also account for num features
        # or_all_cfg_rd = FromMagma(mantle.DefineOr(4, 1))
        or_all_cfg_rd = FromMagma(mantle.DefineOr(self.num_sram_features, 1))
        or_all_cfg_rd.instance_name = f"OR_CONFIG_WR_SRAM"
        or_all_cfg_wr = FromMagma(mantle.DefineOr(self.num_sram_features, 1))
        or_all_cfg_wr.instance_name = f"OR_CONFIG_RD_SRAM"
        for sram_index in range(self.num_sram_features):
            core_feature = self.__features[sram_index + 1]
            self.add_port(f"config_en_{sram_index}", magma.In(magma.Bit))
            # port aliasing
            core_feature.ports["config_en"] = \
                self.ports[f"config_en_{sram_index}"]
            self.wire(core_feature.ports.read_config_data,
                      self.underlying.ports[f"config_data_out_{sram_index}"])
            # also need to wire the sram signal
            # the config enable is the OR of the rd+wr
            or_gate_en = FromMagma(mantle.DefineOr(2, 1))
            or_gate_en.instance_name = f"OR_CONFIG_EN_SRAM_{sram_index}"

            self.wire(or_gate_en.ports.I0, core_feature.ports.config.write)
            self.wire(or_gate_en.ports.I1, core_feature.ports.config.read)
            self.wire(core_feature.ports.config_en,
                      self.underlying.ports["config_en"][sram_index])
            # Still connect to the OR of all the config rd/wr
            self.wire(core_feature.ports.config.write,
                      or_all_cfg_wr.ports[f"I{sram_index}"])
            self.wire(core_feature.ports.config.read,
                      or_all_cfg_rd.ports[f"I{sram_index}"])

        self.wire(or_all_cfg_rd.ports.O[0],
                  self.underlying.ports.config_read[0])
        self.wire(or_all_cfg_wr.ports.O[0],
                  self.underlying.ports.config_write[0])
        self._setup_config()

        conf_names = list(self.registers.keys())
        conf_names.sort()
        with open("mem_cfg.txt", "w+") as cfg_dump:
            for idx, reg in enumerate(conf_names):
                write_line = f"|{reg}|{idx}|{self.registers[reg].width}||\n"
                cfg_dump.write(write_line)
        with open("mem_synth.txt", "w+") as cfg_dump:
            for idx, reg in enumerate(conf_names):
                write_line = f"{reg}\n"
                cfg_dump.write(write_line)
예제 #14
0
    def __init__(self, _params: GlobalBufferParams):
        super().__init__("glb_core_load_dma")
        self._params = _params
        self.header = GlbHeader(self._params)
        assert self._params.bank_data_width == self._params.cgra_data_width * 4

        self.clk = self.clock("clk")
        self.clk_en = self.clock_en("clk_en")
        self.reset = self.reset("reset")

        self.data_g2f = self.output("data_g2f",
                                    width=self._params.cgra_data_width)
        self.data_valid_g2f = self.output("data_valid_g2f", width=1)

        self.rdrq_packet = self.output("rdrq_packet",
                                       self.header.rdrq_packet_t)
        self.rdrs_packet = self.input("rdrs_packet", self.header.rdrs_packet_t)

        self.cfg_ld_dma_num_repeat = self.input(
            "cfg_ld_dma_num_repeat",
            clog2(self._params.queue_depth) + 1)
        self.cfg_ld_dma_ctrl_use_valid = self.input(
            "cfg_ld_dma_ctrl_use_valid", 1)
        self.cfg_ld_dma_ctrl_mode = self.input("cfg_ld_dma_ctrl_mode", 2)
        self.cfg_data_network_latency = self.input("cfg_data_network_latency",
                                                   self._params.latency_width)
        self.cfg_ld_dma_header = self.input("cfg_ld_dma_header",
                                            self.header.cfg_dma_header_t,
                                            size=self._params.queue_depth)

        self.ld_dma_start_pulse = self.input("ld_dma_start_pulse", 1)
        self.ld_dma_done_pulse = self.output("ld_dma_done_pulse", 1)

        # local parameter
        self.default_latency = (
            self._params.glb_switch_pipeline_depth +
            self._params.glb_bank_memory_pipeline_depth +
            self._params.sram_gen_pipeline_depth +
            self._params.sram_gen_output_pipeline_depth +
            1  # SRAM macro read latency
            + self._params.glb_switch_pipeline_depth +
            2  # FIXME: Unnecessary delay of moving back and forth btw switch and router
            + 1  # load_dma cache register delay
        )

        # local variables
        self.strm_data = self.var("strm_data", self._params.cgra_data_width)
        self.strm_data_r = self.var("strm_data_r",
                                    self._params.cgra_data_width)
        self.strm_data_valid = self.var("strm_data_valid", 1)
        self.strm_data_valid_r = self.var("strm_data_valid_r", 1)
        self.strm_data_sel = self.var(
            "strm_data_sel",
            self._params.bank_byte_offset - self._params.cgra_byte_offset)

        self.strm_rd_en_w = self.var("strm_rd_en_w", 1)
        self.strm_rd_addr_w = self.var("strm_rd_addr_w",
                                       self._params.glb_addr_width)
        self.last_strm_rd_addr_r = self.var("last_strm_rd_addr_r",
                                            self._params.glb_addr_width)

        self.ld_dma_start_pulse_next = self.var("ld_dma_start_pulse_next", 1)
        self.ld_dma_start_pulse_r = self.var("ld_dma_start_pulse_r", 1)
        self.is_first = self.var("is_first", 1)

        self.ld_dma_done_pulse_w = self.var("ld_dma_done_pulse_w", 1)

        self.bank_addr_match = self.var("bank_addr_match", 1)
        self.bank_rdrq_rd_en = self.var("bank_rdrq_rd_en", 1)
        self.bank_rdrq_rd_addr = self.var("bank_rdrq_rd_addr",
                                          self._params.glb_addr_width)
        self.bank_rdrs_data_cache_r = self.var("bank_rdrs_data_cache_r",
                                               self._params.bank_data_width)

        self.strm_run = self.var("strm_run", 1)
        self.loop_done = self.var("loop_done", 1)
        self.cycle_valid = self.var("cycle_valid", 1)
        self.cycle_count = self.var("cycle_count", self._params.axi_data_width)
        self.cycle_current_addr = self.var("cycle_current_addr",
                                           self._params.axi_data_width)
        self.data_current_addr = self.var("data_current_addr",
                                          self._params.axi_data_width)
        self.loop_mux_sel = self.var("loop_mux_sel",
                                     clog2(self._params.loop_level))
        self.repeat_cnt = self.var("repeat_cnt",
                                   clog2(self._params.queue_depth) + 1)

        if self._params.queue_depth != 1:
            self.queue_sel_r = self.var("queue_sel_r",
                                        max(1, clog2(self.repeat_cnt.width)))

        # Current dma header
        self.current_dma_header = self.var("current_dma_header",
                                           self.header.cfg_dma_header_t)
        if self._params.queue_depth == 1:
            self.wire(self.cfg_ld_dma_header, self.current_dma_header)
        else:
            self.wire(self.cfg_ld_dma_header[self.queue_sel_r],
                      self.current_dma_header)

        if self._params.queue_depth != 1:
            self.add_always(self.queue_sel_ff)

        self.add_always(self.repeat_cnt_ff)
        self.add_always(self.cycle_counter)
        self.add_always(self.is_first_ff)
        self.add_always(self.strm_run_ff)
        self.add_always(self.strm_data_ff)
        self.add_strm_data_start_pulse_pipeline()
        self.add_ld_dma_done_pulse_pipeline()
        self.add_strm_rd_en_pipeline()
        self.add_strm_rd_addr_pipeline()
        self.add_always(self.ld_dma_start_pulse_logic)
        self.add_always(self.ld_dma_start_pulse_ff)
        self.add_always(self.strm_data_mux)
        self.add_always(self.ld_dma_done_pulse_logic)
        self.add_always(self.strm_rdrq_packet_ff)
        self.add_always(self.last_strm_rd_addr_ff)
        self.add_always(self.bank_rdrq_packet_logic)
        self.add_always(self.bank_rdrs_data_cache_ff)
        self.add_always(self.strm_data_logic)

        # Loop iteration shared for cycle and data
        self.loop_iter = GlbLoopIter(self._params)
        self.add_child("loop_iter",
                       self.loop_iter,
                       clk=self.clk,
                       clk_en=self.clk_en,
                       reset=self.reset,
                       step=self.cycle_valid,
                       mux_sel_out=self.loop_mux_sel,
                       restart=self.loop_done)
        self.wire(self.loop_iter.dim, self.current_dma_header[f"dim"])
        for i in range(self._params.loop_level):
            self.wire(self.loop_iter.ranges[i],
                      self.current_dma_header[f"range_{i}"])

        # Cycle stride
        self.cycle_stride_sched_gen = GlbSchedGen(self._params)
        self.add_child("cycle_stride_sched_gen",
                       self.cycle_stride_sched_gen,
                       clk=self.clk,
                       clk_en=self.clk_en,
                       reset=self.reset,
                       restart=self.ld_dma_start_pulse_r,
                       cycle_count=self.cycle_count,
                       current_addr=self.cycle_current_addr,
                       finished=self.loop_done,
                       valid_output=self.cycle_valid)

        self.cycle_stride_addr_gen = GlbAddrGen(self._params)
        self.add_child("cycle_stride_addr_gen",
                       self.cycle_stride_addr_gen,
                       clk=self.clk,
                       clk_en=self.clk_en,
                       reset=self.reset,
                       restart=self.ld_dma_start_pulse_r,
                       step=self.cycle_valid,
                       mux_sel=self.loop_mux_sel,
                       addr_out=self.cycle_current_addr)
        self.wire(
            self.cycle_stride_addr_gen.start_addr,
            ext(self.current_dma_header[f"cycle_start_addr"],
                self._params.axi_data_width))
        for i in range(self._params.loop_level):
            self.wire(self.cycle_stride_addr_gen.strides[i],
                      self.current_dma_header[f"cycle_stride_{i}"])

        # Data stride
        self.data_stride_addr_gen = GlbAddrGen(self._params)
        self.add_child("data_stride_addr_gen",
                       self.data_stride_addr_gen,
                       clk=self.clk,
                       clk_en=self.clk_en,
                       reset=self.reset,
                       restart=self.ld_dma_start_pulse_r,
                       step=self.cycle_valid,
                       mux_sel=self.loop_mux_sel,
                       addr_out=self.data_current_addr)
        self.wire(
            self.data_stride_addr_gen.start_addr,
            ext(self.current_dma_header[f"start_addr"],
                self._params.axi_data_width))
        for i in range(self._params.loop_level):
            self.wire(self.data_stride_addr_gen.strides[i],
                      self.current_dma_header[f"stride_{i}"])
예제 #15
0
파일: test_kratos.py 프로젝트: mfkiwl/hgdb
def test_kratos_single_instance(find_free_port, simulator):
    if not simulator.available():
        pytest.skip(simulator.__name__ + " not available")
    from kratos import Generator, clog2, always_ff, always_comb, verilog, posedge
    input_width = 16
    buffer_size = 4
    mod = Generator("mod", debug=True)
    clk = mod.clock("clk")
    rst = mod.reset("rst")
    in_ = mod.input("in", input_width)
    out = mod.output("out", input_width)
    counter = mod.var("count", clog2(buffer_size))
    # notice that verilator does not support probing on packed array!!!
    data = mod.var("data", input_width, size=buffer_size)

    @always_comb
    def sum_data():
        out = 0
        for i in range(buffer_size):
            out = out + data[i]

    @always_ff((posedge, clk), (posedge, rst))
    def buffer_logic():
        if rst:
            for i in range(buffer_size):
                data[i] = 0
            counter = 0
        else:
            data[counter] = in_
            counter += 1

    mod.add_always(sum_data, ssa_transform=True)
    mod.add_always(buffer_logic)
    py_line_num = get_line_num(py_filename, "            out = out + data[i]")

    with tempfile.TemporaryDirectory() as temp:
        temp = os.path.abspath(temp)
        db_filename = os.path.join(temp, "debug.db")
        sv_filename = os.path.join(temp, "mod.sv")
        verilog(mod, filename=sv_filename, insert_debug_info=True,
                debug_db_filename=db_filename, ssa_transform=True,
                insert_verilator_info=True)
        # run verilator
        tb = "test_kratos.cc" if simulator == VerilatorTester else "test_kratos.sv"
        main_file = get_vector_file(tb)
        with simulator(sv_filename, main_file, cwd=temp) as tester:
            port = find_free_port()
            uri = get_uri(port)
            # set the port
            tester.run(blocking=False, DEBUG_PORT=port, DEBUG_LOG=True)

            async def client_logic():
                client = hgdb.HGDBClient(uri, db_filename)
                await client.connect()
                # set breakpoint
                await client.set_breakpoint(py_filename, py_line_num)
                await client.continue_()
                for i in range(4):
                    bp = await client.recv()
                    assert bp["payload"]["instances"][0]["local"]["i"] == str(i)
                    if simulator == XceliumTester:
                        assert bp["payload"]["time"] == 10
                    await client.continue_()

                for i in range(4):
                    # the first breakpoint, out is not calculated yet
                    # so it should be 0
                    # after that, it should be 1
                    bp = await client.recv()
                    if simulator == XceliumTester:
                        assert bp["payload"]["time"] == 30
                    if i == 0:
                        assert bp["payload"]["instances"][0]["local"]["out"] == "0"
                    else:
                        assert bp["payload"]["instances"][0]["local"]["out"] == "1"
                    await client.continue_()

                # remove the breakpoint and set a conditional breakpoint
                # discard the current breakpoint information
                await client.recv()
                # remove the current breakpoint
                await client.remove_breakpoint(py_filename, py_line_num)
                await client.set_breakpoint(py_filename, py_line_num, cond="out == 6 && i == 3")
                await client.continue_()
                bp = await client.recv()
                assert bp["payload"]["instances"][0]["local"]["out"] == "6"
                assert bp["payload"]["instances"][0]["local"]["i"] == "3"
                assert bp["payload"]["instances"][0]["local"]["data.2"] == "3"

            asyncio.get_event_loop().run_until_complete(client_logic())
예제 #16
0
    def __init__(self, _params: GlobalBufferParams):
        super().__init__("glb_core_store_dma")
        self._params = _params
        self.header = GlbHeader(self._params)
        assert self._params.bank_data_width == self._params.cgra_data_width * 4

        self.clk = self.clock("clk")
        self.clk_en = self.clock_en("clk_en")
        self.reset = self.reset("reset")

        self.data_f2g = self.input(
            "data_f2g", width=self._params.cgra_data_width)
        self.data_valid_f2g = self.input("data_valid_f2g", width=1)

        self.wr_packet = self.output(
            "wr_packet", self.header.wr_packet_t)

        self.cfg_st_dma_num_repeat = self.input("cfg_st_dma_num_repeat", clog2(self._params.queue_depth) + 1)
        self.cfg_st_dma_ctrl_mode = self.input("cfg_st_dma_ctrl_mode", 2)
        self.cfg_st_dma_ctrl_use_valid = self.input("cfg_st_dma_ctrl_use_valid", 1)
        self.cfg_data_network_latency = self.input(
            "cfg_data_network_latency", self._params.latency_width)
        self.cfg_st_dma_header = self.input(
            "cfg_st_dma_header", self.header.cfg_dma_header_t, size=self._params.queue_depth, explicit_array=True)

        self.st_dma_start_pulse = self.input("st_dma_start_pulse", 1)
        self.st_dma_done_pulse = self.output("st_dma_done_pulse", 1)

        # localparam
        self.default_latency = (self._params.glb_bank_memory_pipeline_depth
                                + self._params.sram_gen_pipeline_depth
                                + self._params.glb_switch_pipeline_depth
                                )
        self.cgra_strb_width = self._params.cgra_data_width // 8
        self.cgra_strb_value = 2 ** (self._params.cgra_data_width // 8) - 1

        # local variables
        self.strm_wr_data_w = self.var("strm_wr_data_w", width=self._params.cgra_data_width)
        self.strm_wr_addr_w = self.var("strm_wr_addr_w", width=self._params.glb_addr_width)
        self.last_strm_wr_addr_r = self.var("last_strm_wr_addr_r", width=self._params.glb_addr_width)
        self.strm_wr_en_w = self.var("strm_wr_en_w", width=1)
        self.strm_data_sel = self.var("strm_data_sel", self._params.bank_byte_offset - self._params.cgra_byte_offset)

        self.bank_addr_match = self.var("bank_addr_match", 1)
        self.bank_wr_en = self.var("bank_wr_en", 1)
        self.bank_wr_addr = self.var("bank_wr_addr", width=self._params.glb_addr_width)
        self.bank_wr_data_cache_r = self.var("bank_wr_data_cache_r", self._params.bank_data_width)
        self.bank_wr_data_cache_w = self.var("bank_wr_data_cache_w", self._params.bank_data_width)
        self.bank_wr_strb_cache_r = self.var("bank_wr_strb_cache_r", math.ceil(self._params.bank_data_width / 8))
        self.bank_wr_strb_cache_w = self.var("bank_wr_strb_cache_w", math.ceil(self._params.bank_data_width / 8))

        self.done_pulse_w = self.var("done_pulse_w", 1)
        self.st_dma_start_pulse_next = self.var("st_dma_start_pulse_next", 1)
        self.st_dma_start_pulse_r = self.var("st_dma_start_pulse_r", 1)
        self.is_first = self.var("is_first", 1)
        self.is_last = self.var("is_last", 1)
        self.strm_run = self.var("strm_run", 1)
        self.loop_done = self.var("loop_done", 1)
        self.cycle_valid = self.var("cycle_valid", 1)
        self.cycle_valid_muxed = self.var("cycle_valid_muxed", 1)
        self.cycle_count = self.var("cycle_count", self._params.axi_data_width)
        self.cycle_current_addr = self.var("cycle_current_addr", self._params.axi_data_width)
        self.data_current_addr = self.var("data_current_addr", self._params.axi_data_width)
        self.loop_mux_sel = self.var("loop_mux_sel", clog2(self._params.loop_level))
        self.repeat_cnt = self.var("repeat_cnt", clog2(self._params.queue_depth) + 1)

        if self._params.queue_depth != 1:
            self.queue_sel_r = self.var("queue_sel_r", max(1, clog2(self.repeat_cnt.width)))

        # Current dma header
        self.current_dma_header = self.var("current_dma_header", self.header.cfg_dma_header_t)
        if self._params.queue_depth == 1:
            self.wire(self.cfg_st_dma_header, self.current_dma_header)
        else:
            self.wire(self.cfg_st_dma_header[self.queue_sel_r], self.current_dma_header)

        if self._params.queue_depth != 1:
            self.add_always(self.queue_sel_ff)

        self.add_always(self.repeat_cnt_ff)
        self.add_always(self.is_first_ff)
        self.add_always(self.is_last_ff)
        self.add_always(self.strm_run_ff)
        self.add_always(self.st_dma_start_pulse_logic)
        self.add_always(self.st_dma_start_pulse_ff)
        self.add_always(self.cycle_counter)
        self.add_always(self.cycle_valid_comb)
        self.add_always(self.strm_wr_packet_comb)
        self.add_always(self.last_strm_wr_addr_ff)
        self.add_always(self.strm_data_sel_comb)
        self.add_always(self.bank_wr_packet_cache_comb)
        self.add_always(self.bank_wr_packet_cache_ff)
        self.add_always(self.bank_wr_packet_logic)
        self.add_always(self.wr_packet_logic)
        self.add_always(self.strm_done_pulse_logic)
        self.add_done_pulse_pipeline()

        # Loop iteration shared for cycle and data
        self.loop_iter = GlbLoopIter(self._params)
        self.add_child("loop_iter",
                       self.loop_iter,
                       clk=self.clk,
                       clk_en=self.clk_en,
                       reset=self.reset,
                       step=self.cycle_valid_muxed,
                       mux_sel_out=self.loop_mux_sel,
                       restart=self.loop_done)
        self.wire(self.loop_iter.dim, self.current_dma_header[f"dim"])
        for i in range(self._params.loop_level):
            self.wire(self.loop_iter.ranges[i], self.current_dma_header[f"range_{i}"])

        # Cycle stride
        self.cycle_stride_sched_gen = GlbSchedGen(self._params)
        self.add_child("cycle_stride_sched_gen",
                       self.cycle_stride_sched_gen,
                       clk=self.clk,
                       clk_en=self.clk_en,
                       reset=self.reset,
                       restart=self.st_dma_start_pulse_r,
                       cycle_count=self.cycle_count,
                       current_addr=self.cycle_current_addr,
                       finished=self.loop_done,
                       valid_output=self.cycle_valid)

        self.cycle_stride_addr_gen = GlbAddrGen(self._params)
        self.add_child("cycle_stride_addr_gen",
                       self.cycle_stride_addr_gen,
                       clk=self.clk,
                       clk_en=self.clk_en,
                       reset=self.reset,
                       restart=self.st_dma_start_pulse_r,
                       step=self.cycle_valid_muxed,
                       mux_sel=self.loop_mux_sel,
                       addr_out=self.cycle_current_addr)
        self.wire(self.cycle_stride_addr_gen.start_addr, ext(
            self.current_dma_header[f"cycle_start_addr"], self._params.axi_data_width))
        for i in range(self._params.loop_level):
            self.wire(self.cycle_stride_addr_gen.strides[i],
                      self.current_dma_header[f"cycle_stride_{i}"])

        # Data stride
        self.data_stride_addr_gen = GlbAddrGen(self._params)
        self.add_child("data_stride_addr_gen",
                       self.data_stride_addr_gen,
                       clk=self.clk,
                       clk_en=self.clk_en,
                       reset=self.reset,
                       restart=self.st_dma_start_pulse_r,
                       step=self.cycle_valid_muxed,
                       mux_sel=self.loop_mux_sel,
                       addr_out=self.data_current_addr)
        self.wire(self.data_stride_addr_gen.start_addr, ext(
            self.current_dma_header[f"start_addr"], self._params.axi_data_width))
        for i in range(self._params.loop_level):
            self.wire(self.data_stride_addr_gen.strides[i], self.current_dma_header[f"stride_{i}"])
예제 #17
0
    def __init__(
            self,
            data_width=16,  # CGRA Params
            mem_width=64,
            mem_depth=512,
            banks=2,
            input_iterator_support=6,  # Addr Controllers
            output_iterator_support=6,
            interconnect_input_ports=1,  # Connection to int
            interconnect_output_ports=3,
            mem_input_ports=1,
            mem_output_ports=1,
            use_sram_stub=1,
            sram_name="default_name",
            read_delay=1,
            agg_height=8,
            max_agg_schedule=64,
            input_max_port_sched=64,
            output_max_port_sched=64,
            align_input=1,
            max_line_length=2048,
            tb_height=1,
            tb_range_max=2048,
            tb_range_inner_max=5,
            tb_sched_max=64,
            num_tb=1,
            multiwrite=2,
            max_prefetch=64,
            num_tiles=1,
            stcl_valid_iter=4):

        self.data_width = data_width
        self.mem_width = mem_width
        self.mem_depth = mem_depth
        self.banks = banks
        self.input_iterator_support = input_iterator_support
        self.output_iterator_support = output_iterator_support
        self.interconnect_input_ports = interconnect_input_ports
        self.interconnect_output_ports = interconnect_output_ports
        self.mem_input_ports = mem_input_ports
        self.mem_output_ports = mem_output_ports
        self.use_sram_stub = use_sram_stub
        self.sram_name = sram_name
        self.agg_height = agg_height
        self.max_agg_schedule = max_agg_schedule
        self.input_max_port_sched = input_max_port_sched
        self.output_max_port_sched = output_max_port_sched
        self.input_port_sched_width = kts.clog2(self.interconnect_input_ports)
        self.align_input = align_input
        self.max_line_length = max_line_length
        assert self.mem_width > self.data_width, "Data width needs to be smaller than mem"
        self.fw_int = int(self.mem_width / self.data_width)
        self.num_tb = num_tb
        self.tb_height = tb_height
        self.tb_range_max = tb_range_max
        self.tb_range_inner_max = tb_range_inner_max
        self.tb_sched_max = tb_sched_max
        self.multiwrite = multiwrite
        self.max_prefetch = max_prefetch
        self.read_delay = read_delay
        self.num_tiles = num_tiles
        self.stcl_valid_iter = stcl_valid_iter

        self.chain_idx_bits = max(1, kts.clog2(num_tiles))
        self.address_width = kts.clog2(self.mem_depth * num_tiles)

        self.config = {}

        # top level configuration registers

        # chaining
        self.config[f"enable_chain_input"] = 0
        self.config[f"enable_chain_output"] = 0
        self.config[f"chain_idx_input"] = 0
        self.config[f"chain_idx_output"] = 0

        self.config[f"tile_en"] = 0
        self.config[f"mode"] = 0
        for i in range(self.interconnect_output_ports):
            self.config[f"rate_matched_{i}"] = 0

        # Set up model..
        ### APP CTRL (FINE-GRAINED)
        self.app_ctrl = AppCtrlModel(
            int_in_ports=self.interconnect_input_ports,
            int_out_ports=self.interconnect_output_ports,
            sprt_stcl_valid=True,
            stcl_iter_support=self.stcl_valid_iter)
        for i in range(self.interconnect_input_ports):
            self.config[f"app_ctrl_write_depth_{i}"] = 0
        for i in range(self.interconnect_output_ports):
            self.config[f"app_ctrl_input_port_{i}"] = 0
            self.config[f"app_ctrl_read_depth_{i}"] = 0
            self.config[f"app_ctrl_prefill_{i}"] = 0

        # calculate stencil_valid with fine-grained application controller
        for i in range(self.stcl_valid_iter):
            self.config[f'app_ctrl_ranges_{i}'] = 0
            self.config[f'app_ctrl_app_ctrl_threshold_{i}'] = 0

        ### COARSE APP CTRL
        self.app_ctrl_coarse = AppCtrlModel(
            int_in_ports=self.interconnect_input_ports,
            int_out_ports=self.interconnect_output_ports,
            sprt_stcl_valid=False,
            # unused, stencil valid comes from app ctrl,
            # not coarse app ctrl
            stcl_iter_support=0)

        for i in range(self.interconnect_input_ports):
            self.config[f"app_ctrl_coarse_write_depth_{i}"] = 0
        for i in range(self.interconnect_output_ports):
            self.config[f"app_ctrl_coarse_input_port_{i}"] = 0
            self.config[f"app_ctrl_coarse_read_depth_{i}"] = 0
            self.config[f"app_ctrl_coarse_prefill_{i}"] = 0

        ### INST AGG ALIGNER
        if (self.agg_height > 0):
            self.agg_aligners = []
            for i in range(self.interconnect_input_ports):
                self.agg_aligners.append(
                    AggAlignerModel(data_width=self.data_width,
                                    max_line_length=self.max_line_length))
                self.config[f"agg_align_{i}_line_length"] = 0

        ### AGG BUFF
        self.agg_buffs = []
        for port in range(self.interconnect_input_ports):
            self.agg_buffs.append(
                AggBuffModel(agg_height=self.agg_height,
                             data_width=self.data_width,
                             mem_width=self.mem_width,
                             max_agg_schedule=self.max_agg_schedule))

            self.config[f"agg_in_{i}_in_period"] = 0
            self.config[f"agg_in_{i}_out_period"] = 0
            for j in range(self.max_agg_schedule):
                self.config[f"agg_in_{i}_in_sched_{j}"] = 0
                self.config[f"agg_in_{i}_out_sched_{j}"] = 0

        ### INPUT ADDR CTRL
        self.iac = InputAddrCtrlModel(
            interconnect_input_ports=self.interconnect_input_ports,
            data_width=self.data_width,
            fetch_width=self.mem_width,
            mem_depth=self.mem_depth,
            num_tiles=self.num_tiles,
            banks=self.banks,
            iterator_support=self.input_iterator_support,
            max_port_schedule=self.input_max_port_sched,
            address_width=self.address_width)

        for i in range(self.interconnect_input_ports):
            self.config[f"input_addr_ctrl_address_gen_{i}_dimensionality"] = 0
            self.config[f"input_addr_ctrl_address_gen_{i}_starting_addr"] = 0
            for j in range(self.input_iterator_support):
                self.config[f"input_addr_ctrl_address_gen_{i}_ranges_{j}"] = 0
                self.config[f"input_addr_ctrl_address_gen_{i}_strides_{j}"] = 0
            for j in range(self.multiwrite):
                self.config[f"input_addr_ctrl_offsets_cfg_{i}_{j}"] = 0

        ### OUTPUT ADDR CTRL
        self.oac = OutputAddrCtrlModel(
            interconnect_output_ports=self.interconnect_output_ports,
            mem_depth=self.mem_depth,
            num_tiles=self.num_tiles,
            data_width=self.data_width,
            fetch_width=self.mem_width,
            banks=self.banks,
            iterator_support=self.output_iterator_support,
            address_width=self.address_width,
            chain_idx_output=self.config[f"chain_idx_output"])
        for i in range(self.interconnect_output_ports):
            self.config[f"output_addr_ctrl_address_gen_{i}_dimensionality"] = 0
            self.config[f"output_addr_ctrl_address_gen_{i}_starting_addr"] = 0
            for j in range(self.input_iterator_support):
                self.config[f"output_addr_ctrl_address_gen_{i}_ranges_{j}"] = 0
                self.config[
                    f"output_addr_ctrl_address_gen_{i}_strides_{j}"] = 0

        ### RW ARBITER
        # Per bank allocation
        self.rw_arbs = []
        for bank in range(self.banks):
            self.rw_arbs.append(
                RWArbiterModel(fetch_width=self.mem_width,
                               data_width=self.data_width,
                               memory_depth=self.mem_depth,
                               int_out_ports=self.interconnect_output_ports,
                               read_delay=self.read_delay))

        self.mems = []
        if self.read_delay == 1:
            ### SRAMS
            for banks in range(self.banks):
                self.mems.append(
                    SRAMWrapperModel(
                        use_sram_stub=self.use_sram_stub,
                        sram_name=self.sram_name,
                        data_width=self.data_width,
                        fw_int=self.fw_int,
                        mem_depth=self.mem_depth,
                        mem_input_ports=self.mem_input_ports,
                        mem_output_ports=self.mem_output_ports,
                        address_width=self.address_width,
                        bank_num=banks,
                        num_tiles=self.num_tiles,
                        enable_chain_input=self.config[f"enable_chain_input"],
                        enable_chain_output=self.
                        config[f"enable_chain_output"],
                        chain_idx_input=self.config[f"chain_idx_input"],
                        chain_idx_output=self.config[f"chain_idx_output"]))
        else:
            ### REGFILES
            for banks in range(self.banks):
                self.mems.append(
                    RegisterFileModel(data_width=self.data_width,
                                      write_ports=self.mem_input_ports,
                                      read_ports=self.mem_output_ports,
                                      width_mult=self.fw_int,
                                      depth=self.mem_depth))

        ### DEMUX READS
        self.demux_reads = DemuxReadsModel(
            fetch_width=self.mem_width,
            data_width=self.data_width,
            banks=self.banks,
            int_out_ports=self.interconnect_output_ports)

        ### SYNC GROUPS
        self.sync_groups = SyncGroupsModel(
            fetch_width=self.mem_width,
            data_width=self.data_width,
            int_out_ports=self.interconnect_output_ports)
        for i in range(self.interconnect_output_ports):
            self.config[f"sync_grp_sync_group_{i}"] = 0

        ### PREFETCHERS
        self.prefetchers = []
        for port in range(self.interconnect_output_ports):
            self.prefetchers.append(
                PrefetcherModel(fetch_width=self.mem_width,
                                data_width=self.data_width,
                                max_prefetch=self.max_prefetch))
            self.config[f"pre_fetch_{port}_input_latency"] = 0

        ### TBAS
        self.tbas = []
        for port in range(self.interconnect_output_ports):
            self.tbas.append(
                TBAModel(word_width=self.data_width,
                         fetch_width=self.fw_int,
                         num_tb=self.num_tb,
                         tb_height=self.tb_height,
                         max_range=self.tb_range_max,
                         max_range_inner=self.tb_range_inner_max))
            for i in range(self.tb_height):
                self.config[f"tba_{port}_tb_{i}_range_inner"] = 0
                self.config[f"tba_{port}_tb_{i}_range_outer"] = 0
                self.config[f"tba_{port}_tb_{i}_stride"] = 0
                self.config[f"tba_{port}_tb_{i}_dimensionality"] = 0
                self.config[f"tba_{port}_tb_{i}_starting_addr"] = 0
                for j in range(self.tb_sched_max):
                    self.config[f"tba_{port}_tb_{i}_indices_{j}"] = 0
예제 #18
0
    def __init__(
            self,
            data_width=16,  # CGRA Params
            mem_depth=32,
            default_iterator_support=3,
            interconnect_input_ports=1,  # Connection to int
            interconnect_output_ports=1,
            config_data_width=32,
            config_addr_width=8,
            cycle_count_width=16,
            add_clk_enable=True,
            add_flush=True):

        lake_name = "Pond_pond"

        super().__init__(config_data_width=config_data_width,
                         config_addr_width=config_addr_width,
                         data_width=data_width,
                         name="PondCore")

        # Capture everything to the tile object
        self.interconnect_input_ports = interconnect_input_ports
        self.interconnect_output_ports = interconnect_output_ports
        self.mem_depth = mem_depth
        self.data_width = data_width
        self.config_data_width = config_data_width
        self.config_addr_width = config_addr_width
        self.add_clk_enable = add_clk_enable
        self.add_flush = add_flush
        self.cycle_count_width = cycle_count_width
        self.default_iterator_support = default_iterator_support
        self.default_config_width = kts.clog2(self.mem_depth)

        cache_key = (self.data_width, self.mem_depth,
                     self.interconnect_input_ports,
                     self.interconnect_output_ports, self.config_data_width,
                     self.config_addr_width, self.add_clk_enable,
                     self.add_flush, self.cycle_count_width,
                     self.default_iterator_support)

        # Check for circuit caching
        if cache_key not in LakeCoreBase._circuit_cache:
            # Instantiate core object here - will only use the object representation to
            # query for information. The circuit representation will be cached and retrieved
            # in the following steps.
            self.dut = Pond(
                data_width=data_width,  # CGRA Params
                mem_depth=mem_depth,
                default_iterator_support=default_iterator_support,
                interconnect_input_ports=
                interconnect_input_ports,  # Connection to int
                interconnect_output_ports=interconnect_output_ports,
                config_data_width=config_data_width,
                config_addr_width=config_addr_width,
                cycle_count_width=cycle_count_width,
                add_clk_enable=add_clk_enable,
                add_flush=add_flush)

            circ = kts.util.to_magma(self.dut,
                                     flatten_array=True,
                                     check_multiple_driver=False,
                                     optimize_if=False,
                                     check_flip_flop_always_ff=False)
            LakeCoreBase._circuit_cache[cache_key] = (circ, self.dut)
        else:
            circ, self.dut = LakeCoreBase._circuit_cache[cache_key]

        # Save as underlying circuit object
        self.underlying = FromMagma(circ)

        self.wrap_lake_core()

        conf_names = list(self.registers.keys())
        conf_names.sort()
        with open("pond_cfg.txt", "w+") as cfg_dump:
            for idx, reg in enumerate(conf_names):
                write_line = f"(\"{reg}\", 0),  # {self.registers[reg].width}\n"
                cfg_dump.write(write_line)
        with open("pond_synth.txt", "w+") as cfg_dump:
            for idx, reg in enumerate(conf_names):
                write_line = f"{reg}\n"
                cfg_dump.write(write_line)
예제 #19
0
    def __init__(self, _params: GlobalBufferParams):
        self._params = _params

        self.cfg_data_network_t = PackedStruct("cfg_data_network_t",
                                               [("tile_connected", 1),
                                                ("latency", self._params.latency_width)])

        self.cfg_pcfg_network_t = PackedStruct("cfg_pcfg_network_t",
                                               [("tile_connected", 1),
                                                ("latency", self._params.latency_width)])

        self.cfg_dma_ctrl_t = PackedStruct("dma_ctrl_t",
                                           [("mode", 2),
                                            ("use_valid", 1),
                                            ("data_mux", 2),
                                            ("num_repeat", clog2(self._params.queue_depth) + 1)])

        # NOTE: Kratos does not support struct of struct now.
        dma_header_struct_list = [("start_addr", self._params.glb_addr_width),
                                  ("cycle_start_addr", self._params.glb_addr_width)]
        dma_header_struct_list += [("dim", 1 + clog2(self._params.loop_level))]
        for i in range(self._params.loop_level):
            dma_header_struct_list += [(f"range_{i}", self._params.axi_data_width),
                                       (f"stride_{i}", self._params.axi_data_width),
                                       (f"cycle_stride_{i}", self._params.axi_data_width)]
        self.cfg_dma_header_t = PackedStruct("dma_header_t", dma_header_struct_list)

        # pcfg dma header
        self.cfg_pcfg_dma_ctrl_t = PackedStruct("pcfg_dma_ctrl_t", [("mode", 1)])
        self.cfg_pcfg_dma_header_t = PackedStruct("pcfg_dma_header_t",
                                                  [("start_addr", self._params.glb_addr_width),
                                                   ("num_cfg", self._params.max_num_cfg_width)])
        wr_packet_list = [("wr_en", 1),
                          ("wr_strb", math.ceil(self._params.bank_data_width / 8)),
                          ("wr_addr", self._params.glb_addr_width),
                          ("wr_data", self._params.bank_data_width), ]
        rdrq_packet_list = [("rd_en", 1),
                            ("rd_addr", self._params.glb_addr_width), ]
        rdrs_packet_list = [("rd_data", self._params.bank_data_width),
                            ("rd_data_valid", 1), ]

        self.packet_t = PackedStruct(
            "packet_t", wr_packet_list + rdrq_packet_list + rdrs_packet_list)
        self.rd_packet_t = PackedStruct(
            "rd_packet_t", rdrq_packet_list + rdrs_packet_list)
        self.rdrq_packet_t = PackedStruct("rdrq_packet_t", rdrq_packet_list)
        self.rdrs_packet_t = PackedStruct("rdrs_packet_t", rdrs_packet_list)

        self.wr_packet_t = PackedStruct("wr_packet_t", wr_packet_list)

        # NOTE: Kratos currently does not support struct of struct.
        # This can become cleaner if it does.
        self.wr_packet_ports = [name for (name, _) in wr_packet_list]
        self.rdrq_packet_ports = [name for (name, _) in rdrq_packet_list]
        self.rdrs_packet_ports = [name for (name, _) in rdrs_packet_list]
        self.rd_packet_ports = [name for (name, _) in (
            rdrq_packet_list + rdrs_packet_list)]
        self.packet_ports = [name for (name, _) in (
            rdrq_packet_list + rdrs_packet_list + wr_packet_list)]

        self.cgra_cfg_t = PackedStruct("cgra_cfg_t", [("rd_en", 1), ("wr_en", 1), (
            "addr", self._params.cgra_cfg_addr_width), ("data", self._params.cgra_cfg_data_width)])
예제 #20
0
    def __init__(self,
                 data_width=16,  # CGRA Params
                 mem_width=64,
                 mem_depth=512,
                 banks=1,
                 input_iterator_support=6,  # Addr Controllers
                 output_iterator_support=6,
                 input_config_width=16,
                 output_config_width=16,
                 interconnect_input_ports=2,  # Connection to int
                 interconnect_output_ports=2,
                 mem_input_ports=1,
                 mem_output_ports=1,
                 read_delay=1,  # Cycle delay in read (SRAM vs Register File)
                 rw_same_cycle=False,  # Does the memory allow r+w in same cycle?
                 agg_height=4,
                 max_agg_schedule=32,
                 input_max_port_sched=32,
                 output_max_port_sched=32,
                 align_input=1,
                 max_line_length=128,
                 max_tb_height=1,
                 tb_range_max=128,
                 tb_range_inner_max=5,
                 tb_sched_max=64,
                 max_tb_stride=15,
                 num_tb=1,
                 tb_iterator_support=2,
                 multiwrite=1,
                 num_tiles=1,
                 max_prefetch=8,
                 app_ctrl_depth_width=16,
                 remove_tb=False,
                 stcl_valid_iter=4):
        super().__init__("strg_ub")

        self.data_width = data_width
        self.mem_width = mem_width
        self.mem_depth = mem_depth
        self.banks = banks
        self.input_iterator_support = input_iterator_support
        self.output_iterator_support = output_iterator_support
        self.input_config_width = input_config_width
        self.output_config_width = output_config_width
        self.interconnect_input_ports = interconnect_input_ports
        self.interconnect_output_ports = interconnect_output_ports
        self.mem_input_ports = mem_input_ports
        self.mem_output_ports = mem_output_ports
        self.agg_height = agg_height
        self.max_agg_schedule = max_agg_schedule
        self.input_max_port_sched = input_max_port_sched
        self.output_max_port_sched = output_max_port_sched
        self.input_port_sched_width = clog2(self.interconnect_input_ports)
        self.align_input = align_input
        self.max_line_length = max_line_length
        assert self.mem_width >= self.data_width, "Data width needs to be smaller than mem"
        self.fw_int = int(self.mem_width / self.data_width)
        self.num_tb = num_tb
        self.max_tb_height = max_tb_height
        self.tb_range_max = tb_range_max
        self.tb_range_inner_max = tb_range_inner_max
        self.max_tb_stride = max_tb_stride
        self.tb_sched_max = tb_sched_max
        self.tb_iterator_support = tb_iterator_support
        self.multiwrite = multiwrite
        self.max_prefetch = max_prefetch
        self.num_tiles = num_tiles
        self.app_ctrl_depth_width = app_ctrl_depth_width
        self.remove_tb = remove_tb
        self.read_delay = read_delay
        self.rw_same_cycle = rw_same_cycle
        self.stcl_valid_iter = stcl_valid_iter
        # phases = [] TODO

        self.address_width = clog2(self.num_tiles * self.mem_depth)

        # CLK and RST
        self._clk = self.clock("clk")
        self._rst_n = self.reset("rst_n")

        # INPUTS
        self._data_in = self.input("data_in",
                                   self.data_width,
                                   size=self.interconnect_input_ports,
                                   packed=True,
                                   explicit_array=True)

        self._wen_in = self.input("wen_in", self.interconnect_input_ports)
        self._ren_input = self.input("ren_in", self.interconnect_output_ports)
        # Post rate matched
        self._ren_in = self.var("ren_in_muxed", self.interconnect_output_ports)
        # Processed versions of wen and ren from the app ctrl
        self._wen = self.var("wen", self.interconnect_input_ports)
        self._ren = self.var("ren", self.interconnect_output_ports)

        # Add rate matched
        # If one input port, let any output port use the wen_in as the ren_in
        # If more, do the same thing but also provide port selection
        if self.interconnect_input_ports == 1:
            self._rate_matched = self.input("rate_matched", self.interconnect_output_ports)
            self._rate_matched.add_attribute(ConfigRegAttr("Rate matched - 1 or 0"))
            for i in range(self.interconnect_output_ports):
                self.wire(self._ren_in[i],
                          kts.ternary(self._rate_matched[i],
                                      self._wen_in,
                                      self._ren_input[i]))
        else:
            self._rate_matched = self.input("rate_matched", 1 + kts.clog2(self.interconnect_input_ports),
                                            size=self.interconnect_output_ports,
                                            explicit_array=True,
                                            packed=True)
            self._rate_matched.add_attribute(ConfigRegAttr("Rate matched [input port | on/off]"))
            for i in range(self.interconnect_output_ports):
                self.wire(self._ren_in[i],
                          kts.ternary(self._rate_matched[i][0],
                                      self._wen_in[self._rate_matched[i][kts.clog2(self.interconnect_input_ports), 1]],
                                      self._ren_input[i]))

        self._arb_wen_en = self.var("arb_wen_en", self.interconnect_input_ports)
        self._arb_ren_en = self.var("arb_ren_en", self.interconnect_output_ports)

        self._data_from_strg = self.input("data_from_strg",
                                          self.data_width,
                                          size=(self.banks,
                                                self.mem_output_ports,
                                                self.fw_int),
                                          packed=True,
                                          explicit_array=True)

        self._mem_valid_data = self.input("mem_valid_data",
                                          self.mem_output_ports,
                                          size=self.banks,
                                          explicit_array=True,
                                          packed=True)

        self._out_mem_valid_data = self.var("out_mem_valid_data",
                                            self.mem_output_ports,
                                            size=self.banks,
                                            explicit_array=True,
                                            packed=True)

        # We need to signal valids out of the agg buff, only if one exists...
        if self.agg_height > 0:
            self._to_iac_valid = self.var("ab_to_mem_valid",
                                          self.interconnect_input_ports)

        self._data_out = self.output("data_out",
                                     self.data_width,
                                     size=self.interconnect_output_ports,
                                     packed=True,
                                     explicit_array=True)

        self._valid_out = self.output("valid_out",
                                      self.interconnect_output_ports)

        self._valid_out_alt = self.var("valid_out_alt",
                                       self.interconnect_output_ports)

        self._data_to_strg = self.output("data_to_strg",
                                         self.data_width,
                                         size=(self.banks,
                                               self.mem_input_ports,
                                               self.fw_int),
                                         packed=True,
                                         explicit_array=True)

        # If we can perform a read and a write on the same cycle,
        # this will necessitate a separate read and write address...
        if self.rw_same_cycle:
            self._wr_addr_out = self.output("wr_addr_out",
                                            self.address_width,
                                            size=(self.banks,
                                                  self.mem_input_ports),
                                            explicit_array=True,
                                            packed=True)

            self._rd_addr_out = self.output("rd_addr_out",
                                            self.address_width,
                                            size=(self.banks,
                                                  self.mem_output_ports),
                                            explicit_array=True,
                                            packed=True)

        else:
            self._addr_out = self.output("addr_out",
                                         self.address_width,
                                         size=(self.banks,
                                               self.mem_input_ports),
                                         packed=True,
                                         explicit_array=True)

        self._cen_to_strg = self.output("cen_to_strg", self.mem_output_ports,
                                        size=self.banks,
                                        explicit_array=True,
                                        packed=True)
        self._wen_to_strg = self.output("wen_to_strg", self.mem_input_ports,
                                        size=self.banks,
                                        explicit_array=True,
                                        packed=True)
        if self.num_tb > 0:
            self._tb_valid_out = self.var("tb_valid_out", self.interconnect_output_ports)

        self._port_wens = self.var("port_wens", self.interconnect_input_ports)

        ####################
        ##### APP CTRL #####
        ####################
        self._ack_transpose = self.var("ack_transpose",
                                       self.banks,
                                       size=self.interconnect_output_ports,
                                       explicit_array=True,
                                       packed=True)

        self._ack_reduced = self.var("ack_reduced",
                                     self.interconnect_output_ports)

        self.app_ctrl = AppCtrl(interconnect_input_ports=self.interconnect_input_ports,
                                interconnect_output_ports=self.interconnect_output_ports,
                                depth_width=self.app_ctrl_depth_width,
                                sprt_stcl_valid=True,
                                stcl_iter_support=self.stcl_valid_iter)

        # Some refactoring here for pond to get rid of app controllers...
        # This is honestly pretty messy and should clean up nicely when we have the spec...
        self._ren_out_reduced = self.var("ren_out_reduced",
                                         self.interconnect_output_ports)

        if self.num_tb == 0 or self.remove_tb:
            self.wire(self._wen, self._wen_in)
            self.wire(self._ren, self._ren_in)
            self.wire(self._valid_out, self._valid_out_alt)
            self.wire(self._arb_wen_en, self._wen)
            self.wire(self._arb_ren_en, self._ren)
        else:
            self.add_child("app_ctrl", self.app_ctrl,
                           clk=self._clk,
                           rst_n=self._rst_n,
                           wen_in=self._wen_in,
                           ren_in=self._ren_in,
                           #    ren_update=self._tb_valid_out,
                           valid_out_data=self._valid_out,
                           # valid_out_stencil=,
                           wen_out=self._wen,
                           ren_out=self._ren)

            self.wire(self.app_ctrl.ports.tb_valid, self._tb_valid_out)
            self.wire(self.app_ctrl.ports.ren_update, self._tb_valid_out)

            self.app_ctrl_coarse = AppCtrl(interconnect_input_ports=self.interconnect_input_ports,
                                           interconnect_output_ports=self.interconnect_output_ports,
                                           depth_width=self.app_ctrl_depth_width)
            self.add_child("app_ctrl_coarse", self.app_ctrl_coarse,
                           clk=self._clk,
                           rst_n=self._rst_n,
                           wen_in=self._to_iac_valid,  # self._port_wens & self._to_iac_valid,  # Gets valid and the ack
                           ren_in=self._ren_out_reduced,
                           tb_valid=kts.const(0, 1),
                           ren_update=self._ack_reduced,
                           wen_out=self._arb_wen_en,
                           ren_out=self._arb_ren_en)

        ###########################
        ##### INPUT AGG SCHED #####
        ###########################
        ###########################################
        ##### AGGREGATION ALIGNERS (OPTIONAL) #####
        ###########################################
        # These variables are holders and can be swapped out if needed
        self._data_consume = self._data_in
        self._valid_consume = self._wen
        # Zero out if not aligning
        if(self.agg_height > 0):
            self._align_to_agg = self.var("align_input",
                                          self.interconnect_input_ports)
        # Add the aggregation buffer aligners
        if(self.align_input):
            self._data_consume = self.var("data_consume",
                                          self.data_width,
                                          size=self.interconnect_input_ports,
                                          packed=True,
                                          explicit_array=True)
            self._valid_consume = self.var("valid_consume",
                                           self.interconnect_input_ports)

            # Make new aggregation aligners for each port
            for i in range(self.interconnect_input_ports):
                new_child = AggAligner(self.data_width, self.max_line_length)
                self.add_child(f"agg_align_{i}", new_child,
                               clk=self._clk,
                               rst_n=self._rst_n,
                               in_dat=self._data_in[i],
                               in_valid=self._wen[i],
                               align=self._align_to_agg[i],
                               out_valid=self._valid_consume[i],
                               out_dat=self._data_consume[i])
        else:
            if self.agg_height > 0:
                self.wire(self._align_to_agg, const(0, self._align_to_agg.width))
        ################################################
        ##### END: AGGREGATION ALIGNERS (OPTIONAL) #####
        ################################################

        if self.agg_height == 0:
            self._to_iac_dat = self._data_consume
            self._to_iac_valid = self._valid_consume

        ##################################
        ##### AGG BUFFERS (OPTIONAL) #####
        ##################################
        # Only instantiate agg_buffer if needed
        if(self.agg_height > 0):
            self._to_iac_dat = self.var("ab_to_mem_dat",
                                        self.mem_width,
                                        size=self.interconnect_input_ports,
                                        packed=True,
                                        explicit_array=True)

            # self._to_iac_valid = self.var("ab_to_mem_valid",
            #                               self.interconnect_input_ports)

            self._agg_buffers = []
            # Add input aggregations buffers
            for i in range(self.interconnect_input_ports):
                # add children aggregator buffers...
                agg_buffer_new = AggregationBuffer(self.agg_height,
                                                   self.data_width,
                                                   self.mem_width,
                                                   self.max_agg_schedule)
                self._agg_buffers.append(agg_buffer_new)
                self.add_child(f"agg_in_{i}", agg_buffer_new,
                               clk=self._clk,
                               rst_n=self._rst_n,
                               data_in=self._data_consume[i],
                               valid_in=self._valid_consume[i],
                               align=self._align_to_agg[i],
                               data_out=self._to_iac_dat[i],
                               valid_out=self._to_iac_valid[i])

        #######################################
        ##### END: AGG BUFFERS (OPTIONAL) #####
        #######################################

        self._ready_tba = self.var("ready_tba", self.interconnect_output_ports)
        ####################################
        ##### INPUT ADDRESS CONTROLLER #####
        ####################################
        self._wen_to_arb = self.var("wen_to_arb", self.mem_input_ports,
                                    size=self.banks,
                                    explicit_array=True,
                                    packed=True)
        self._addr_to_arb = self.var("addr_to_arb",
                                     self.address_width,
                                     size=(self.banks,
                                           self.mem_input_ports),
                                     explicit_array=True,
                                     packed=True)
        self._data_to_arb = self.var("data_to_arb",
                                     self.data_width,
                                     size=(self.banks,
                                           self.mem_input_ports,
                                           self.fw_int),
                                     explicit_array=True,
                                     packed=True)

        # Connect these inputs ports to an address generator
        iac = InputAddrCtrl(interconnect_input_ports=self.interconnect_input_ports,
                            mem_depth=self.mem_depth,
                            num_tiles=self.num_tiles,
                            banks=self.banks,
                            iterator_support=self.input_iterator_support,
                            address_width=self.address_width,
                            data_width=self.data_width,
                            fetch_width=self.mem_width,
                            multiwrite=self.multiwrite,
                            strg_wr_ports=self.mem_input_ports,
                            config_width=self.input_config_width)
        self.add_child(f"input_addr_ctrl", iac,
                       clk=self._clk,
                       rst_n=self._rst_n,
                       valid_in=self._to_iac_valid,
                       # wen_en=kts.concat(*([kts.const(1, 1)] * self.interconnect_input_ports)),
                       wen_en=self._arb_wen_en,
                       data_in=self._to_iac_dat,
                       wen_to_sram=self._wen_to_arb,
                       addr_out=self._addr_to_arb,
                       port_out=self._port_wens,
                       data_out=self._data_to_arb)

        #########################################
        ##### END: INPUT ADDRESS CONTROLLER #####
        #########################################
        self._arb_acks = self.var("arb_acks",
                                  self.interconnect_output_ports,
                                  size=self.banks,
                                  explicit_array=True,
                                  packed=True)

        self._prefetch_step = self.var("prefetch_step", self.interconnect_output_ports)
        self._oac_step = self.var("oac_step", self.interconnect_output_ports)
        self._oac_valid = self.var("oac_valid", self.interconnect_output_ports)
        self._ren_out = self.var("ren_out",
                                 self.interconnect_output_ports,
                                 size=self.banks,
                                 explicit_array=True,
                                 packed=True)
        self._ren_out_tpose = self.var("ren_out_tpose",
                                       self.banks,
                                       size=self.interconnect_output_ports,
                                       explicit_array=True,
                                       packed=True)

        self._oac_addr_out = self.var("oac_addr_out",
                                      self.address_width,
                                      size=self.interconnect_output_ports,
                                      explicit_array=True,
                                      packed=True)
        #####################################
        ##### OUTPUT ADDRESS CONTROLLER #####
        #####################################
        oac = OutputAddrCtrl(interconnect_output_ports=self.interconnect_output_ports,
                             mem_depth=self.mem_depth,
                             num_tiles=self.num_tiles,
                             banks=self.banks,
                             iterator_support=self.output_iterator_support,
                             address_width=self.address_width,
                             config_width=self.output_config_width)

        if self.remove_tb:
            self.wire(self._oac_valid, self._ren)
            self.wire(self._oac_step, self._ren)
        else:
            self.wire(self._oac_valid, self._prefetch_step)
            self.wire(self._oac_step, self._ack_reduced)

        self.chain_idx_bits = max(1, clog2(num_tiles))
        self._enable_chain_output = self.input("enable_chain_output", 1)
        self._chain_idx_output = self.input("chain_idx_output", self.chain_idx_bits)

        self.add_child(f"output_addr_ctrl", oac,
                       clk=self._clk,
                       rst_n=self._rst_n,
                       valid_in=self._oac_valid,
                       ren=self._ren_out,
                       addr_out=self._oac_addr_out,
                       step_in=self._oac_step)

        for i in range(self.interconnect_output_ports):
            for j in range(self.banks):
                self.wire(self._ren_out_tpose[i][j], self._ren_out[j][i])
        ##############################
        ##### READ/WRITE ARBITER #####
        ##############################
        # Hook up the read write arbiters for each bank
        self._arb_dat_out = self.var("arb_dat_out",
                                     self.data_width,
                                     size=(self.banks,
                                           self.mem_output_ports,
                                           self.fw_int),
                                     explicit_array=True,
                                     packed=True)

        self._arb_port_out = self.var("arb_port_out",
                                      self.interconnect_output_ports,
                                      size=(self.banks,
                                            self.mem_output_ports),
                                      explicit_array=True,
                                      packed=True)
        self._arb_valid_out = self.var("arb_valid_out", self.mem_output_ports,
                                       size=self.banks,
                                       explicit_array=True,
                                       packed=True)

        self._rd_sync_gate = self.var("rd_sync_gate",
                                      self.interconnect_output_ports)

        self.arbiters = []
        for i in range(self.banks):
            rw_arb = RWArbiter(fetch_width=self.mem_width,
                               data_width=self.data_width,
                               memory_depth=self.mem_depth,
                               num_tiles=self.num_tiles,
                               int_in_ports=self.interconnect_input_ports,
                               int_out_ports=self.interconnect_output_ports,
                               strg_wr_ports=self.mem_input_ports,
                               strg_rd_ports=self.mem_output_ports,
                               read_delay=self.read_delay,
                               rw_same_cycle=self.rw_same_cycle,
                               separate_addresses=self.rw_same_cycle)
            self.arbiters.append(rw_arb)
            self.add_child(f"rw_arb_{i}", rw_arb,
                           clk=self._clk,
                           rst_n=self._rst_n,
                           wen_in=self._wen_to_arb[i],
                           w_data=self._data_to_arb[i],
                           w_addr=self._addr_to_arb[i],
                           data_from_mem=self._data_from_strg[i],
                           mem_valid_data=self._mem_valid_data[i],
                           out_mem_valid_data=self._out_mem_valid_data[i],
                           ren_en=self._arb_ren_en,
                           rd_addr=self._oac_addr_out,
                           out_data=self._arb_dat_out[i],
                           out_port=self._arb_port_out[i],
                           out_valid=self._arb_valid_out[i],
                           cen_mem=self._cen_to_strg[i],
                           wen_mem=self._wen_to_strg[i],
                           data_to_mem=self._data_to_strg[i],
                           out_ack=self._arb_acks[i])

            # Bind the separate addrs
            if self.rw_same_cycle:
                self.wire(rw_arb.ports.wr_addr_to_mem, self._wr_addr_out[i])
                self.wire(rw_arb.ports.rd_addr_to_mem, self._rd_addr_out[i])
            else:
                self.wire(rw_arb.ports.addr_to_mem, self._addr_out[i])

            if self.remove_tb:
                self.wire(rw_arb.ports.ren_in, self._ren_out[i])
            else:
                self.wire(rw_arb.ports.ren_in, self._ren_out[i] & self._rd_sync_gate)

        self.num_tb_bits = max(1, clog2(self.num_tb))

        self._data_to_sync = self.var("data_to_sync",
                                      self.data_width,
                                      size=(self.interconnect_output_ports,
                                            self.fw_int),
                                      explicit_array=True,
                                      packed=True)

        self._valid_to_sync = self.var("valid_to_sync", self.interconnect_output_ports)

        self._data_to_tba = self.var("data_to_tba",
                                     self.data_width,
                                     size=(self.interconnect_output_ports,
                                           self.fw_int),
                                     explicit_array=True,
                                     packed=True)

        self._valid_to_tba = self.var("valid_to_tba", self.interconnect_output_ports)

        self._data_to_pref = self.var("data_to_pref",
                                      self.data_width,
                                      size=(self.interconnect_output_ports,
                                            self.fw_int),
                                      explicit_array=True,
                                      packed=True)

        self._valid_to_pref = self.var("valid_to_pref", self.interconnect_output_ports)
        #######################
        ##### DEMUX READS #####
        #######################
        dmux_rd = DemuxReads(fetch_width=self.mem_width,
                             data_width=self.data_width,
                             banks=self.banks,
                             int_out_ports=self.interconnect_output_ports,
                             strg_rd_ports=self.mem_output_ports)

        self._arb_dat_out_f = self.var("arb_dat_out_f",
                                       self.data_width,
                                       size=(self.banks * self.mem_output_ports,
                                             self.fw_int),
                                       explicit_array=True,
                                       packed=True)

        self._arb_port_out_f = self.var("arb_port_out_f",
                                        self.interconnect_output_ports,
                                        size=(self.banks * self.mem_output_ports),
                                        explicit_array=True,
                                        packed=True)
        self._arb_valid_out_f = self.var("arb_valid_out_f", self.mem_output_ports * self.banks)
        self._arb_mem_valid_data_f = self.var("arb_mem_valid_data_f", self.mem_output_ports * self.banks)

        self._arb_mem_valid_data_out = self.var("arb_mem_valid_data_out",
                                                self.interconnect_output_ports)

        self._mem_valid_data_sync = self.var("mem_valid_data_sync",
                                             self.interconnect_output_ports)

        self._mem_valid_data_pref = self.var("mem_valid_data_pref",
                                             self.interconnect_output_ports)

        tmp_cnt = 0
        for i in range(self.banks):
            for j in range(self.mem_output_ports):
                self.wire(self._arb_dat_out_f[tmp_cnt], self._arb_dat_out[i][j])
                self.wire(self._arb_port_out_f[tmp_cnt], self._arb_port_out[i][j])
                self.wire(self._arb_valid_out_f[tmp_cnt], self._arb_valid_out[i][j])
                self.wire(self._arb_mem_valid_data_f[tmp_cnt], self._out_mem_valid_data[i][j])
                tmp_cnt = tmp_cnt + 1

        # If this is end of the road...
        if self.remove_tb:
            assert self.fw_int == 1, "Make it easier on me now..."
            self.add_child("demux_rds", dmux_rd,
                           clk=self._clk,
                           rst_n=self._rst_n,
                           data_in=self._arb_dat_out_f,
                           mem_valid_data=self._arb_mem_valid_data_f,
                           mem_valid_data_out=self._arb_mem_valid_data_out,
                           valid_in=self._arb_valid_out_f,
                           port_in=self._arb_port_out_f,
                           valid_out=self._valid_out_alt)
            for i in range(self.interconnect_output_ports):
                self.wire(self._data_out[i], dmux_rd.ports.data_out[i])

        else:
            self.add_child("demux_rds", dmux_rd,
                           clk=self._clk,
                           rst_n=self._rst_n,
                           data_in=self._arb_dat_out_f,
                           mem_valid_data=self._arb_mem_valid_data_f,
                           mem_valid_data_out=self._arb_mem_valid_data_out,
                           valid_in=self._arb_valid_out_f,
                           port_in=self._arb_port_out_f,
                           data_out=self._data_to_sync,
                           valid_out=self._valid_to_sync)

            #######################
            ##### SYNC GROUPS #####
            #######################
            sync_group = SyncGroups(fetch_width=self.mem_width,
                                    data_width=self.data_width,
                                    int_out_ports=self.interconnect_output_ports)

            for i in range(self.interconnect_output_ports):
                self.wire(self._ren_out_reduced[i], self._ren_out_tpose[i].r_or())

            self.add_child("sync_grp", sync_group,
                           clk=self._clk,
                           rst_n=self._rst_n,
                           data_in=self._data_to_sync,
                           mem_valid_data=self._arb_mem_valid_data_out,
                           mem_valid_data_out=self._mem_valid_data_sync,
                           valid_in=self._valid_to_sync,
                           data_out=self._data_to_pref,
                           valid_out=self._valid_to_pref,
                           ren_in=self._ren_out_reduced,
                           rd_sync_gate=self._rd_sync_gate,
                           ack_in=self._ack_reduced)

            # This is the end of the line if we aren't using tb
            ######################
            ##### PREFETCHER #####
            ######################
            prefetchers = []
            for i in range(self.interconnect_output_ports):

                pref = Prefetcher(fetch_width=self.mem_width,
                                  data_width=self.data_width,
                                  max_prefetch=self.max_prefetch)

                prefetchers.append(pref)

                if self.num_tb == 0:
                    assert self.fw_int == 1, \
                        "If no transpose buffer, data width needs match memory width"
                    self.add_child(f"pre_fetch_{i}", pref,
                                   clk=self._clk,
                                   rst_n=self._rst_n,
                                   data_in=self._data_to_pref[i],
                                   mem_valid_data=self._mem_valid_data_sync[i],
                                   mem_valid_data_out=self._mem_valid_data_pref[i],
                                   valid_read=self._valid_to_pref[i],
                                   tba_rdy_in=self._ren[i],
                                   #    data_out=self._data_out[i],
                                   valid_out=self._valid_out_alt[i],
                                   prefetch_step=self._prefetch_step[i])
                    self.wire(self._data_out[i], pref.ports.data_out[0])
                else:
                    self.add_child(f"pre_fetch_{i}", pref,
                                   clk=self._clk,
                                   rst_n=self._rst_n,
                                   data_in=self._data_to_pref[i],
                                   mem_valid_data=self._mem_valid_data_sync[i],
                                   mem_valid_data_out=self._mem_valid_data_pref[i],
                                   valid_read=self._valid_to_pref[i],
                                   tba_rdy_in=self._ready_tba[i],
                                   data_out=self._data_to_tba[i],
                                   valid_out=self._valid_to_tba[i],
                                   prefetch_step=self._prefetch_step[i])

                    #############################
                    ##### TRANSPOSE BUFFERS #####
                    #############################
            if self.num_tb > 0:

                self._tb_data_out = self.var("tb_data_out", self.data_width,
                                             size=self.interconnect_output_ports,
                                             explicit_array=True,
                                             packed=True)
                self._tb_valid_out = self.var("tb_valid_out", self.interconnect_output_ports)

                for i in range(self.interconnect_output_ports):

                    tba = TransposeBufferAggregation(word_width=self.data_width,
                                                     fetch_width=self.fw_int,
                                                     num_tb=self.num_tb,
                                                     max_tb_height=self.max_tb_height,
                                                     max_range=self.tb_range_max,
                                                     max_range_inner=self.tb_range_inner_max,
                                                     max_stride=self.max_tb_stride,
                                                     tb_iterator_support=self.tb_iterator_support)

                    self.add_child(f"tba_{i}", tba,
                                   clk=self._clk,
                                   rst_n=self._rst_n,
                                   SRAM_to_tb_data=self._data_to_tba[i],
                                   valid_data=self._valid_to_tba[i],
                                   tb_index_for_data=0,
                                   ack_in=self._valid_to_tba[i],
                                   mem_valid_data=self._mem_valid_data_pref[i],
                                   tb_to_interconnect_data=self._tb_data_out[i],
                                   tb_to_interconnect_valid=self._tb_valid_out[i],
                                   tb_arbiter_rdy=self._ready_tba[i],
                                   tba_ren=self._ren[i])

                for i in range(self.interconnect_output_ports):
                    self.wire(self._data_out[i], self._tb_data_out[i])
                    # self.wire(self._valid_out[i], self._tb_valid_out[i])
            else:
                self.wire(self._valid_out, self._valid_out_alt)

        ####################
        ##### ADD CODE #####
        ####################
        self.add_code(self.transpose_acks)
        self.add_code(self.reduce_acks)
예제 #21
0
파일: tb_only.py 프로젝트: StanfordAHA/lake
    def __init__(
            self,
            data_width=16,  # CGRA Params
            mem_width=64,
            mem_depth=512,
            banks=1,
            input_addr_iterator_support=6,
            output_addr_iterator_support=6,
            input_sched_iterator_support=6,
            output_sched_iterator_support=6,
            config_width=16,
            #  output_config_width=16,
            interconnect_input_ports=2,  # Connection to int
            interconnect_output_ports=2,
            mem_input_ports=1,
            mem_output_ports=1,
            read_delay=1,  # Cycle delay in read (SRAM vs Register File)
            rw_same_cycle=False,  # Does the memory allow r+w in same cycle?
            agg_height=4,
            tb_height=2):

        super().__init__("strg_ub_tb_only")

        ##################################################################################
        # Capture constructor parameter...
        ##################################################################################
        self.fetch_width = mem_width // data_width
        self.interconnect_input_ports = interconnect_input_ports
        self.interconnect_output_ports = interconnect_output_ports
        self.agg_height = agg_height
        self.tb_height = tb_height
        self.mem_width = mem_width
        self.mem_depth = mem_depth
        self.config_width = config_width
        self.data_width = data_width
        self.input_addr_iterator_support = input_addr_iterator_support
        self.input_sched_iterator_support = input_sched_iterator_support

        self.default_iterator_support = 6
        self.default_config_width = 16
        self.sram_iterator_support = 6
        self.agg_rd_addr_gen_width = 8

        ##################################################################################
        # IO
        ##################################################################################
        self._clk = self.clock("clk")
        self._rst_n = self.reset("rst_n")

        self._cycle_count = self.input("cycle_count", 16)

        # data from SRAM
        self._sram_read_data = self.input("sram_read_data",
                                          self.data_width,
                                          size=self.fetch_width,
                                          packed=True,
                                          explicit_array=True)
        # read enable from SRAM
        self._t_read = self.input("t_read", self.interconnect_output_ports)

        # sram to tb for loop
        self._loops_sram2tb_mux_sel = self.input(
            "loops_sram2tb_mux_sel",
            width=max(clog2(self.default_iterator_support), 1),
            size=self.interconnect_output_ports,
            explicit_array=True,
            packed=True)

        self._loops_sram2tb_restart = self.input(
            "loops_sram2tb_restart",
            width=1,
            size=self.interconnect_output_ports,
            explicit_array=True,
            packed=True)

        self._valid_out = self.output("accessor_output",
                                      self.interconnect_output_ports)
        self._data_out = self.output("data_out",
                                     self.data_width,
                                     size=self.interconnect_output_ports,
                                     packed=True,
                                     explicit_array=True)

        ##################################################################################
        # TB RELEVANT SIGNALS
        ##################################################################################
        self._tb = self.var("tb",
                            width=self.data_width,
                            size=(self.interconnect_output_ports,
                                  self.tb_height, self.fetch_width),
                            packed=True,
                            explicit_array=True)

        self._tb_write_addr = self.var("tb_write_addr",
                                       2 + max(1, clog2(self.tb_height)),
                                       size=self.interconnect_output_ports,
                                       packed=True,
                                       explicit_array=True)

        self._tb_read_addr = self.var("tb_read_addr",
                                      2 + max(1, clog2(self.tb_height)),
                                      size=self.interconnect_output_ports,
                                      packed=True,
                                      explicit_array=True)

        # write enable to tb, delayed 1 cycle from SRAM reads
        self._t_read_d1 = self.var("t_read_d1", self.interconnect_output_ports)
        # read enable for reads from tb
        self._tb_read = self.var("tb_read", self.interconnect_output_ports)

        # Break out valids...
        self.wire(self._valid_out, self._tb_read)

        # delayed input mux_sel and restart signals from sram read/tb write
        # for loop and scheduling
        self._mux_sel_d1 = self.var("mux_sel_d1",
                                    kts.clog2(self.default_iterator_support),
                                    size=self.interconnect_output_ports,
                                    packed=True,
                                    explicit_array=True)

        self._restart_d1 = self.var("restart_d1",
                                    width=1,
                                    size=self.interconnect_output_ports,
                                    explicit_array=True,
                                    packed=True)

        for i in range(self.interconnect_output_ports):
            # signals delayed by 1 cycle from SRAM
            @always_ff((posedge, "clk"), (negedge, "rst_n"))
            def delay_read():
                if ~self._rst_n:
                    self._t_read_d1[i] = 0
                    self._mux_sel_d1[i] = 0
                    self._restart_d1[i] = 0
                else:
                    self._t_read_d1[i] = self._t_read[i]
                    self._mux_sel_d1[i] = self._loops_sram2tb_mux_sel[i]
                    self._restart_d1[i] = self._loops_sram2tb_restart[i]

            self.add_code(delay_read)

        ##################################################################################
        # TB PATHS
        ##################################################################################
        for i in range(self.interconnect_output_ports):

            self.tb_iter_support = 6
            self.tb_addr_width = 4
            self.tb_range_width = 16

            _AG = AddrGen(iterator_support=self.default_iterator_support,
                          config_width=self.tb_addr_width)

            self.add_child(f"tb_write_addr_gen_{i}",
                           _AG,
                           clk=self._clk,
                           rst_n=self._rst_n,
                           step=self._t_read_d1[i],
                           mux_sel=self._mux_sel_d1[i],
                           restart=self._restart_d1[i])
            safe_wire(gen=self,
                      w_to=self._tb_write_addr[i],
                      w_from=_AG.ports.addr_out)

            @always_ff((posedge, "clk"))
            def tb_ctrl():
                if self._t_read_d1[i]:
                    self._tb[i][self._tb_write_addr[i][0]] = \
                        self._sram_read_data

            self.add_code(tb_ctrl)

            # READ FROM TB

            fl_ctr_tb_rd = ForLoop(iterator_support=self.tb_iter_support,
                                   config_width=self.tb_range_width)

            self.add_child(f"loops_buf2out_read_{i}",
                           fl_ctr_tb_rd,
                           clk=self._clk,
                           rst_n=self._rst_n,
                           step=self._tb_read[i])

            _AG = AddrGen(iterator_support=self.tb_iter_support,
                          config_width=self.tb_addr_width)
            self.add_child(
                f"tb_read_addr_gen_{i}",
                _AG,
                clk=self._clk,
                rst_n=self._rst_n,
                step=self._tb_read[i],
                # addr_out=self._tb_read_addr[i])
                mux_sel=fl_ctr_tb_rd.ports.mux_sel_out,
                restart=fl_ctr_tb_rd.ports.restart)
            safe_wire(gen=self,
                      w_to=self._tb_read_addr[i],
                      w_from=_AG.ports.addr_out)

            self.add_child(
                f"tb_read_sched_gen_{i}",
                SchedGen(
                    iterator_support=self.tb_iter_support,
                    # config_width=self.tb_addr_width),
                    config_width=16),
                clk=self._clk,
                rst_n=self._rst_n,
                cycle_count=self._cycle_count,
                mux_sel=fl_ctr_tb_rd.ports.mux_sel_out,
                finished=fl_ctr_tb_rd.ports.restart,
                valid_output=self._tb_read[i])

            @always_comb
            def tb_to_out():
                self._data_out[i] = self._tb[i][self._tb_read_addr[i][
                    clog2(self.tb_height) + clog2(self.fetch_width) - 1,
                    clog2(self.fetch_width)]][self._tb_read_addr[i][
                        clog2(self.fetch_width) - 1, 0]]

            self.add_code(tb_to_out)
예제 #22
0
    def __init__(self,
                 data_width=16,
                 banks=1,
                 memory_width=64,
                 rw_same_cycle=False,
                 read_delay=1,
                 addr_width=9):
        super().__init__("strg_fifo")

        # Generation parameters
        self.banks = banks
        self.data_width = data_width
        self.memory_width = memory_width
        self.rw_same_cycle = rw_same_cycle
        self.read_delay = read_delay
        self.addr_width = addr_width
        self.fw_int = int(self.memory_width / self.data_width)

        # assert banks > 1 or rw_same_cycle is True or self.fw_int > 1, \
        #     "Can't sustain throughput with this setup. Need potential bandwidth for " + \
        #     "1 write and 1 read in a cycle - try using more banks or a macro that supports 1R1W"

        # Clock and Reset
        self._clk = self.clock("clk")
        self._rst_n = self.reset("rst_n")

        # Inputs + Outputs
        self._push = self.input("push", 1)
        self._data_in = self.input("data_in", self.data_width)
        self._pop = self.input("pop", 1)

        self._data_out = self.output("data_out", self.data_width)
        self._valid_out = self.output("valid_out", 1)
        self._empty = self.output("empty", 1)
        self._full = self.output("full", 1)

        # get relevant signals from the storage banks
        self._data_from_strg = self.input("data_from_strg", self.data_width,
                                          size=(self.banks,
                                                self.fw_int),
                                          explicit_array=True,
                                          packed=True)

        self._wen_addr = self.var("wen_addr", self.addr_width,
                                  size=self.banks,
                                  explicit_array=True,
                                  packed=True)

        self._ren_addr = self.var("ren_addr", self.addr_width,
                                  size=self.banks,
                                  explicit_array=True,
                                  packed=True)

        self._front_combined = self.var("front_combined", self.data_width,
                                        size=self.fw_int,
                                        explicit_array=True,
                                        packed=True)

        self._data_to_strg = self.output("data_to_strg", self.data_width,
                                         size=(self.banks,
                                               self.fw_int),
                                         explicit_array=True,
                                         packed=True)

        self._wen_to_strg = self.output("wen_to_strg", self.banks)
        self._ren_to_strg = self.output("ren_to_strg", self.banks)

        self._num_words_mem = self.var("num_words_mem", self.data_width)

        if self.banks == 1:
            self._curr_bank_wr = self.var("curr_bank_wr", 1)
            self.wire(self._curr_bank_wr, kts.const(0, 1))
            self._curr_bank_rd = self.var("curr_bank_rd", 1)
            self.wire(self._curr_bank_rd, kts.const(0, 1))
        else:
            self._curr_bank_wr = self.var("curr_bank_wr", kts.clog2(self.banks))
            self._curr_bank_rd = self.var("curr_bank_rd", kts.clog2(self.banks))

        self._write_queue = self.var("write_queue", self.data_width,
                                     size=(self.banks,
                                           self.fw_int),
                                     explicit_array=True,
                                     packed=True)

        # Lets us know if the bank has a write queued up
        self._queued_write = self.var("queued_write", self.banks)

        self._front_data_out = self.var("front_data_out", self.data_width)
        self._front_pop = self.var("front_pop", 1)
        self._front_empty = self.var("front_empty", 1)
        self._front_full = self.var("front_full", 1)
        self._front_valid = self.var("front_valid", 1)
        self._front_par_read = self.var("front_par_read", 1)
        self._front_par_out = self.var("front_par_out", self.data_width,
                                       size=(self.fw_int,
                                             1),
                                       explicit_array=True,
                                       packed=True)

        self._front_rd_ptr = self.var("front_rd_ptr", max(1, clog2(self.fw_int)))

        self._front_push = self.var("front_push", 1)
        self.wire(self._front_push, self._push & (~self._full | self._pop))

        self._front_rf = RegFIFO(data_width=self.data_width,
                                 width_mult=1,
                                 depth=self.fw_int,
                                 parallel=True,
                                 break_out_rd_ptr=True)

        # This one breaks out the read pointer so we can properly
        # reorder the data to storage

        self.add_child("front_rf", self._front_rf,
                       clk=self._clk,
                       clk_en=kts.const(1, 1),
                       rst_n=self._rst_n,
                       push=self._front_push,
                       pop=self._front_pop,
                       empty=self._front_empty,
                       full=self._front_full,
                       valid=self._front_valid,
                       parallel_read=self._front_par_read,
                       parallel_load=kts.const(0, 1),  # We don't need to parallel load the front
                       parallel_in=0,  # Same reason as above
                       parallel_out=self._front_par_out,
                       num_load=0,
                       rd_ptr_out=self._front_rd_ptr)
        self.wire(self._front_rf.ports.data_in[0], self._data_in)
        self.wire(self._front_data_out, self._front_rf.ports.data_out[0])

        self._back_data_in = self.var("back_data_in", self.data_width)
        self._back_data_out = self.var("back_data_out", self.data_width)
        self._back_push = self.var("back_push", 1)
        self._back_empty = self.var("back_empty", 1)
        self._back_full = self.var("back_full", 1)
        self._back_valid = self.var("back_valid", 1)
        self._back_pl = self.var("back_pl", 1)
        self._back_par_in = self.var("back_par_in", self.data_width,
                                     size=(self.fw_int,
                                           1),
                                     explicit_array=True,
                                     packed=True)
        self._back_num_load = self.var("back_num_load", clog2(self.fw_int) + 1)

        self._back_occ = self.var("back_occ", clog2(self.fw_int) + 1)
        self._front_occ = self.var("front_occ", clog2(self.fw_int) + 1)

        self._back_rf = RegFIFO(data_width=self.data_width,
                                width_mult=1,
                                depth=self.fw_int,
                                parallel=True,
                                break_out_rd_ptr=False)

        self._fw_is_1 = self.var("fw_is_1", 1)
        self.wire(self._fw_is_1, kts.const(self.fw_int == 1, 1))

        self._back_pop = self.var("back_pop", 1)
        if self.fw_int == 1:
            self.wire(self._back_pop, self._pop & (~self._empty | self._push) & ~self._back_pl)
        else:
            self.wire(self._back_pop, self._pop & (~self._empty | self._push))

        self.add_child("back_rf", self._back_rf,
                       clk=self._clk,
                       clk_en=kts.const(1, 1),
                       rst_n=self._rst_n,
                       push=self._back_push,
                       pop=self._back_pop,
                       empty=self._back_empty,
                       full=self._back_full,
                       valid=self._back_valid,
                       parallel_read=kts.const(0, 1),
                       # Only do back load when data is going there
                       parallel_load=self._back_pl & self._back_num_load.r_or(),
                       parallel_in=self._back_par_in,
                       num_load=self._back_num_load)
        self.wire(self._back_rf.ports.data_in[0], self._back_data_in)
        self.wire(self._back_data_out, self._back_rf.ports.data_out[0])
        # send the writes through when a read isn't happening
        for i in range(self.banks):
            self.add_code(self.send_writes, idx=i)
            self.add_code(self.send_reads, idx=i)

        # Set the parallel load to back bank - if no delay it's immediate
        # if not, it's delayed :)
        if self.read_delay == 1:
            self._ren_delay = self.var("ren_delay", 1)
            self.add_code(self.set_parallel_ld_delay_1)
            self.wire(self._back_pl, self._ren_delay)
        else:
            self.wire(self._back_pl, self._ren_to_strg.r_or())

        # Combine front end data - just the items + incoming
        # this data is actually based on the rd_ptr from the front fifo
        for i in range(self.fw_int):
            self.wire(self._front_combined[i], self._front_par_out[self._front_rd_ptr + i])
        # This is always true
        # self.wire(self._front_combined[self.fw_int - 1], self._data_in)

        # prioritize queued writes, otherwise send combined data
        for i in range(self.banks):
            self.wire(self._data_to_strg[i],
                      kts.ternary(self._queued_write[i], self._write_queue[i], self._front_combined))

        # Wire the thin output from front to thin input to back
        self.wire(self._back_data_in, self._front_data_out)
        self.wire(self._back_push, self._front_valid)
        self.add_code(self.set_front_pop)

        # Queue writes
        for i in range(self.banks):
            self.add_code(self.set_write_queue, idx=i)

        # Track number of words in memory
        # if self.read_delay == 1:
        #     self.add_code(self.set_num_words_mem_delay)
        # else:
        self.add_code(self.set_num_words_mem)

        # Track occupancy of the two small fifos
        self.add_code(self.set_front_occ)
        self.add_code(self.set_back_occ)

        if self.banks > 1:
            self.add_code(self.set_curr_bank_wr)
            self.add_code(self.set_curr_bank_rd)
        if self.read_delay == 1:
            self._prev_bank_rd = self.var("prev_bank_rd", max(1, kts.clog2(self.banks)))
            self.add_code(self.set_prev_bank_rd)

        # Parallel load data to back - based on num_load
        index_into = self._curr_bank_rd
        if self.read_delay == 1:
            index_into = self._prev_bank_rd
        for i in range(self.fw_int - 1):
            # Shift data over if you bypassed from the memory output
            self.wire(self._back_par_in[i],
                      kts.ternary(self._back_num_load == self.fw_int,
                                  self._data_from_strg[index_into][i],
                                  self._data_from_strg[index_into][i + 1]))
        self.wire(self._back_par_in[self.fw_int - 1],
                  kts.ternary(self._back_num_load == self.fw_int,
                              self._data_from_strg[index_into][self.fw_int - 1],
                              kts.const(0, self.data_width)))

        # Set the parallel read to the front fifo - analogous with trying to write to the memory
        self.add_code(self.set_front_par_read)

        # Set the number being parallely loaded into the register
        self.add_code(self.set_back_num_load)

        # Data out and valid out are (in the general case) just the data and valid from the back fifo
        # In the case where we have a fresh memory read, it would be from that
        bank_idx_read = self._curr_bank_rd
        if self.read_delay == 1:
            bank_idx_read = self._prev_bank_rd
        self.wire(self._data_out,
                  kts.ternary(self._back_pl, self._data_from_strg[bank_idx_read][0], self._back_data_out))
        self.wire(self._valid_out, kts.ternary(self._back_pl, self._pop, self._back_valid))

        # Set addresses to storage
        for i in range(self.banks):
            self.add_code(self.set_wen_addr, idx=i)
            self.add_code(self.set_ren_addr, idx=i)
        # Now deal with a shared address vs separate addresses
        if self.rw_same_cycle:
            # Separate
            self._wen_addr_out = self.output("wen_addr_out", self.addr_width,
                                             size=self.banks,
                                             explicit_array=True,
                                             packed=True)
            self._ren_addr_out = self.output("ren_addr_out", self.addr_width,
                                             size=self.banks,
                                             explicit_array=True,
                                             packed=True)
            self.wire(self._wen_addr_out, self._wen_addr)
            self.wire(self._ren_addr_out, self._ren_addr)
        else:
            self._addr_out = self.output("addr_out", self.addr_width,
                                         size=self.banks,
                                         explicit_array=True,
                                         packed=True)
            # If sharing the addresses, send read addr with priority
            for i in range(self.banks):
                self.wire(self._addr_out[i],
                          kts.ternary(self._wen_to_strg[i], self._wen_addr[i], self._ren_addr[i]))

        # Do final empty/full
        self._num_items = self.var("num_items", self.data_width)
        self.add_code(self.set_num_items)
        self._fifo_depth = self.input("fifo_depth", self.data_width)
        self._fifo_depth.add_attribute(ConfigRegAttr("Fifo depth..."))
        self.wire(self._empty, self._num_items == 0)
        self.wire(self._full, self._num_items == (self._fifo_depth))
예제 #23
0
def test_reg_file_basic(data_width, depth, width_mult, write_ports,
                        read_ports):

    addr_width = kts.clog2(depth)

    # Set up model...
    model_rf = RegisterFileModel(data_width=data_width,
                                 write_ports=write_ports,
                                 read_ports=read_ports,
                                 width_mult=width_mult,
                                 depth=depth)
    new_config = {}
    model_rf.set_config(new_config=new_config)
    ###

    # Set up dut...
    dut = RegisterFile(data_width=data_width,
                       write_ports=write_ports,
                       read_ports=read_ports,
                       width_mult=width_mult,
                       depth=depth)

    magma_dut = kts.util.to_magma(dut,
                                  flatten_array=True,
                                  check_flip_flop_always_ff=False)
    tester = fault.Tester(magma_dut, magma_dut.clk)
    ###

    for key, value in new_config.items():
        setattr(tester.circuit, key, value)

    # initial reset
    tester.circuit.clk = 0
    tester.circuit.rst_n = 0
    tester.step(2)
    tester.circuit.rst_n = 1
    tester.step(2)

    rand.seed(0)

    for z in range(1000):
        # Generate new input
        wen = []
        wr_addr = []
        wr_data = []
        for i in range(write_ports):
            wen.append(rand.randint(0, 1))
            wr_addr.append(rand.randint(0, depth - 1))
            new_dat = []
            for j in range(width_mult):
                new_dat.append(rand.randint(0, 2**data_width - 1))
            wr_data.append(new_dat)

        rd_addr = []
        for i in range(read_ports):
            rd_addr.append(rand.randint(0, depth - 1))

        if write_ports == 1:
            tester.circuit.wr_addr = wr_addr[0]
        else:
            for i in range(write_ports):
                setattr(tester.circuit, f"wr_addr_{i}", wr_addr[i])

        for i in range(write_ports):
            tester.circuit.wen[i] = wen[i]

        if width_mult == 1 and write_ports == 1:
            tester.circuit.data_in = wr_data[0][0]
        elif width_mult == 1:
            for i in range(write_ports):
                setattr(tester.circuit, f"data_in_{i}_0", wr_data[i][0])
        elif write_ports == 1:
            for i in range(width_mult):
                setattr(tester.circuit, f"data_in_{i}", wr_data[0][i])
        else:
            for i in range(write_ports):
                for j in range(width_mult):
                    setattr(tester.circuit, f"data_in_{i}_{j}", wr_data[i][j])

        if read_ports == 1:
            tester.circuit.rd_addr = rd_addr[0]
        else:
            for i in range(read_ports):
                setattr(tester.circuit, f"rd_addr_{i}", rd_addr[i])

        model_dat_out = model_rf.interact(wen, wr_addr, rd_addr, wr_data)

        tester.eval()

        if width_mult == 1 and read_ports == 1:
            tester.circuit.data_out.expect(model_dat_out[0][0])
        elif width_mult == 1:
            for i in range(read_ports):
                getattr(tester.circuit,
                        f"data_out_{i}_0").expect(model_dat_out[i][0])
        elif read_ports == 1:
            for i in range(width_mult):
                getattr(tester.circuit,
                        f"data_out_{i}").expect(model_dat_out[0][i])
        else:
            for i in range(read_ports):
                for j in range(width_mult):
                    getattr(tester.circuit,
                            f"data_out_{i}_{j}").expect(model_dat_out[i][j])

        tester.step(2)

    with tempfile.TemporaryDirectory() as tempdir:
        tester.compile_and_run(target="verilator",
                               directory=tempdir,
                               magma_output="verilog",
                               flags=["-Wno-fatal"])
예제 #24
0
    def __init__(self,
                 interconnect_input_ports,
                 interconnect_output_ports,
                 depth_width=16,
                 sprt_stcl_valid=False,
                 stcl_cnt_width=16,
                 stcl_iter_support=4):
        super().__init__("app_ctrl", debug=True)

        self.int_in_ports = interconnect_input_ports
        self.int_out_ports = interconnect_output_ports
        self.depth_width = depth_width
        self.sprt_stcl_valid = sprt_stcl_valid
        self.stcl_cnt_width = stcl_cnt_width
        self.stcl_iter_support = stcl_iter_support

        # Clock and Reset
        self._clk = self.clock("clk")
        self._rst_n = self.reset("rst_n")

        # IO
        self._wen_in = self.input("wen_in", self.int_in_ports)
        self._ren_in = self.input("ren_in", self.int_out_ports)

        self._ren_update = self.input("ren_update", self.int_out_ports)

        self._tb_valid = self.input("tb_valid", self.int_out_ports)

        self._valid_out_data = self.output("valid_out_data", self.int_out_ports)

        self._valid_out_stencil = self.output("valid_out_stencil", self.int_out_ports)

        # Send tb valid to valid out for now...

        if self.sprt_stcl_valid:
            # Add the config registers to watch
            self._ranges = self.input("ranges", self.stcl_cnt_width,
                                      size=self.stcl_iter_support,
                                      packed=True,
                                      explicit_array=True)
            self._ranges.add_attribute(ConfigRegAttr("Ranges of stencil valid generator"))

            self._threshold = self.input("threshold", self.stcl_cnt_width,
                                         size=self.stcl_iter_support,
                                         packed=True,
                                         explicit_array=True)
            self._threshold.add_attribute(ConfigRegAttr("Threshold of stencil valid generator"))

            self._dim_counter = self.var("dim_counter", self.stcl_cnt_width,
                                         size=self.stcl_iter_support,
                                         packed=True,
                                         explicit_array=True)

            self._update = self.var("update", self.stcl_iter_support)

            self.wire(self._update[0], const(1, 1))
            for i in range(self.stcl_iter_support - 1):
                self.wire(self._update[i + 1],
                          (self._dim_counter[i] == (self._ranges[i] - 1)) & self._update[i])

            for i in range(self.stcl_iter_support):
                self.add_code(self.dim_counter_update, idx=i)

            # Now we need to just compute stencil valid
            threshold_comps = [self._dim_counter[_i] >= self._threshold[_i] for _i in range(self.stcl_iter_support)]
            self.wire(self._valid_out_stencil[0], kts.concat(*threshold_comps).r_and())
            for i in range(self.int_out_ports - 1):
                # self.wire(self._valid_out_stencil[i + 1], 0)
                # for multiple ports
                self.wire(self._valid_out_stencil[i + 1], kts.concat(*threshold_comps).r_and())

        else:
            self.wire(self._valid_out_stencil, self._tb_valid)

        # Now gate the valid with stencil valid
        self.wire(self._valid_out_data, self._tb_valid & self._valid_out_stencil)
        self._wr_delay_state_n = self.var("wr_delay_state_n", self.int_out_ports)
        self._wen_out = self.output("wen_out", self.int_in_ports)
        self._ren_out = self.output("ren_out", self.int_out_ports)

        self._write_depth_wo = self.input("write_depth_wo", self.depth_width,
                                          size=self.int_in_ports,
                                          explicit_array=True,
                                          packed=True)
        self._write_depth_wo.add_attribute(ConfigRegAttr("Depth of writes"))

        self._write_depth_ss = self.input("write_depth_ss", self.depth_width,
                                          size=self.int_in_ports,
                                          explicit_array=True,
                                          packed=True)
        self._write_depth_ss.add_attribute(ConfigRegAttr("Depth of writes"))

        self._write_depth = self.var("write_depth", self.depth_width,
                                     size=self.int_in_ports,
                                     explicit_array=True,
                                     packed=True)

        for i in range(self.int_in_ports):
            self.wire(self._write_depth[i],
                      kts.ternary(self._wr_delay_state_n[i],
                                  self._write_depth_ss[i],
                                  self._write_depth_wo[i]))

        self._read_depth = self.input("read_depth", self.depth_width,
                                      size=self.int_out_ports,
                                      explicit_array=True,
                                      packed=True)
        self._read_depth.add_attribute(ConfigRegAttr("Depth of reads"))

        self._write_count = self.var("write_count", self.depth_width,
                                     size=self.int_in_ports,
                                     explicit_array=True,
                                     packed=True)

        self._read_count = self.var("read_count", self.depth_width,
                                    size=self.int_out_ports,
                                    explicit_array=True,
                                    packed=True)
        self._write_done = self.var("write_done", self.int_in_ports)
        self._write_done_ff = self.var("write_done_ff", self.int_in_ports)
        self._read_done = self.var("read_done", self.int_out_ports)
        self._read_done_ff = self.var("read_done_ff", self.int_out_ports)

        self.in_port_bits = max(1, kts.clog2(self.int_in_ports))
        self._input_port = self.input("input_port", self.in_port_bits,
                                      size=self.int_out_ports,
                                      explicit_array=True,
                                      packed=True)
        self._input_port.add_attribute(ConfigRegAttr("Relative input port for an output port"))

        self.out_port_bits = max(1, kts.clog2(self.int_out_ports))
        self._output_port = self.input("output_port", self.out_port_bits,
                                       size=self.int_in_ports,
                                       explicit_array=True,
                                       packed=True)
        self._output_port.add_attribute(ConfigRegAttr("Relative output port for an input port"))

        self._prefill = self.input("prefill", self.int_out_ports)
        self._prefill.add_attribute(ConfigRegAttr("Is the input stream prewritten?"))

        for i in range(self.int_out_ports):
            self.add_code(self.set_read_done, idx=i)
            if self.int_in_ports == 1:
                self.add_code(self.set_read_done_ff_one_wr, idx=i)
            else:
                self.add_code(self.set_read_done_ff, idx=i)

        # self._write_done_comb = self.var("write_done_comb", self.int_in_ports)
        for i in range(self.int_in_ports):
            self.add_code(self.set_write_done, idx=i)
            self.add_code(self.set_write_done_ff, idx=i)

        for i in range(self.int_in_ports):
            self.add_code(self.set_write_cnt, idx=i)
        for i in range(self.int_out_ports):
            if self.int_in_ports == 1:
                self.add_code(self.set_read_cnt_one_wr, idx=i)
            else:
                self.add_code(self.set_read_cnt, idx=i)

        for i in range(self.int_out_ports):
            if self.int_in_ports == 1:
                self.add_code(self.set_wr_delay_state_one_wr, idx=i)
            else:
                self.add_code(self.set_wr_delay_state, idx=i)

        self._read_on = self.var("read_on", self.int_out_ports)
        for i in range(self.int_out_ports):
            self.wire(self._read_on[i], self._read_depth[i].r_or())

        # If we have prefill enabled, we are skipping the initial delay step...
        self.wire(self._ren_out,
                  (self._wr_delay_state_n | self._prefill) & ~self._read_done_ff & self._ren_in &
                  self._read_on)
        self.wire(self._wen_out, ~self._write_done_ff & self._wen_in)
예제 #25
0
    def __init__(
            self,
            data_width=16,  # CGRA Params
            mem_width=16,
            mem_depth=256,
            banks=1,
            input_iterator_support=6,  # Addr Controllers
            output_iterator_support=6,
            input_config_width=16,
            output_config_width=16,
            interconnect_input_ports=1,  # Connection to int
            interconnect_output_ports=1,
            mem_input_ports=1,
            mem_output_ports=1,
            use_sram_stub=True,
            sram_macro_info=SRAMMacroInfo(),
            read_delay=1,  # Cycle delay in read (SRAM vs Register File)
            rw_same_cycle=True,  # Does the memory allow r+w in same cycle?
            agg_height=4,
            tb_sched_max=16,
            config_data_width=32,
            config_addr_width=8,
            num_tiles=1,
            remove_tb=False,
            fifo_mode=False,
            add_clk_enable=True,
            add_flush=True,
            override_name=None):

        # name
        if override_name:
            self.__name = override_name + "Core"
            lake_name = override_name
        else:
            self.__name = "MemCore"
            lake_name = "LakeTop"

        super().__init__(config_addr_width, config_data_width)

        # Capture everything to the tile object
        self.data_width = data_width
        self.mem_width = mem_width
        self.mem_depth = mem_depth
        self.banks = banks
        self.fw_int = int(self.mem_width / self.data_width)
        self.input_iterator_support = input_iterator_support
        self.output_iterator_support = output_iterator_support
        self.input_config_width = input_config_width
        self.output_config_width = output_config_width
        self.interconnect_input_ports = interconnect_input_ports
        self.interconnect_output_ports = interconnect_output_ports
        self.mem_input_ports = mem_input_ports
        self.mem_output_ports = mem_output_ports
        self.use_sram_stub = use_sram_stub
        self.sram_macro_info = sram_macro_info
        self.read_delay = read_delay
        self.rw_same_cycle = rw_same_cycle
        self.agg_height = agg_height
        self.config_data_width = config_data_width
        self.config_addr_width = config_addr_width
        self.num_tiles = num_tiles
        self.remove_tb = remove_tb
        self.fifo_mode = fifo_mode
        self.add_clk_enable = add_clk_enable
        self.add_flush = add_flush
        # self.app_ctrl_depth_width = app_ctrl_depth_width
        # self.stcl_valid_iter = stcl_valid_iter

        # Typedefs for ease
        TData = magma.Bits[self.data_width]
        TBit = magma.Bits[1]

        self.__inputs = []
        self.__outputs = []

        # cache_key = (self.data_width, self.mem_width, self.mem_depth, self.banks,
        #              self.input_iterator_support, self.output_iterator_support,
        #              self.interconnect_input_ports, self.interconnect_output_ports,
        #              self.use_sram_stub, self.sram_macro_info, self.read_delay,
        #              self.rw_same_cycle, self.agg_height, self.max_agg_schedule,
        #              self.input_max_port_sched, self.output_max_port_sched,
        #              self.align_input, self.max_line_length, self.max_tb_height,
        #              self.tb_range_max, self.tb_sched_max, self.max_tb_stride,
        #              self.num_tb, self.tb_iterator_support, self.multiwrite,
        #              self.max_prefetch, self.config_data_width, self.config_addr_width,
        #              self.num_tiles, self.remove_tb, self.fifo_mode, self.stcl_valid_iter,
        #              self.add_clk_enable, self.add_flush, self.app_ctrl_depth_width)

        cache_key = (self.data_width, self.mem_width, self.mem_depth,
                     self.banks, self.input_iterator_support,
                     self.output_iterator_support,
                     self.interconnect_input_ports,
                     self.interconnect_output_ports, self.use_sram_stub,
                     self.sram_macro_info, self.read_delay, self.rw_same_cycle,
                     self.agg_height, self.config_data_width,
                     self.config_addr_width, self.num_tiles, self.remove_tb,
                     self.fifo_mode, self.add_clk_enable, self.add_flush)

        # Check for circuit caching
        if cache_key not in MemCore.__circuit_cache:

            # Instantiate core object here - will only use the object representation to
            # query for information. The circuit representation will be cached and retrieved
            # in the following steps.
            # lt_dut = LakeTop(data_width=self.data_width,
            #                  mem_width=self.mem_width,
            #                  mem_depth=self.mem_depth,
            #                  banks=self.banks,
            #                  input_iterator_support=self.input_iterator_support,
            #                  output_iterator_support=self.output_iterator_support,
            #                  input_config_width=self.input_config_width,
            #                  output_config_width=self.output_config_width,
            #                  interconnect_input_ports=self.interconnect_input_ports,
            #                  interconnect_output_ports=self.interconnect_output_ports,
            #                  use_sram_stub=self.use_sram_stub,
            #                  sram_macro_info=self.sram_macro_info,
            #                  read_delay=self.read_delay,
            #                  rw_same_cycle=self.rw_same_cycle,
            #                  agg_height=self.agg_height,
            #                  max_agg_schedule=self.max_agg_schedule,
            #                  input_max_port_sched=self.input_max_port_sched,
            #                  output_max_port_sched=self.output_max_port_sched,
            #                  align_input=self.align_input,
            #                  max_line_length=self.max_line_length,
            #                  max_tb_height=self.max_tb_height,
            #                  tb_range_max=self.tb_range_max,
            #                  tb_range_inner_max=self.tb_range_inner_max,
            #                  tb_sched_max=self.tb_sched_max,
            #                  max_tb_stride=self.max_tb_stride,
            #                  num_tb=self.num_tb,
            #                  tb_iterator_support=self.tb_iterator_support,
            #                  multiwrite=self.multiwrite,
            #                  max_prefetch=self.max_prefetch,
            #                  config_data_width=self.config_data_width,
            #                  config_addr_width=self.config_addr_width,
            #                  num_tiles=self.num_tiles,
            #                  app_ctrl_depth_width=self.app_ctrl_depth_width,
            #                  remove_tb=self.remove_tb,
            #                  fifo_mode=self.fifo_mode,
            #                  add_clk_enable=self.add_clk_enable,
            #                  add_flush=self.add_flush,
            #                  stcl_valid_iter=self.stcl_valid_iter)

            lt_dut = LakeTop(
                data_width=self.data_width,
                mem_width=self.mem_width,
                mem_depth=self.mem_depth,
                banks=self.banks,
                input_iterator_support=self.input_iterator_support,
                output_iterator_support=self.output_iterator_support,
                input_config_width=self.input_config_width,
                output_config_width=self.output_config_width,
                interconnect_input_ports=self.interconnect_input_ports,
                interconnect_output_ports=self.interconnect_output_ports,
                use_sram_stub=self.use_sram_stub,
                sram_macro_info=self.sram_macro_info,
                read_delay=self.read_delay,
                rw_same_cycle=self.rw_same_cycle,
                agg_height=self.agg_height,
                config_data_width=self.config_data_width,
                config_addr_width=self.config_addr_width,
                num_tiles=self.num_tiles,
                remove_tb=self.remove_tb,
                fifo_mode=self.fifo_mode,
                add_clk_enable=self.add_clk_enable,
                add_flush=self.add_flush,
                name=lake_name,
                gen_addr=False)

            change_sram_port_pass = change_sram_port_names(
                use_sram_stub, sram_macro_info)
            circ = kts.util.to_magma(
                lt_dut,
                flatten_array=True,
                check_multiple_driver=False,
                optimize_if=False,
                check_flip_flop_always_ff=False,
                additional_passes={"change_sram_port": change_sram_port_pass})
            MemCore.__circuit_cache[cache_key] = (circ, lt_dut)
        else:
            circ, lt_dut = MemCore.__circuit_cache[cache_key]

        # Save as underlying circuit object
        self.underlying = FromMagma(circ)

        # Enumerate input and output ports
        # (clk and reset are assumed)
        core_interface = get_interface(lt_dut)
        cfgs = extract_top_config(lt_dut)
        assert len(cfgs) > 0, "No configs?"

        # We basically add in the configuration bus differently
        # than the other ports...
        skip_names = [
            "config_data_in", "config_write", "config_addr_in",
            "config_data_out", "config_read", "config_en", "clk_en"
        ]

        # Create a list of signals that will be able to be
        # hardwired to a constant at runtime...
        control_signals = []
        # The rest of the signals to wire to the underlying representation...
        other_signals = []

        # for port_name, port_size, port_width, is_ctrl, port_dir, explicit_array in core_interface:
        for io_info in core_interface:
            if io_info.port_name in skip_names:
                continue
            ind_ports = io_info.port_width
            intf_type = TBit
            # For our purposes, an explicit array means the inner data HAS to be 16 bits
            if io_info.expl_arr:
                ind_ports = io_info.port_size[0]
                intf_type = TData
            dir_type = magma.In
            app_list = self.__inputs
            if io_info.port_dir == "PortDirection.Out":
                dir_type = magma.Out
                app_list = self.__outputs
            if ind_ports > 1:
                for i in range(ind_ports):
                    self.add_port(f"{io_info.port_name}_{i}",
                                  dir_type(intf_type))
                    app_list.append(self.ports[f"{io_info.port_name}_{i}"])
            else:
                self.add_port(io_info.port_name, dir_type(intf_type))
                app_list.append(self.ports[io_info.port_name])

            # classify each signal for wiring to underlying representation...
            if io_info.is_ctrl:
                control_signals.append((io_info.port_name, io_info.port_width))
            else:
                if ind_ports > 1:
                    for i in range(ind_ports):
                        other_signals.append(
                            (f"{io_info.port_name}_{i}", io_info.port_dir,
                             io_info.expl_arr, i, io_info.port_name))
                else:
                    other_signals.append(
                        (io_info.port_name, io_info.port_dir, io_info.expl_arr,
                         0, io_info.port_name))

        assert (len(self.__outputs) > 0)

        # We call clk_en stall at this level for legacy reasons????
        self.add_ports(stall=magma.In(TBit), )

        self.chain_idx_bits = max(1, kts.clog2(self.num_tiles))

        # put a 1-bit register and a mux to select the control signals
        for control_signal, width in control_signals:
            if width == 1:
                mux = MuxWrapper(2, 1, name=f"{control_signal}_sel")
                reg_value_name = f"{control_signal}_reg_value"
                reg_sel_name = f"{control_signal}_reg_sel"
                self.add_config(reg_value_name, 1)
                self.add_config(reg_sel_name, 1)
                self.wire(mux.ports.I[0], self.ports[control_signal])
                self.wire(mux.ports.I[1],
                          self.registers[reg_value_name].ports.O)
                self.wire(mux.ports.S, self.registers[reg_sel_name].ports.O)
                # 0 is the default wire, which takes from the routing network
                self.wire(mux.ports.O[0],
                          self.underlying.ports[control_signal][0])
            else:
                for i in range(width):
                    mux = MuxWrapper(2, 1, name=f"{control_signal}_{i}_sel")
                    reg_value_name = f"{control_signal}_{i}_reg_value"
                    reg_sel_name = f"{control_signal}_{i}_reg_sel"
                    self.add_config(reg_value_name, 1)
                    self.add_config(reg_sel_name, 1)
                    self.wire(mux.ports.I[0],
                              self.ports[f"{control_signal}_{i}"])
                    self.wire(mux.ports.I[1],
                              self.registers[reg_value_name].ports.O)
                    self.wire(mux.ports.S,
                              self.registers[reg_sel_name].ports.O)
                    # 0 is the default wire, which takes from the routing network
                    self.wire(mux.ports.O[0],
                              self.underlying.ports[control_signal][i])

        # Wire the other signals up...
        for pname, pdir, expl_arr, ind, uname in other_signals:
            # If we are in an explicit array moment, use the given wire name...
            if expl_arr is False:
                # And if not, use the index
                self.wire(self.ports[pname][0],
                          self.underlying.ports[uname][ind])
            else:
                self.wire(self.ports[pname], self.underlying.ports[pname])

        # CLK, RESET, and STALL PER STANDARD PROCEDURE

        # Need to invert this
        self.resetInverter = FromMagma(mantle.DefineInvert(1))
        self.wire(self.resetInverter.ports.I[0], self.ports.reset)
        self.wire(self.resetInverter.ports.O[0], self.underlying.ports.rst_n)
        self.wire(self.ports.clk, self.underlying.ports.clk)

        # Mem core uses clk_en (essentially active low stall)
        self.stallInverter = FromMagma(mantle.DefineInvert(1))
        self.wire(self.stallInverter.ports.I, self.ports.stall)
        self.wire(self.stallInverter.ports.O[0],
                  self.underlying.ports.clk_en[0])

        # we have six? features in total
        # 0:    TILE
        # 1:    TILE
        # 1-4:  SMEM
        # Feature 0: Tile
        self.__features: List[CoreFeature] = [self]
        # Features 1-4: SRAM
        self.num_sram_features = lt_dut.total_sets
        for sram_index in range(self.num_sram_features):
            core_feature = CoreFeature(self, sram_index + 1)
            self.__features.append(core_feature)

        # Wire the config
        for idx, core_feature in enumerate(self.__features):
            if (idx > 0):
                self.add_port(
                    f"config_{idx}",
                    magma.In(
                        ConfigurationType(self.config_addr_width,
                                          self.config_data_width)))
                # port aliasing
                core_feature.ports["config"] = self.ports[f"config_{idx}"]
        self.add_port(
            "config",
            magma.In(
                ConfigurationType(self.config_addr_width,
                                  self.config_data_width)))

        # or the signal up
        t = ConfigurationType(self.config_addr_width, self.config_data_width)
        t_names = ["config_addr", "config_data"]
        or_gates = {}
        for t_name in t_names:
            port_type = t[t_name]
            or_gate = FromMagma(
                mantle.DefineOr(len(self.__features), len(port_type)))
            or_gate.instance_name = f"OR_{t_name}_FEATURE"
            for idx, core_feature in enumerate(self.__features):
                self.wire(or_gate.ports[f"I{idx}"],
                          core_feature.ports.config[t_name])
            or_gates[t_name] = or_gate

        self.wire(
            or_gates["config_addr"].ports.O,
            self.underlying.ports.config_addr_in[0:self.config_addr_width])
        self.wire(or_gates["config_data"].ports.O,
                  self.underlying.ports.config_data_in)

        # read data out
        for idx, core_feature in enumerate(self.__features):
            if (idx > 0):
                # self.add_port(f"read_config_data_{idx}",
                self.add_port(f"read_config_data_{idx}",
                              magma.Out(magma.Bits[self.config_data_width]))
                # port aliasing
                core_feature.ports["read_config_data"] = \
                    self.ports[f"read_config_data_{idx}"]

        # MEM Config
        configurations = []
        # merged_configs = []
        skip_cfgs = []

        for cfg_info in cfgs:
            if cfg_info.port_name in skip_cfgs:
                continue
            if cfg_info.expl_arr:
                if cfg_info.port_size[0] > 1:
                    for i in range(cfg_info.port_size[0]):
                        configurations.append(
                            (f"{cfg_info.port_name}_{i}", cfg_info.port_width))
                else:
                    configurations.append(
                        (cfg_info.port_name, cfg_info.port_width))
            else:
                configurations.append(
                    (cfg_info.port_name, cfg_info.port_width))

        # Do all the stuff for the main config
        main_feature = self.__features[0]
        for config_reg_name, width in configurations:
            main_feature.add_config(config_reg_name, width)
            if (width == 1):
                self.wire(main_feature.registers[config_reg_name].ports.O[0],
                          self.underlying.ports[config_reg_name][0])
            else:
                self.wire(main_feature.registers[config_reg_name].ports.O,
                          self.underlying.ports[config_reg_name])

        # SRAM
        # These should also account for num features
        # or_all_cfg_rd = FromMagma(mantle.DefineOr(4, 1))
        or_all_cfg_rd = FromMagma(mantle.DefineOr(self.num_sram_features, 1))
        or_all_cfg_rd.instance_name = f"OR_CONFIG_WR_SRAM"
        or_all_cfg_wr = FromMagma(mantle.DefineOr(self.num_sram_features, 1))
        or_all_cfg_wr.instance_name = f"OR_CONFIG_RD_SRAM"
        for sram_index in range(self.num_sram_features):
            core_feature = self.__features[sram_index + 1]
            self.add_port(f"config_en_{sram_index}", magma.In(magma.Bit))
            # port aliasing
            core_feature.ports["config_en"] = \
                self.ports[f"config_en_{sram_index}"]
            # Sort of a temp hack - the name is just config_data_out
            if self.num_sram_features == 1:
                self.wire(core_feature.ports.read_config_data,
                          self.underlying.ports["config_data_out"])
            else:
                self.wire(
                    core_feature.ports.read_config_data,
                    self.underlying.ports[f"config_data_out_{sram_index}"])
            # also need to wire the sram signal
            # the config enable is the OR of the rd+wr
            or_gate_en = FromMagma(mantle.DefineOr(2, 1))
            or_gate_en.instance_name = f"OR_CONFIG_EN_SRAM_{sram_index}"

            self.wire(or_gate_en.ports.I0, core_feature.ports.config.write)
            self.wire(or_gate_en.ports.I1, core_feature.ports.config.read)
            self.wire(core_feature.ports.config_en,
                      self.underlying.ports["config_en"][sram_index])
            # Still connect to the OR of all the config rd/wr
            self.wire(core_feature.ports.config.write,
                      or_all_cfg_wr.ports[f"I{sram_index}"])
            self.wire(core_feature.ports.config.read,
                      or_all_cfg_rd.ports[f"I{sram_index}"])

        self.wire(or_all_cfg_rd.ports.O[0],
                  self.underlying.ports.config_read[0])
        self.wire(or_all_cfg_wr.ports.O[0],
                  self.underlying.ports.config_write[0])
        self._setup_config()

        conf_names = list(self.registers.keys())
        conf_names.sort()
        with open("mem_cfg.txt", "w+") as cfg_dump:
            for idx, reg in enumerate(conf_names):
                write_line = f"(\"{reg}\", 0),  # {self.registers[reg].width}\n"
                cfg_dump.write(write_line)
        with open("mem_synth.txt", "w+") as cfg_dump:
            for idx, reg in enumerate(conf_names):
                write_line = f"{reg}\n"
                cfg_dump.write(write_line)
예제 #26
0
def gen_global_buffer_rdl(name, params):
    addr_map = AddrMap(name)

    # Data Network Ctrl Register
    data_network_ctrl = Reg("data_network")
    tile_connected_f = Field("tile_connected", 1)
    strm_latency_f = Field("latency", params.latency_width)
    data_network_ctrl.add_children([tile_connected_f, strm_latency_f])
    addr_map.add_child(data_network_ctrl)

    # Pcfg Network Ctrl Register
    pcfg_network_ctrl = Reg("pcfg_network")
    pcfg_network_ctrl.add_children(
        [Field("tile_connected", 1),
         Field("latency", params.latency_width)])
    addr_map.add_child(pcfg_network_ctrl)

    # Store DMA Ctrl
    st_dma_ctrl_r = Reg("st_dma_ctrl")
    st_dma_mode_f = Field("mode", 2)
    st_dma_ctrl_r.add_child(st_dma_mode_f)
    st_dma_use_valid_f = Field("use_valid", 1)
    st_dma_ctrl_r.add_child(st_dma_use_valid_f)
    st_dma_data_mux_f = Field("data_mux", 2)
    st_dma_ctrl_r.add_child(st_dma_data_mux_f)
    st_dma_num_repeat_f = Field("num_repeat", clog2(params.queue_depth) + 1)
    st_dma_ctrl_r.add_child(st_dma_num_repeat_f)
    addr_map.add_child(st_dma_ctrl_r)

    # Store DMA Header
    if params.queue_depth == 1:
        st_dma_header_rf = RegFile(f"st_dma_header_0", size=params.queue_depth)
    else:
        st_dma_header_rf = RegFile(f"st_dma_header", size=params.queue_depth)

    # dim reg
    dim_r = Reg(f"dim")
    dim_f = Field(f"dim", width=clog2(params.loop_level) + 1)
    dim_r.add_child(dim_f)
    st_dma_header_rf.add_child(dim_r)

    # start_addr reg
    start_addr_r = Reg(f"start_addr")
    start_addr_f = Field(f"start_addr", width=params.glb_addr_width)
    start_addr_r.add_child(start_addr_f)
    st_dma_header_rf.add_child(start_addr_r)

    # cycle_start_addr reg
    cycle_start_addr_r = Reg(f"cycle_start_addr")
    cycle_start_addr_f = Field(f"cycle_start_addr",
                               width=params.glb_addr_width)
    cycle_start_addr_r.add_child(cycle_start_addr_f)
    st_dma_header_rf.add_child(cycle_start_addr_r)

    # num_word reg
    range_r = Reg(f"range", size=params.loop_level)
    range_f = Field("range", width=params.axi_data_width)
    range_r.add_child(range_f)
    stride_r = Reg(f"stride", size=params.loop_level)
    stride_f = Field("stride", width=params.axi_data_width)
    stride_r.add_child(stride_f)
    cycle_stride_r = Reg(f"cycle_stride", size=params.loop_level)
    cycle_stride_f = Field("cycle_stride", width=params.axi_data_width)
    cycle_stride_r.add_child(cycle_stride_f)
    st_dma_header_rf.add_child(range_r)
    st_dma_header_rf.add_child(stride_r)
    st_dma_header_rf.add_child(cycle_stride_r)

    addr_map.add_child(st_dma_header_rf)

    # Load DMA Ctrl
    ld_dma_ctrl_r = Reg("ld_dma_ctrl")
    ld_dma_mode_f = Field("mode", 2)
    ld_dma_ctrl_r.add_child(ld_dma_mode_f)
    ld_dma_use_valid_f = Field("use_valid", 1)
    ld_dma_ctrl_r.add_child(ld_dma_use_valid_f)
    ld_dma_data_mux_f = Field("data_mux", 2)
    ld_dma_ctrl_r.add_child(ld_dma_data_mux_f)
    ld_dma_num_repeat_f = Field("num_repeat", clog2(params.queue_depth) + 1)
    ld_dma_ctrl_r.add_child(ld_dma_num_repeat_f)
    addr_map.add_child(ld_dma_ctrl_r)

    # Load DMA Header
    if params.queue_depth == 1:
        ld_dma_header_rf = RegFile(f"ld_dma_header_0", size=params.queue_depth)
    else:
        ld_dma_header_rf = RegFile(f"ld_dma_header", size=params.queue_depth)

    # dim reg
    dim_r = Reg(f"dim")
    dim_f = Field(f"dim", width=clog2(params.loop_level) + 1)
    dim_r.add_child(dim_f)
    ld_dma_header_rf.add_child(dim_r)

    # start_addr reg
    start_addr_r = Reg(f"start_addr")
    start_addr_f = Field(f"start_addr", width=params.glb_addr_width)
    start_addr_r.add_child(start_addr_f)
    ld_dma_header_rf.add_child(start_addr_r)

    # cycle_start_addr reg
    cycle_start_addr_r = Reg(f"cycle_start_addr")
    cycle_start_addr_f = Field(f"cycle_start_addr",
                               width=params.glb_addr_width)
    cycle_start_addr_r.add_child(cycle_start_addr_f)
    ld_dma_header_rf.add_child(cycle_start_addr_r)

    # num_word reg
    range_r = Reg(f"range", size=params.loop_level)
    range_f = Field("range", width=params.axi_data_width)
    range_r.add_child(range_f)
    stride_r = Reg(f"stride", size=params.loop_level)
    stride_f = Field("stride", width=params.axi_data_width)
    stride_r.add_child(stride_f)
    cycle_stride_r = Reg(f"cycle_stride", size=params.loop_level)
    cycle_stride_f = Field("cycle_stride", width=params.axi_data_width)
    cycle_stride_r.add_child(cycle_stride_f)
    ld_dma_header_rf.add_child(range_r)
    ld_dma_header_rf.add_child(stride_r)
    ld_dma_header_rf.add_child(cycle_stride_r)

    addr_map.add_child(ld_dma_header_rf)

    # Pcfg DMA Ctrl
    pcfg_dma_ctrl_r = Reg("pcfg_dma_ctrl")
    pcfg_dma_mode_f = Field("mode", 1)
    pcfg_dma_ctrl_r.add_child(pcfg_dma_mode_f)
    addr_map.add_child(pcfg_dma_ctrl_r)

    # Pcfg DMA Header RegFile
    pcfg_dma_header_rf = RegFile("pcfg_dma_header")
    # start_addr reg
    start_addr_r = Reg(f"start_addr")
    start_addr_f = Field(f"start_addr", width=params.glb_addr_width)
    start_addr_r.add_child(start_addr_f)
    pcfg_dma_header_rf.add_child(start_addr_r)
    # num cfg reg
    num_cfg_r = Reg(f"num_cfg")
    num_cfg_f = Field(f"num_cfg", width=params.max_num_cfg_width)
    num_cfg_r.add_child(num_cfg_f)
    pcfg_dma_header_rf.add_child(num_cfg_r)
    addr_map.add_child(pcfg_dma_header_rf)
    glb_rdl = Rdl(addr_map)

    return glb_rdl
예제 #27
0
    def __init__(
            self,
            data_width=16,  # CGRA Params
            mem_depth=32,
            default_iterator_support=3,
            interconnect_input_ports=2,  # Connection to int
            interconnect_output_ports=2,
            mem_input_ports=1,
            mem_output_ports=1,
            config_data_width=32,
            config_addr_width=8,
            cycle_count_width=16,
            add_clk_enable=True,
            add_flush=True):
        super().__init__("pond", debug=True)

        self.interconnect_input_ports = interconnect_input_ports
        self.interconnect_output_ports = interconnect_output_ports
        self.mem_input_ports = mem_input_ports
        self.mem_output_ports = mem_output_ports
        self.mem_depth = mem_depth
        self.data_width = data_width
        self.config_data_width = config_data_width
        self.config_addr_width = config_addr_width
        self.add_clk_enable = add_clk_enable
        self.add_flush = add_flush
        self.cycle_count_width = cycle_count_width
        self.default_iterator_support = default_iterator_support
        self.default_config_width = kts.clog2(self.mem_depth)
        # inputs
        self._clk = self.clock("clk")
        self._clk.add_attribute(
            FormalAttr(f"{self._clk.name}", FormalSignalConstraint.CLK))
        self._rst_n = self.reset("rst_n")
        self._rst_n.add_attribute(
            FormalAttr(f"{self._rst_n.name}", FormalSignalConstraint.RSTN))
        self._clk_en = self.clock_en("clk_en", 1)

        # Enable/Disable tile
        self._tile_en = self.input("tile_en", 1)
        self._tile_en.add_attribute(
            ConfigRegAttr("Tile logic enable manifested as clock gate"))

        gclk = self.var("gclk", 1)
        self._gclk = kts.util.clock(gclk)
        self.wire(gclk, kts.util.clock(self._clk & self._tile_en))

        self._cycle_count = add_counter(self, "cycle_count",
                                        self.cycle_count_width)

        # Create write enable + addr, same for read.
        # self._write = self.input("write", self.interconnect_input_ports)
        self._write = self.var("write", self.mem_input_ports)
        # self._write.add_attribute(ControlSignalAttr(is_control=True))

        self._write_addr = self.var("write_addr",
                                    kts.clog2(self.mem_depth),
                                    size=self.interconnect_input_ports,
                                    explicit_array=True,
                                    packed=True)

        # Add "_pond" suffix to avoid error during garnet RTL generation
        self._data_in = self.input("data_in_pond",
                                   self.data_width,
                                   size=self.interconnect_input_ports,
                                   explicit_array=True,
                                   packed=True)
        self._data_in.add_attribute(
            FormalAttr(f"{self._data_in.name}",
                       FormalSignalConstraint.SEQUENCE))
        self._data_in.add_attribute(ControlSignalAttr(is_control=False))

        self._read = self.var("read", self.mem_output_ports)
        self._t_write = self.var("t_write", self.interconnect_input_ports)
        self._t_read = self.var("t_read", self.interconnect_output_ports)
        # self._read.add_attribute(ControlSignalAttr(is_control=True))

        self._read_addr = self.var("read_addr",
                                   kts.clog2(self.mem_depth),
                                   size=self.interconnect_output_ports,
                                   explicit_array=True,
                                   packed=True)

        self._s_read_addr = self.var("s_read_addr",
                                     kts.clog2(self.mem_depth),
                                     size=self.interconnect_output_ports,
                                     explicit_array=True,
                                     packed=True)

        self._data_out = self.output("data_out_pond",
                                     self.data_width,
                                     size=self.interconnect_output_ports,
                                     explicit_array=True,
                                     packed=True)
        self._data_out.add_attribute(
            FormalAttr(f"{self._data_out.name}",
                       FormalSignalConstraint.SEQUENCE))
        self._data_out.add_attribute(ControlSignalAttr(is_control=False))

        self._valid_out = self.output("valid_out_pond",
                                      self.interconnect_output_ports)
        self._valid_out.add_attribute(
            FormalAttr(f"{self._valid_out.name}",
                       FormalSignalConstraint.SEQUENCE))
        self._valid_out.add_attribute(ControlSignalAttr(is_control=False))

        self._mem_data_out = self.var("mem_data_out",
                                      self.data_width,
                                      size=self.mem_output_ports,
                                      explicit_array=True,
                                      packed=True)

        self._s_mem_data_in = self.var("s_mem_data_in",
                                       self.data_width,
                                       size=self.interconnect_input_ports,
                                       explicit_array=True,
                                       packed=True)

        self._mem_data_in = self.var("mem_data_in",
                                     self.data_width,
                                     size=self.mem_input_ports,
                                     explicit_array=True,
                                     packed=True)

        self._s_mem_write_addr = self.var("s_mem_write_addr",
                                          kts.clog2(self.mem_depth),
                                          size=self.interconnect_input_ports,
                                          explicit_array=True,
                                          packed=True)

        self._s_mem_read_addr = self.var("s_mem_read_addr",
                                         kts.clog2(self.mem_depth),
                                         size=self.interconnect_output_ports,
                                         explicit_array=True,
                                         packed=True)

        self._mem_write_addr = self.var("mem_write_addr",
                                        kts.clog2(self.mem_depth),
                                        size=self.mem_input_ports,
                                        explicit_array=True,
                                        packed=True)

        self._mem_read_addr = self.var("mem_read_addr",
                                       kts.clog2(self.mem_depth),
                                       size=self.mem_output_ports,
                                       explicit_array=True,
                                       packed=True)

        if self.interconnect_output_ports == 1:
            self.wire(self._data_out[0], self._mem_data_out[0])
        else:
            for i in range(self.interconnect_output_ports):
                self.wire(self._data_out[i], self._mem_data_out[0])

        # Valid out is simply passing the read signal through...
        self.wire(self._valid_out, self._t_read)

        # Create write addressors
        for wr_port in range(self.interconnect_input_ports):

            RF_WRITE_ITER = ForLoop(
                iterator_support=self.default_iterator_support,
                config_width=self.cycle_count_width)
            RF_WRITE_ADDR = AddrGen(
                iterator_support=self.default_iterator_support,
                config_width=self.default_config_width)
            RF_WRITE_SCHED = SchedGen(
                iterator_support=self.default_iterator_support,
                config_width=self.cycle_count_width,
                use_enable=True)

            self.add_child(f"rf_write_iter_{wr_port}",
                           RF_WRITE_ITER,
                           clk=self._gclk,
                           rst_n=self._rst_n,
                           step=self._t_write[wr_port])
            # Whatever comes through here should hopefully just pipe through seamlessly
            # addressor modules
            self.add_child(f"rf_write_addr_{wr_port}",
                           RF_WRITE_ADDR,
                           clk=self._gclk,
                           rst_n=self._rst_n,
                           step=self._t_write[wr_port],
                           mux_sel=RF_WRITE_ITER.ports.mux_sel_out,
                           restart=RF_WRITE_ITER.ports.restart)
            safe_wire(self, self._write_addr[wr_port],
                      RF_WRITE_ADDR.ports.addr_out)

            self.add_child(f"rf_write_sched_{wr_port}",
                           RF_WRITE_SCHED,
                           clk=self._gclk,
                           rst_n=self._rst_n,
                           mux_sel=RF_WRITE_ITER.ports.mux_sel_out,
                           finished=RF_WRITE_ITER.ports.restart,
                           cycle_count=self._cycle_count,
                           valid_output=self._t_write[wr_port])

        # Create read addressors
        for rd_port in range(self.interconnect_output_ports):

            RF_READ_ITER = ForLoop(
                iterator_support=self.default_iterator_support,
                config_width=self.cycle_count_width)
            RF_READ_ADDR = AddrGen(
                iterator_support=self.default_iterator_support,
                config_width=self.default_config_width)
            RF_READ_SCHED = SchedGen(
                iterator_support=self.default_iterator_support,
                config_width=self.cycle_count_width,
                use_enable=True)

            self.add_child(f"rf_read_iter_{rd_port}",
                           RF_READ_ITER,
                           clk=self._gclk,
                           rst_n=self._rst_n,
                           step=self._t_read[rd_port])

            self.add_child(f"rf_read_addr_{rd_port}",
                           RF_READ_ADDR,
                           clk=self._gclk,
                           rst_n=self._rst_n,
                           step=self._t_read[rd_port],
                           mux_sel=RF_READ_ITER.ports.mux_sel_out,
                           restart=RF_READ_ITER.ports.restart)
            if self.interconnect_output_ports > 1:
                safe_wire(self, self._read_addr[rd_port],
                          RF_READ_ADDR.ports.addr_out)
            else:
                safe_wire(self, self._read_addr[rd_port],
                          RF_READ_ADDR.ports.addr_out)

            self.add_child(f"rf_read_sched_{rd_port}",
                           RF_READ_SCHED,
                           clk=self._gclk,
                           rst_n=self._rst_n,
                           mux_sel=RF_READ_ITER.ports.mux_sel_out,
                           finished=RF_READ_ITER.ports.restart,
                           cycle_count=self._cycle_count,
                           valid_output=self._t_read[rd_port])

        self.wire(self._write, self._t_write.r_or())
        self.wire(self._mem_write_addr[0],
                  decode(self, self._t_write, self._s_mem_write_addr))

        self.wire(self._mem_data_in[0],
                  decode(self, self._t_write, self._s_mem_data_in))

        self.wire(self._read, self._t_read.r_or())
        self.wire(self._mem_read_addr[0],
                  decode(self, self._t_read, self._s_mem_read_addr))
        # ===================================
        # Instantiate config hooks...
        # ===================================
        self.fw_int = 1
        self.data_words_per_set = 2**self.config_addr_width
        self.sets = int(
            (self.fw_int * self.mem_depth) / self.data_words_per_set)

        self.sets_per_macro = max(
            1, int(self.mem_depth / self.data_words_per_set))
        self.total_sets = max(1, 1 * self.sets_per_macro)

        self._config_data_in = self.input("config_data_in",
                                          self.config_data_width)
        self._config_data_in.add_attribute(ControlSignalAttr(is_control=False))

        self._config_data_in_shrt = self.var("config_data_in_shrt",
                                             self.data_width)

        self.wire(self._config_data_in_shrt,
                  self._config_data_in[self.data_width - 1, 0])

        self._config_addr_in = self.input("config_addr_in",
                                          self.config_addr_width)
        self._config_addr_in.add_attribute(ControlSignalAttr(is_control=False))

        self._config_data_out_shrt = self.var("config_data_out_shrt",
                                              self.data_width,
                                              size=self.total_sets,
                                              explicit_array=True,
                                              packed=True)

        self._config_data_out = self.output("config_data_out",
                                            self.config_data_width,
                                            size=self.total_sets,
                                            explicit_array=True,
                                            packed=True)
        self._config_data_out.add_attribute(
            ControlSignalAttr(is_control=False))

        for i in range(self.total_sets):
            self.wire(
                self._config_data_out[i],
                self._config_data_out_shrt[i].extend(self.config_data_width))

        self._config_read = self.input("config_read", 1)
        self._config_read.add_attribute(ControlSignalAttr(is_control=False))

        self._config_write = self.input("config_write", 1)
        self._config_write.add_attribute(ControlSignalAttr(is_control=False))

        self._config_en = self.input("config_en", self.total_sets)
        self._config_en.add_attribute(ControlSignalAttr(is_control=False))

        self._mem_data_cfg = self.var("mem_data_cfg",
                                      self.data_width,
                                      explicit_array=True,
                                      packed=True)

        self._mem_addr_cfg = self.var("mem_addr_cfg",
                                      kts.clog2(self.mem_depth))

        # Add config...
        stg_cfg_seq = StorageConfigSeq(
            data_width=self.data_width,
            config_addr_width=self.config_addr_width,
            addr_width=kts.clog2(self.mem_depth),
            fetch_width=self.data_width,
            total_sets=self.total_sets,
            sets_per_macro=self.sets_per_macro)

        # The clock to config sequencer needs to be the normal clock or
        # if the tile is off, we bring the clock back in based on config_en
        cfg_seq_clk = self.var("cfg_seq_clk", 1)
        self._cfg_seq_clk = kts.util.clock(cfg_seq_clk)
        self.wire(cfg_seq_clk, kts.util.clock(self._gclk))

        self.add_child(f"config_seq",
                       stg_cfg_seq,
                       clk=self._cfg_seq_clk,
                       rst_n=self._rst_n,
                       clk_en=self._clk_en | self._config_en.r_or(),
                       config_data_in=self._config_data_in_shrt,
                       config_addr_in=self._config_addr_in,
                       config_wr=self._config_write,
                       config_rd=self._config_read,
                       config_en=self._config_en,
                       wr_data=self._mem_data_cfg,
                       rd_data_out=self._config_data_out_shrt,
                       addr_out=self._mem_addr_cfg)

        if self.interconnect_output_ports == 1:
            self.wire(stg_cfg_seq.ports.rd_data_stg, self._mem_data_out)
        else:
            self.wire(stg_cfg_seq.ports.rd_data_stg[0], self._mem_data_out[0])

        self.RF_GEN = RegisterFile(data_width=self.data_width,
                                   write_ports=self.mem_input_ports,
                                   read_ports=self.mem_output_ports,
                                   width_mult=1,
                                   depth=self.mem_depth,
                                   read_delay=0)

        # Now we can instantiate and wire up the register file
        self.add_child(f"rf",
                       self.RF_GEN,
                       clk=self._gclk,
                       rst_n=self._rst_n,
                       data_out=self._mem_data_out)

        # Opt in for config_write
        self._write_rf = self.var("write_rf", self.mem_input_ports)
        self.wire(
            self._write_rf[0],
            kts.ternary(self._config_en.r_or(), self._config_write,
                        self._write[0]))
        for i in range(self.mem_input_ports - 1):
            self.wire(
                self._write_rf[i + 1],
                kts.ternary(self._config_en.r_or(), kts.const(0, 1),
                            self._write[i + 1]))
        self.wire(self.RF_GEN.ports.wen, self._write_rf)

        # Opt in for config_data_in
        for i in range(self.interconnect_input_ports):
            self.wire(
                self._s_mem_data_in[i],
                kts.ternary(self._config_en.r_or(), self._mem_data_cfg,
                            self._data_in[i]))
        self.wire(self.RF_GEN.ports.data_in, self._mem_data_in)

        # Opt in for config_addr
        for i in range(self.interconnect_input_ports):
            self.wire(
                self._s_mem_write_addr[i],
                kts.ternary(self._config_en.r_or(), self._mem_addr_cfg,
                            self._write_addr[i]))

        self.wire(self.RF_GEN.ports.wr_addr, self._mem_write_addr[0])

        for i in range(self.interconnect_output_ports):
            self.wire(
                self._s_mem_read_addr[i],
                kts.ternary(self._config_en.r_or(), self._mem_addr_cfg,
                            self._read_addr[i]))

        self.wire(self.RF_GEN.ports.rd_addr, self._mem_read_addr[0])

        if self.add_clk_enable:
            # self.clock_en("clk_en")
            kts.passes.auto_insert_clock_enable(self.internal_generator)
            clk_en_port = self.internal_generator.get_port("clk_en")
            clk_en_port.add_attribute(ControlSignalAttr(False))

        if self.add_flush:
            self.add_attribute("sync-reset=flush")
            kts.passes.auto_insert_sync_reset(self.internal_generator)
            flush_port = self.internal_generator.get_port("flush")
            flush_port.add_attribute(ControlSignalAttr(True))

        # Finally, lift the config regs...
        lift_config_reg(self.internal_generator)
예제 #28
0
    def interact(self, ack_in, data_in, valid_in, ren_in, mem_valid_data):
        '''
        Returns (data_out, valid_out, rd_sync_gate, mem_valid_data_out)
        '''
        valid_out = []
        data_out = []
        rd_sync_gate = []
        mem_valid_data_out = []
        # # Use current state of bus to set local gate reduced
        # for i in range(self.int_out_ports):
        #     # For this port, just want to check that its corresponding entry is low
        #     for j in range(self.groups):
        #         if self.config[f"sync_group_{i}"] == (1 << j):
        #             self.local_gate_reduced[i] = self.local_gate[j][i]

        # Set the ren_int, ack_in combo
        ren_int = []
        for i in range(self.int_out_ports):
            ren_int.append(ren_in[i] & self.local_gate_reduced[i])

        # Create new local mask
        for i in range(self.groups):
            for j in range(self.int_out_ports):
                self.local_mask[i][j] = 1
                # If port j is in group i, set its gate mask
                if self.config[f"sync_group_{j}"] == (1 << i):
                    self.local_mask[i][j] = not (ren_int[j] &
                                                 ((ack_in & (1 << j)) != 0))
                    # self.local_mask[i][j] = not (ren_int[j] and ack_in[j])

        # Get group finished
        group_finished = []
        for i in range(self.groups):
            group_finished.append(1)
            # Check that either the bus or mask is low for all items in the group
            for j in range(self.int_out_ports):
                # Only check if the port is in the group
                if self.config[f"sync_group_{j}"] == (1 << i):
                    if self.local_gate[i][j] == 1 and self.local_mask[i][
                            j] == 1:
                        group_finished[i] = 0

        rd_sync_gate = self.get_rd_sync()

        for i in range(self.groups):
            for j in range(self.int_out_ports):
                if group_finished[i] == 1:
                    self.local_gate[i][j] = 1
                else:
                    self.local_gate[i][
                        j] = self.local_gate[i][j] & self.local_mask[i][j]

        # Can get the valid syncs now
        # We do this by checking each groups members
        # to all have their valid reg high
        for i in range(self.groups):
            self.sync_group_valid[i] = 1
            for j in range(self.int_out_ports):
                # If any member of the group isn't valid yet, the group isn't valid
                if (self.config[f"sync_group_{j}"]
                        == (1 << i)) and self.valid_reg[j] == 0:
                    self.sync_group_valid[i] = 0

        # Each port gets its group's sync valid
        for i in range(self.int_out_ports):
            valid_out.append(self.sync_group_valid[kts.clog2(
                self.config[f'sync_group_{i}'])])
            # valid_out.append(self.sync_group_valid[self.config[f'sync_group_{i}']])
            data_out.append(self.data_reg[i].copy())
            mem_valid_data_out.append(self.mem_valid_data_out_reg[i])

        # Update the registered valids - we keep these around to track
        # which valids in the group already came
        for i in range(self.int_out_ports):
            group_log = kts.clog2(self.config[f"sync_group_{i}"])
            if self.sync_group_valid[group_log] == 1 or self.valid_reg[i] == 0:
                self.valid_reg[i] = valid_in[i]
                self.data_reg[i] = data_in[i].copy()
                self.mem_valid_data_out_reg[i] = mem_valid_data[i]

        # Use current state of bus to set local gate reduced
        for i in range(self.int_out_ports):
            # For this port, just want to check that its corresponding entry is low
            for j in range(self.groups):
                if self.config[f"sync_group_{i}"] == (1 << j):
                    self.local_gate_reduced[i] = self.local_gate[j][i]

        # rd_sync_gate = self.get_rd_sync()
        return (data_out.copy(), valid_out, rd_sync_gate.copy(),
                mem_valid_data_out)