def bank_wr_packet_cache_comb(self):
     self.bank_wr_strb_cache_w = self.bank_wr_strb_cache_r
     self.bank_wr_data_cache_w = self.bank_wr_data_cache_r
     # First, if cached data is written to memory, clear it.
     if self.bank_wr_en:
         self.bank_wr_strb_cache_w = 0
         self.bank_wr_data_cache_w = 0
     # Next, save data to cache
     if self.strm_wr_en_w:
         if self.strm_data_sel == 0:
             self.bank_wr_strb_cache_w[self.cgra_strb_width - 1,
                                       0] = const(self.cgra_strb_value, self.cgra_strb_width)
             self.bank_wr_data_cache_w[self._params.cgra_data_width - 1, 0] = self.strm_wr_data_w
         elif self.strm_data_sel == 1:
             self.bank_wr_strb_cache_w[self.cgra_strb_width * 2 - 1,
                                       self.cgra_strb_width] = const(self.cgra_strb_value,
                                                                     self.cgra_strb_width)
             self.bank_wr_data_cache_w[self._params.cgra_data_width * 2 - 1,
                                       self._params.cgra_data_width] = self.strm_wr_data_w
         elif self.strm_data_sel == 2:
             self.bank_wr_strb_cache_w[self.cgra_strb_width * 3 - 1,
                                       self.cgra_strb_width * 2] = const(self.cgra_strb_value,
                                                                         self.cgra_strb_width)
             self.bank_wr_data_cache_w[self._params.cgra_data_width * 3 - 1,
                                       self._params.cgra_data_width * 2] = self.strm_wr_data_w
         elif self.strm_data_sel == 3:
             self.bank_wr_strb_cache_w[self.cgra_strb_width * 4 - 1,
                                       self.cgra_strb_width * 3] = const(self.cgra_strb_value,
                                                                         self.cgra_strb_width)
             self.bank_wr_data_cache_w[self._params.cgra_data_width * 4 - 1,
                                       self._params.cgra_data_width * 3] = self.strm_wr_data_w
         else:
             self.bank_wr_strb_cache_w = self.bank_wr_strb_cache_r
             self.bank_wr_data_cache_w = self.bank_wr_data_cache_r
Exemple #2
0
 def set_back_num_load(self):
     if self._back_pl:
         self._back_num_load = kts.ternary(self._pop,
                                           kts.const(self.fw_int - 1, self._back_num_load.width),
                                           kts.const(self.fw_int, self._back_num_load.width))
     else:
         self._back_num_load = 0
Exemple #3
0
 def add_sram_macro(self):
     for i in range(self.num_sram_macros):
         self.add_child(f"sram_array_{i}",
                        TS1N16FFCLLSBLVTC2048X64M8SW(),
                        CLK=self.CLK,
                        A=self.a_sram_d,
                        BWEB=self.BWEB_d,
                        CEB=self.ceb_demux_d[i],
                        WEB=self.web_demux_d[i],
                        D=self.D_d,
                        Q=self.q_sram2mux[i],
                        RTSEL=const(0b01, 2),
                        WTSEL=const(0b00, 2))
Exemple #4
0
def set_dryer(dryer, mod):
    dryer.output(mod.ports.dryer_done)
    heating = dryer.add_state("Heating")
    spin = dryer.add_state("Spin")
    cool_down = dryer.add_state("CoolDown")
    dryer_done = dryer.add_state("Done")
    heating.next(spin, None)
    spin.next(cool_down, None)
    cool_down.next(dryer_done, None)

    for state in (heating, spin, cool_down):
        state.output(mod.ports.dryer_done, const(0, 1))
    dryer_done.output(mod.ports.dryer_done, const(1, 1))
Exemple #5
0
def set_washer(washer, mod):
    washer.output(mod.ports.washer_done)
    water_fill = washer.add_state("WaterFill")
    spin = washer.add_state("Spin")
    drain = washer.add_state("Drain")
    washer_done = washer.add_state("Done")

    water_fill.next(spin, None)
    spin.next(drain, None)
    drain.next(washer_done, None)

    for state in (water_fill, spin, drain):
        state.output(mod.ports.washer_done, const(0, 1))
    washer_done.output(mod.ports.washer_done, const(1, 1))
Exemple #6
0
    def sram_ctrl_logic(self):
        if ~self.WEB:
            self.web_demux = ~(const(1, width=self.num_sram_macros) << resize(
                self.sram_sel, self.num_sram_macros))
        else:
            self.web_demux = const(2**self.num_sram_macros - 1,
                                   self.num_sram_macros)

        if ~self.CEB:
            self.ceb_demux = ~(const(1, width=self.num_sram_macros) << resize(
                self.sram_sel, self.num_sram_macros))
        else:
            self.ceb_demux = const(2**self.num_sram_macros - 1,
                                   self.num_sram_macros)
Exemple #7
0
def zext(gen, wire, size):
    if wire.width >= size:
        return wire
    else:
        zext_signal = gen.var(f"{wire.name}_zext", size)
        gen.wire(zext_signal, kts.concat(kts.const(0, size - wire.width),
                                         wire))
        return zext_signal
Exemple #8
0
 def send_writes(self, idx):
     # Send a write to a bank if the read to that bank isn't happening (unless you can do both)
     # further, we make sure there is a queued write, or current buffer is
     # full and there is an incoming push
     # and there is stuff in memory to be read
     self._wen_to_strg[idx] = ((~self._ren_to_strg[idx] | kts.const(int(self.rw_same_cycle), 1)) &
                               (self._queued_write[idx] |  # Already queued a write...
                               (((self._front_occ == self.fw_int) & self._push &  # Grossness
                                 (~self._front_pop)) &
                                (self._curr_bank_wr == idx))))
Exemple #9
0
def main():
    mod = Generator("LaundryMachine")
    washer_door = mod.input("washer_door", 1)
    dryer_door = mod.input("dryer_door", 1)
    clothes = mod.input("clothes", 1)
    washer_done = mod.output("washer_done", 1)
    dryer_done = mod.output("dryer_done", 1)
    done_laundry = mod.output("done", 1)
    mod.clock("clk")
    mod.reset("rst")

    main_fsm = mod.add_fsm("Laundry")
    washer = mod.add_fsm("Washer")
    dryer = mod.add_fsm("Dryer")
    main_fsm.add_child_fsm(washer)
    main_fsm.add_child_fsm(dryer)

    # add sub states
    set_washer(washer, mod)
    set_dryer(dryer, mod)

    # set the main state
    reset = main_fsm.add_state("Reset")
    main_fsm.set_start_state(reset)
    start_washing = main_fsm.add_state("Washing")
    start_drying = main_fsm.add_state("Drying")
    done = main_fsm.add_state("Done")

    reset.next(start_washing, washer_door == 1)
    start_washing.next(washer["WaterFill"], None)
    washer["Done"].next(start_drying, None)
    start_drying.next(dryer["Heating"], dryer_door == 1)
    dryer["Done"].next(done, None)
    done.next(reset, None)

    main_fsm.output(done_laundry)
    for state in (reset, start_washing, start_drying):
        state.output(done_laundry, const(0, 1))
    done.output(done_laundry, const(1, 1))

    verilog(mod, filename="laundry.sv")
    main_fsm.dot_graph("laundry.dot")
Exemple #10
0
    def add_pipeline(self):
        sram_signals_reset_high_in = concat(self.WEB, self.CEB, self.web_demux,
                                            self.ceb_demux, self.BWEB)
        sram_signals_reset_high_out = concat(self.WEB_d, self.CEB_d,
                                             self.web_demux_d,
                                             self.ceb_demux_d, self.BWEB_d)
        self.sram_signals_reset_high_pipeline = Pipeline(
            width=sram_signals_reset_high_in.width,
            depth=self._params.sram_gen_pipeline_depth,
            reset_high=True)
        self.add_child("sram_signals_reset_high_pipeline",
                       self.sram_signals_reset_high_pipeline,
                       clk=self.CLK,
                       clk_en=const(1, 1),
                       reset=self.RESET,
                       in_=sram_signals_reset_high_in,
                       out_=sram_signals_reset_high_out)

        sram_signals_in = concat(self.a_sram, self.sram_sel, self.D)
        sram_signals_out = concat(self.a_sram_d, self.sram_sel_d, self.D_d)
        self.sram_signals_pipeline = Pipeline(
            width=sram_signals_in.width,
            depth=self._params.sram_gen_pipeline_depth)
        self.add_child("sram_signals_pipeline",
                       self.sram_signals_pipeline,
                       clk=self.CLK,
                       clk_en=const(1, 1),
                       reset=self.RESET,
                       in_=sram_signals_in,
                       out_=sram_signals_out)

        self.sram_signals_output_pipeline = Pipeline(
            width=self.sram_macro_width,
            depth=self._params.sram_gen_output_pipeline_depth)
        self.add_child("sram_signals_output_pipeline",
                       self.sram_signals_output_pipeline,
                       clk=self.CLK,
                       clk_en=const(1, 1),
                       reset=self.RESET,
                       in_=self.Q_w,
                       out_=self.Q)
Exemple #11
0
 def add_sram_cfg_rd_addr_sel_pipeline(self):
     self.sram_cfg_rd_addr_sel_d = self.var("sram_cfg_rd_addr_sel_d", 1)
     self.sram_cfg_rd_addr_sel_pipeline = Pipeline(
         width=1, depth=self.bank_ctrl_pipeline_depth)
     self.add_child(
         "sram_cfg_rd_addr_sel_pipeline",
         self.sram_cfg_rd_addr_sel_pipeline,
         clk=self.clk,
         clk_en=const(1, 1),
         reset=self.reset,
         in_=self.if_sram_cfg_s.rd_addr[self._params.bank_byte_offset - 1],
         out_=self.sram_cfg_rd_addr_sel_d)
Exemple #12
0
def add_counter(generator, name, bitwidth, increment=kts.const(1, 1)):
    ctr = generator.var(name, bitwidth)

    @always_ff((posedge, "clk"), (negedge, "rst_n"))
    def ctr_inc_code():
        if ~generator._rst_n:
            ctr = 0
        elif increment:
            ctr = ctr + 1

    generator.add_code(ctr_inc_code)
    return ctr
 def add_pipeline(self):
     self.mem_pipeline = Pipeline(
         width=self.data_width,
         depth=(self._params.sram_gen_pipeline_depth +
                self._params.sram_gen_output_pipeline_depth))
     self.add_child("mem_pipeline",
                    self.mem_pipeline,
                    clk=self.CLK,
                    clk_en=const(1, 1),
                    reset=self.RESET,
                    in_=self.Q_w,
                    out_=self.Q)
Exemple #14
0
 def mem_signal_logic(self):
     if self.if_sram_cfg_s.wr_en:
         if self.if_sram_cfg_s.wr_addr[self._params.bank_byte_offset -
                                       1] == 0:
             self.mem_wr_en = 1
             self.mem_rd_en_w = 0
             self.mem_addr = self.if_sram_cfg_s.wr_addr
             self.mem_data_in = concat(
                 const(
                     0, self._params.bank_data_width -
                     self._params.axi_data_width),
                 self.if_sram_cfg_s.wr_data)
             self.mem_data_in_bit_sel = concat(
                 const(
                     0, self._params.bank_data_width -
                     self._params.axi_data_width),
                 const(2**self._params.axi_data_width - 1,
                       self._params.axi_data_width))
         else:
             self.mem_wr_en = 1
             self.mem_rd_en_w = 0
             self.mem_addr = self.if_sram_cfg_s.wr_addr
             self.mem_data_in = concat(
                 self.if_sram_cfg_s.wr_data[self._params.bank_data_width -
                                            self._params.axi_data_width - 1,
                                            0],
                 const(0, self._params.axi_data_width))
             self.mem_data_in_bit_sel = concat(
                 const(
                     2**(self._params.bank_data_width -
                         self._params.axi_data_width) - 1,
                     self._params.bank_data_width -
                     self._params.axi_data_width),
                 const(0, self._params.axi_data_width))
     elif self.if_sram_cfg_s.rd_en:
         self.mem_wr_en = 0
         self.mem_rd_en_w = 1
         self.mem_addr = self.if_sram_cfg_s.rd_addr
         self.mem_data_in = 0
         self.mem_data_in_bit_sel = 0
     elif self.packet_wr_en:
         self.mem_wr_en = 1
         self.mem_rd_en_w = 0
         self.mem_addr = self.packet_wr_addr
         self.mem_data_in = self.packet_wr_data
         self.mem_data_in_bit_sel = self.packet_wr_data_bit_sel
     elif self.packet_rd_en:
         self.mem_wr_en = 0
         self.mem_rd_en_w = 1
         self.mem_addr = self.packet_rd_addr
         self.mem_data_in = 0
         self.mem_data_in_bit_sel = 0
     else:
         self.mem_wr_en = 0
         self.mem_rd_en_w = 0
         self.mem_addr = 0
         self.mem_data_in = 0
         self.mem_data_in_bit_sel = 0
Exemple #15
0
 def pipeline(self):
     if self.reset:
         for i in range(self.depth):
             if self.reset_high:
                 self.pipeline_r[i] = const(2**self.width - 1, self.width)
             else:
                 self.pipeline_r[i] = 0
     elif self.clk_en:
         for i in range(self.depth):
             if i == 0:
                 self.pipeline_r[i] = self.in_
             else:
                 self.pipeline_r[i] = self.pipeline_r[resize(
                     i - 1, self.depth_width)]
Exemple #16
0
    def add_rd_en_pipeline(self):
        self.mem_rd_en_w = self.var("mem_rd_en_w", 1)
        self.mem_rd_en_d = self.var("mem_rd_en_d", 1)
        self.sram_cfg_rd_en_d = self.var("sram_cfg_rd_en_d", 1)
        self.packet_rd_en_d = self.var("packet_rd_en_d", 1)
        self.wire(self.mem_rd_en_w, self.mem_rd_en)

        self.mem_rd_en_pipeline = Pipeline(width=1,
                                           depth=self.bank_ctrl_pipeline_depth)
        self.add_child("mem_rd_en_pipeline",
                       self.mem_rd_en_pipeline,
                       clk=self.clk,
                       clk_en=const(1, 1),
                       reset=self.reset,
                       in_=self.mem_rd_en_w,
                       out_=self.mem_rd_en_d)

        self.sram_cfg_rd_en_pipeline = Pipeline(
            width=1, depth=self.bank_ctrl_pipeline_depth)
        self.add_child("sram_cfg_rd_en_pipeline",
                       self.sram_cfg_rd_en_pipeline,
                       clk=self.clk,
                       clk_en=const(1, 1),
                       reset=self.reset,
                       in_=self.if_sram_cfg_s.rd_en,
                       out_=self.sram_cfg_rd_en_d)

        self.packet_rd_en_pipeline = Pipeline(
            width=1, depth=self.bank_ctrl_pipeline_depth)
        self.add_child("packet_rd_en_pipeline",
                       self.packet_rd_en_pipeline,
                       clk=self.clk,
                       clk_en=const(1, 1),
                       reset=self.reset,
                       in_=self.packet_rd_en,
                       out_=self.packet_rd_en_d)
Exemple #17
0
 def add_pcfg_dma_done_pulse_pipeline(self):
     maximum_latency = 3 * self._params.num_glb_tiles + self.default_latency
     latency_width = clog2(maximum_latency)
     self.done_pulse_d_arr = self.var(
         "done_pulse_d_arr", 1, size=maximum_latency, explicit_array=True)
     self.done_pulse_pipeline = Pipeline(width=1,
                                         depth=maximum_latency,
                                         flatten_output=True)
     self.add_child("done_pulse_pipeline",
                    self.done_pulse_pipeline,
                    clk=self.clk,
                    clk_en=const(1, 1),
                    reset=self.reset,
                    in_=self.done_pulse_r,
                    out_=self.done_pulse_d_arr)
     self.wire(self.pcfg_done_pulse,
               self.done_pulse_d_arr[resize(self.cfg_pcfg_network_latency, latency_width)
                                     + self.default_latency
                                     + self._params.num_glb_tiles])
Exemple #18
0
def safe_wire(gen, w_to, w_from):
    '''
    Wire together two signals of (potentially) mismatched width to
    avoid the exception that Kratos throws.
    '''
    # Only works in one dimension...
    if w_to.width != w_from.width:
        if lake_util_verbose_trim:
            print(
                f"SAFEWIRE: WIDTH MISMATCH: {w_to.name} width {w_to.width} <-> {w_from.name} width {w_from.width}"
            )
        # w1 contains smaller width...
        if w_to.width < w_from.width:
            gen.wire(w_to, w_from[w_to.width - 1, 0])
        else:
            gen.wire(w_to[w_from.width - 1, 0], w_from)
            zero_overlap = w_to.width - w_from.width
            gen.wire(w_to[w_to.width - 1, w_from.width],
                     kts.const(0, zero_overlap))
    else:
        gen.wire(w_to, w_from)
Exemple #19
0
    def set_read_bank(self):
        if self.banks == 1:
            self.wire(self._rd_bank, kts.const(0, 1))
        else:
            # The read bank is comb if no delay, otherwise delayed
            if self.read_delay == 1:

                @always_ff((posedge, "clk"), (negedge, "rst_n"))
                def read_bank_ff(self):
                    if ~self._rst_n:
                        self._rd_bank = 0
                    else:
                        self._rd_bank = \
                            self._rd_addr[self.b_a_off + self.bank_width - 1, self.b_a_off]

                self.add_code(read_bank_ff)
            else:

                @always_comb
                def read_bank_comb(self):
                    self._rd_bank = \
                        self._rd_addr[self.b_a_off + self.bank_width - 1, self.b_a_off]

                self.add_code(read_bank_comb)
Exemple #20
0
 def tile2tile_w2e_wiring(self):
     self.wire(self.proc_packet_w2e_wsti[0], self.proc_packet_d)
     self.wire(self.strm_packet_w2e_wsti[0], 0)
     self.wire(self.pcfg_packet_w2e_wsti[0], 0)
     for i in range(1, self._params.num_glb_tiles):
         self.wire(
             self.proc_packet_w2e_wsti[const(
                 i, clog2(self._params.num_glb_tiles))],
             self.proc_packet_w2e_esto[const(
                 (i - 1), clog2(self._params.num_glb_tiles))])
         self.wire(
             self.strm_packet_w2e_wsti[const(
                 i, clog2(self._params.num_glb_tiles))],
             self.strm_packet_w2e_esto[const(
                 (i - 1), clog2(self._params.num_glb_tiles))])
         self.wire(
             self.pcfg_packet_w2e_wsti[const(
                 i, clog2(self._params.num_glb_tiles))],
             self.pcfg_packet_w2e_esto[const(
                 (i - 1), clog2(self._params.num_glb_tiles))])
Exemple #21
0
    def __init__(self,
                 data_width=16,  # CGRA Params
                 mem_width=64,
                 mem_depth=512,
                 banks=1,
                 input_iterator_support=6,  # Addr Controllers
                 output_iterator_support=6,
                 input_config_width=16,
                 output_config_width=16,
                 interconnect_input_ports=2,  # Connection to int
                 interconnect_output_ports=2,
                 mem_input_ports=1,
                 mem_output_ports=1,
                 read_delay=1,  # Cycle delay in read (SRAM vs Register File)
                 rw_same_cycle=False,  # Does the memory allow r+w in same cycle?
                 agg_height=4,
                 max_agg_schedule=32,
                 input_max_port_sched=32,
                 output_max_port_sched=32,
                 align_input=1,
                 max_line_length=128,
                 max_tb_height=1,
                 tb_range_max=128,
                 tb_range_inner_max=5,
                 tb_sched_max=64,
                 max_tb_stride=15,
                 num_tb=1,
                 tb_iterator_support=2,
                 multiwrite=1,
                 num_tiles=1,
                 max_prefetch=8,
                 app_ctrl_depth_width=16,
                 remove_tb=False,
                 stcl_valid_iter=4):
        super().__init__("strg_ub")

        self.data_width = data_width
        self.mem_width = mem_width
        self.mem_depth = mem_depth
        self.banks = banks
        self.input_iterator_support = input_iterator_support
        self.output_iterator_support = output_iterator_support
        self.input_config_width = input_config_width
        self.output_config_width = output_config_width
        self.interconnect_input_ports = interconnect_input_ports
        self.interconnect_output_ports = interconnect_output_ports
        self.mem_input_ports = mem_input_ports
        self.mem_output_ports = mem_output_ports
        self.agg_height = agg_height
        self.max_agg_schedule = max_agg_schedule
        self.input_max_port_sched = input_max_port_sched
        self.output_max_port_sched = output_max_port_sched
        self.input_port_sched_width = clog2(self.interconnect_input_ports)
        self.align_input = align_input
        self.max_line_length = max_line_length
        assert self.mem_width >= self.data_width, "Data width needs to be smaller than mem"
        self.fw_int = int(self.mem_width / self.data_width)
        self.num_tb = num_tb
        self.max_tb_height = max_tb_height
        self.tb_range_max = tb_range_max
        self.tb_range_inner_max = tb_range_inner_max
        self.max_tb_stride = max_tb_stride
        self.tb_sched_max = tb_sched_max
        self.tb_iterator_support = tb_iterator_support
        self.multiwrite = multiwrite
        self.max_prefetch = max_prefetch
        self.num_tiles = num_tiles
        self.app_ctrl_depth_width = app_ctrl_depth_width
        self.remove_tb = remove_tb
        self.read_delay = read_delay
        self.rw_same_cycle = rw_same_cycle
        self.stcl_valid_iter = stcl_valid_iter
        # phases = [] TODO

        self.address_width = clog2(self.num_tiles * self.mem_depth)

        # CLK and RST
        self._clk = self.clock("clk")
        self._rst_n = self.reset("rst_n")

        # INPUTS
        self._data_in = self.input("data_in",
                                   self.data_width,
                                   size=self.interconnect_input_ports,
                                   packed=True,
                                   explicit_array=True)

        self._wen_in = self.input("wen_in", self.interconnect_input_ports)
        self._ren_input = self.input("ren_in", self.interconnect_output_ports)
        # Post rate matched
        self._ren_in = self.var("ren_in_muxed", self.interconnect_output_ports)
        # Processed versions of wen and ren from the app ctrl
        self._wen = self.var("wen", self.interconnect_input_ports)
        self._ren = self.var("ren", self.interconnect_output_ports)

        # Add rate matched
        # If one input port, let any output port use the wen_in as the ren_in
        # If more, do the same thing but also provide port selection
        if self.interconnect_input_ports == 1:
            self._rate_matched = self.input("rate_matched", self.interconnect_output_ports)
            self._rate_matched.add_attribute(ConfigRegAttr("Rate matched - 1 or 0"))
            for i in range(self.interconnect_output_ports):
                self.wire(self._ren_in[i],
                          kts.ternary(self._rate_matched[i],
                                      self._wen_in,
                                      self._ren_input[i]))
        else:
            self._rate_matched = self.input("rate_matched", 1 + kts.clog2(self.interconnect_input_ports),
                                            size=self.interconnect_output_ports,
                                            explicit_array=True,
                                            packed=True)
            self._rate_matched.add_attribute(ConfigRegAttr("Rate matched [input port | on/off]"))
            for i in range(self.interconnect_output_ports):
                self.wire(self._ren_in[i],
                          kts.ternary(self._rate_matched[i][0],
                                      self._wen_in[self._rate_matched[i][kts.clog2(self.interconnect_input_ports), 1]],
                                      self._ren_input[i]))

        self._arb_wen_en = self.var("arb_wen_en", self.interconnect_input_ports)
        self._arb_ren_en = self.var("arb_ren_en", self.interconnect_output_ports)

        self._data_from_strg = self.input("data_from_strg",
                                          self.data_width,
                                          size=(self.banks,
                                                self.mem_output_ports,
                                                self.fw_int),
                                          packed=True,
                                          explicit_array=True)

        self._mem_valid_data = self.input("mem_valid_data",
                                          self.mem_output_ports,
                                          size=self.banks,
                                          explicit_array=True,
                                          packed=True)

        self._out_mem_valid_data = self.var("out_mem_valid_data",
                                            self.mem_output_ports,
                                            size=self.banks,
                                            explicit_array=True,
                                            packed=True)

        # We need to signal valids out of the agg buff, only if one exists...
        if self.agg_height > 0:
            self._to_iac_valid = self.var("ab_to_mem_valid",
                                          self.interconnect_input_ports)

        self._data_out = self.output("data_out",
                                     self.data_width,
                                     size=self.interconnect_output_ports,
                                     packed=True,
                                     explicit_array=True)

        self._valid_out = self.output("valid_out",
                                      self.interconnect_output_ports)

        self._valid_out_alt = self.var("valid_out_alt",
                                       self.interconnect_output_ports)

        self._data_to_strg = self.output("data_to_strg",
                                         self.data_width,
                                         size=(self.banks,
                                               self.mem_input_ports,
                                               self.fw_int),
                                         packed=True,
                                         explicit_array=True)

        # If we can perform a read and a write on the same cycle,
        # this will necessitate a separate read and write address...
        if self.rw_same_cycle:
            self._wr_addr_out = self.output("wr_addr_out",
                                            self.address_width,
                                            size=(self.banks,
                                                  self.mem_input_ports),
                                            explicit_array=True,
                                            packed=True)

            self._rd_addr_out = self.output("rd_addr_out",
                                            self.address_width,
                                            size=(self.banks,
                                                  self.mem_output_ports),
                                            explicit_array=True,
                                            packed=True)

        else:
            self._addr_out = self.output("addr_out",
                                         self.address_width,
                                         size=(self.banks,
                                               self.mem_input_ports),
                                         packed=True,
                                         explicit_array=True)

        self._cen_to_strg = self.output("cen_to_strg", self.mem_output_ports,
                                        size=self.banks,
                                        explicit_array=True,
                                        packed=True)
        self._wen_to_strg = self.output("wen_to_strg", self.mem_input_ports,
                                        size=self.banks,
                                        explicit_array=True,
                                        packed=True)
        if self.num_tb > 0:
            self._tb_valid_out = self.var("tb_valid_out", self.interconnect_output_ports)

        self._port_wens = self.var("port_wens", self.interconnect_input_ports)

        ####################
        ##### APP CTRL #####
        ####################
        self._ack_transpose = self.var("ack_transpose",
                                       self.banks,
                                       size=self.interconnect_output_ports,
                                       explicit_array=True,
                                       packed=True)

        self._ack_reduced = self.var("ack_reduced",
                                     self.interconnect_output_ports)

        self.app_ctrl = AppCtrl(interconnect_input_ports=self.interconnect_input_ports,
                                interconnect_output_ports=self.interconnect_output_ports,
                                depth_width=self.app_ctrl_depth_width,
                                sprt_stcl_valid=True,
                                stcl_iter_support=self.stcl_valid_iter)

        # Some refactoring here for pond to get rid of app controllers...
        # This is honestly pretty messy and should clean up nicely when we have the spec...
        self._ren_out_reduced = self.var("ren_out_reduced",
                                         self.interconnect_output_ports)

        if self.num_tb == 0 or self.remove_tb:
            self.wire(self._wen, self._wen_in)
            self.wire(self._ren, self._ren_in)
            self.wire(self._valid_out, self._valid_out_alt)
            self.wire(self._arb_wen_en, self._wen)
            self.wire(self._arb_ren_en, self._ren)
        else:
            self.add_child("app_ctrl", self.app_ctrl,
                           clk=self._clk,
                           rst_n=self._rst_n,
                           wen_in=self._wen_in,
                           ren_in=self._ren_in,
                           #    ren_update=self._tb_valid_out,
                           valid_out_data=self._valid_out,
                           # valid_out_stencil=,
                           wen_out=self._wen,
                           ren_out=self._ren)

            self.wire(self.app_ctrl.ports.tb_valid, self._tb_valid_out)
            self.wire(self.app_ctrl.ports.ren_update, self._tb_valid_out)

            self.app_ctrl_coarse = AppCtrl(interconnect_input_ports=self.interconnect_input_ports,
                                           interconnect_output_ports=self.interconnect_output_ports,
                                           depth_width=self.app_ctrl_depth_width)
            self.add_child("app_ctrl_coarse", self.app_ctrl_coarse,
                           clk=self._clk,
                           rst_n=self._rst_n,
                           wen_in=self._to_iac_valid,  # self._port_wens & self._to_iac_valid,  # Gets valid and the ack
                           ren_in=self._ren_out_reduced,
                           tb_valid=kts.const(0, 1),
                           ren_update=self._ack_reduced,
                           wen_out=self._arb_wen_en,
                           ren_out=self._arb_ren_en)

        ###########################
        ##### INPUT AGG SCHED #####
        ###########################
        ###########################################
        ##### AGGREGATION ALIGNERS (OPTIONAL) #####
        ###########################################
        # These variables are holders and can be swapped out if needed
        self._data_consume = self._data_in
        self._valid_consume = self._wen
        # Zero out if not aligning
        if(self.agg_height > 0):
            self._align_to_agg = self.var("align_input",
                                          self.interconnect_input_ports)
        # Add the aggregation buffer aligners
        if(self.align_input):
            self._data_consume = self.var("data_consume",
                                          self.data_width,
                                          size=self.interconnect_input_ports,
                                          packed=True,
                                          explicit_array=True)
            self._valid_consume = self.var("valid_consume",
                                           self.interconnect_input_ports)

            # Make new aggregation aligners for each port
            for i in range(self.interconnect_input_ports):
                new_child = AggAligner(self.data_width, self.max_line_length)
                self.add_child(f"agg_align_{i}", new_child,
                               clk=self._clk,
                               rst_n=self._rst_n,
                               in_dat=self._data_in[i],
                               in_valid=self._wen[i],
                               align=self._align_to_agg[i],
                               out_valid=self._valid_consume[i],
                               out_dat=self._data_consume[i])
        else:
            if self.agg_height > 0:
                self.wire(self._align_to_agg, const(0, self._align_to_agg.width))
        ################################################
        ##### END: AGGREGATION ALIGNERS (OPTIONAL) #####
        ################################################

        if self.agg_height == 0:
            self._to_iac_dat = self._data_consume
            self._to_iac_valid = self._valid_consume

        ##################################
        ##### AGG BUFFERS (OPTIONAL) #####
        ##################################
        # Only instantiate agg_buffer if needed
        if(self.agg_height > 0):
            self._to_iac_dat = self.var("ab_to_mem_dat",
                                        self.mem_width,
                                        size=self.interconnect_input_ports,
                                        packed=True,
                                        explicit_array=True)

            # self._to_iac_valid = self.var("ab_to_mem_valid",
            #                               self.interconnect_input_ports)

            self._agg_buffers = []
            # Add input aggregations buffers
            for i in range(self.interconnect_input_ports):
                # add children aggregator buffers...
                agg_buffer_new = AggregationBuffer(self.agg_height,
                                                   self.data_width,
                                                   self.mem_width,
                                                   self.max_agg_schedule)
                self._agg_buffers.append(agg_buffer_new)
                self.add_child(f"agg_in_{i}", agg_buffer_new,
                               clk=self._clk,
                               rst_n=self._rst_n,
                               data_in=self._data_consume[i],
                               valid_in=self._valid_consume[i],
                               align=self._align_to_agg[i],
                               data_out=self._to_iac_dat[i],
                               valid_out=self._to_iac_valid[i])

        #######################################
        ##### END: AGG BUFFERS (OPTIONAL) #####
        #######################################

        self._ready_tba = self.var("ready_tba", self.interconnect_output_ports)
        ####################################
        ##### INPUT ADDRESS CONTROLLER #####
        ####################################
        self._wen_to_arb = self.var("wen_to_arb", self.mem_input_ports,
                                    size=self.banks,
                                    explicit_array=True,
                                    packed=True)
        self._addr_to_arb = self.var("addr_to_arb",
                                     self.address_width,
                                     size=(self.banks,
                                           self.mem_input_ports),
                                     explicit_array=True,
                                     packed=True)
        self._data_to_arb = self.var("data_to_arb",
                                     self.data_width,
                                     size=(self.banks,
                                           self.mem_input_ports,
                                           self.fw_int),
                                     explicit_array=True,
                                     packed=True)

        # Connect these inputs ports to an address generator
        iac = InputAddrCtrl(interconnect_input_ports=self.interconnect_input_ports,
                            mem_depth=self.mem_depth,
                            num_tiles=self.num_tiles,
                            banks=self.banks,
                            iterator_support=self.input_iterator_support,
                            address_width=self.address_width,
                            data_width=self.data_width,
                            fetch_width=self.mem_width,
                            multiwrite=self.multiwrite,
                            strg_wr_ports=self.mem_input_ports,
                            config_width=self.input_config_width)
        self.add_child(f"input_addr_ctrl", iac,
                       clk=self._clk,
                       rst_n=self._rst_n,
                       valid_in=self._to_iac_valid,
                       # wen_en=kts.concat(*([kts.const(1, 1)] * self.interconnect_input_ports)),
                       wen_en=self._arb_wen_en,
                       data_in=self._to_iac_dat,
                       wen_to_sram=self._wen_to_arb,
                       addr_out=self._addr_to_arb,
                       port_out=self._port_wens,
                       data_out=self._data_to_arb)

        #########################################
        ##### END: INPUT ADDRESS CONTROLLER #####
        #########################################
        self._arb_acks = self.var("arb_acks",
                                  self.interconnect_output_ports,
                                  size=self.banks,
                                  explicit_array=True,
                                  packed=True)

        self._prefetch_step = self.var("prefetch_step", self.interconnect_output_ports)
        self._oac_step = self.var("oac_step", self.interconnect_output_ports)
        self._oac_valid = self.var("oac_valid", self.interconnect_output_ports)
        self._ren_out = self.var("ren_out",
                                 self.interconnect_output_ports,
                                 size=self.banks,
                                 explicit_array=True,
                                 packed=True)
        self._ren_out_tpose = self.var("ren_out_tpose",
                                       self.banks,
                                       size=self.interconnect_output_ports,
                                       explicit_array=True,
                                       packed=True)

        self._oac_addr_out = self.var("oac_addr_out",
                                      self.address_width,
                                      size=self.interconnect_output_ports,
                                      explicit_array=True,
                                      packed=True)
        #####################################
        ##### OUTPUT ADDRESS CONTROLLER #####
        #####################################
        oac = OutputAddrCtrl(interconnect_output_ports=self.interconnect_output_ports,
                             mem_depth=self.mem_depth,
                             num_tiles=self.num_tiles,
                             banks=self.banks,
                             iterator_support=self.output_iterator_support,
                             address_width=self.address_width,
                             config_width=self.output_config_width)

        if self.remove_tb:
            self.wire(self._oac_valid, self._ren)
            self.wire(self._oac_step, self._ren)
        else:
            self.wire(self._oac_valid, self._prefetch_step)
            self.wire(self._oac_step, self._ack_reduced)

        self.chain_idx_bits = max(1, clog2(num_tiles))
        self._enable_chain_output = self.input("enable_chain_output", 1)
        self._chain_idx_output = self.input("chain_idx_output", self.chain_idx_bits)

        self.add_child(f"output_addr_ctrl", oac,
                       clk=self._clk,
                       rst_n=self._rst_n,
                       valid_in=self._oac_valid,
                       ren=self._ren_out,
                       addr_out=self._oac_addr_out,
                       step_in=self._oac_step)

        for i in range(self.interconnect_output_ports):
            for j in range(self.banks):
                self.wire(self._ren_out_tpose[i][j], self._ren_out[j][i])
        ##############################
        ##### READ/WRITE ARBITER #####
        ##############################
        # Hook up the read write arbiters for each bank
        self._arb_dat_out = self.var("arb_dat_out",
                                     self.data_width,
                                     size=(self.banks,
                                           self.mem_output_ports,
                                           self.fw_int),
                                     explicit_array=True,
                                     packed=True)

        self._arb_port_out = self.var("arb_port_out",
                                      self.interconnect_output_ports,
                                      size=(self.banks,
                                            self.mem_output_ports),
                                      explicit_array=True,
                                      packed=True)
        self._arb_valid_out = self.var("arb_valid_out", self.mem_output_ports,
                                       size=self.banks,
                                       explicit_array=True,
                                       packed=True)

        self._rd_sync_gate = self.var("rd_sync_gate",
                                      self.interconnect_output_ports)

        self.arbiters = []
        for i in range(self.banks):
            rw_arb = RWArbiter(fetch_width=self.mem_width,
                               data_width=self.data_width,
                               memory_depth=self.mem_depth,
                               num_tiles=self.num_tiles,
                               int_in_ports=self.interconnect_input_ports,
                               int_out_ports=self.interconnect_output_ports,
                               strg_wr_ports=self.mem_input_ports,
                               strg_rd_ports=self.mem_output_ports,
                               read_delay=self.read_delay,
                               rw_same_cycle=self.rw_same_cycle,
                               separate_addresses=self.rw_same_cycle)
            self.arbiters.append(rw_arb)
            self.add_child(f"rw_arb_{i}", rw_arb,
                           clk=self._clk,
                           rst_n=self._rst_n,
                           wen_in=self._wen_to_arb[i],
                           w_data=self._data_to_arb[i],
                           w_addr=self._addr_to_arb[i],
                           data_from_mem=self._data_from_strg[i],
                           mem_valid_data=self._mem_valid_data[i],
                           out_mem_valid_data=self._out_mem_valid_data[i],
                           ren_en=self._arb_ren_en,
                           rd_addr=self._oac_addr_out,
                           out_data=self._arb_dat_out[i],
                           out_port=self._arb_port_out[i],
                           out_valid=self._arb_valid_out[i],
                           cen_mem=self._cen_to_strg[i],
                           wen_mem=self._wen_to_strg[i],
                           data_to_mem=self._data_to_strg[i],
                           out_ack=self._arb_acks[i])

            # Bind the separate addrs
            if self.rw_same_cycle:
                self.wire(rw_arb.ports.wr_addr_to_mem, self._wr_addr_out[i])
                self.wire(rw_arb.ports.rd_addr_to_mem, self._rd_addr_out[i])
            else:
                self.wire(rw_arb.ports.addr_to_mem, self._addr_out[i])

            if self.remove_tb:
                self.wire(rw_arb.ports.ren_in, self._ren_out[i])
            else:
                self.wire(rw_arb.ports.ren_in, self._ren_out[i] & self._rd_sync_gate)

        self.num_tb_bits = max(1, clog2(self.num_tb))

        self._data_to_sync = self.var("data_to_sync",
                                      self.data_width,
                                      size=(self.interconnect_output_ports,
                                            self.fw_int),
                                      explicit_array=True,
                                      packed=True)

        self._valid_to_sync = self.var("valid_to_sync", self.interconnect_output_ports)

        self._data_to_tba = self.var("data_to_tba",
                                     self.data_width,
                                     size=(self.interconnect_output_ports,
                                           self.fw_int),
                                     explicit_array=True,
                                     packed=True)

        self._valid_to_tba = self.var("valid_to_tba", self.interconnect_output_ports)

        self._data_to_pref = self.var("data_to_pref",
                                      self.data_width,
                                      size=(self.interconnect_output_ports,
                                            self.fw_int),
                                      explicit_array=True,
                                      packed=True)

        self._valid_to_pref = self.var("valid_to_pref", self.interconnect_output_ports)
        #######################
        ##### DEMUX READS #####
        #######################
        dmux_rd = DemuxReads(fetch_width=self.mem_width,
                             data_width=self.data_width,
                             banks=self.banks,
                             int_out_ports=self.interconnect_output_ports,
                             strg_rd_ports=self.mem_output_ports)

        self._arb_dat_out_f = self.var("arb_dat_out_f",
                                       self.data_width,
                                       size=(self.banks * self.mem_output_ports,
                                             self.fw_int),
                                       explicit_array=True,
                                       packed=True)

        self._arb_port_out_f = self.var("arb_port_out_f",
                                        self.interconnect_output_ports,
                                        size=(self.banks * self.mem_output_ports),
                                        explicit_array=True,
                                        packed=True)
        self._arb_valid_out_f = self.var("arb_valid_out_f", self.mem_output_ports * self.banks)
        self._arb_mem_valid_data_f = self.var("arb_mem_valid_data_f", self.mem_output_ports * self.banks)

        self._arb_mem_valid_data_out = self.var("arb_mem_valid_data_out",
                                                self.interconnect_output_ports)

        self._mem_valid_data_sync = self.var("mem_valid_data_sync",
                                             self.interconnect_output_ports)

        self._mem_valid_data_pref = self.var("mem_valid_data_pref",
                                             self.interconnect_output_ports)

        tmp_cnt = 0
        for i in range(self.banks):
            for j in range(self.mem_output_ports):
                self.wire(self._arb_dat_out_f[tmp_cnt], self._arb_dat_out[i][j])
                self.wire(self._arb_port_out_f[tmp_cnt], self._arb_port_out[i][j])
                self.wire(self._arb_valid_out_f[tmp_cnt], self._arb_valid_out[i][j])
                self.wire(self._arb_mem_valid_data_f[tmp_cnt], self._out_mem_valid_data[i][j])
                tmp_cnt = tmp_cnt + 1

        # If this is end of the road...
        if self.remove_tb:
            assert self.fw_int == 1, "Make it easier on me now..."
            self.add_child("demux_rds", dmux_rd,
                           clk=self._clk,
                           rst_n=self._rst_n,
                           data_in=self._arb_dat_out_f,
                           mem_valid_data=self._arb_mem_valid_data_f,
                           mem_valid_data_out=self._arb_mem_valid_data_out,
                           valid_in=self._arb_valid_out_f,
                           port_in=self._arb_port_out_f,
                           valid_out=self._valid_out_alt)
            for i in range(self.interconnect_output_ports):
                self.wire(self._data_out[i], dmux_rd.ports.data_out[i])

        else:
            self.add_child("demux_rds", dmux_rd,
                           clk=self._clk,
                           rst_n=self._rst_n,
                           data_in=self._arb_dat_out_f,
                           mem_valid_data=self._arb_mem_valid_data_f,
                           mem_valid_data_out=self._arb_mem_valid_data_out,
                           valid_in=self._arb_valid_out_f,
                           port_in=self._arb_port_out_f,
                           data_out=self._data_to_sync,
                           valid_out=self._valid_to_sync)

            #######################
            ##### SYNC GROUPS #####
            #######################
            sync_group = SyncGroups(fetch_width=self.mem_width,
                                    data_width=self.data_width,
                                    int_out_ports=self.interconnect_output_ports)

            for i in range(self.interconnect_output_ports):
                self.wire(self._ren_out_reduced[i], self._ren_out_tpose[i].r_or())

            self.add_child("sync_grp", sync_group,
                           clk=self._clk,
                           rst_n=self._rst_n,
                           data_in=self._data_to_sync,
                           mem_valid_data=self._arb_mem_valid_data_out,
                           mem_valid_data_out=self._mem_valid_data_sync,
                           valid_in=self._valid_to_sync,
                           data_out=self._data_to_pref,
                           valid_out=self._valid_to_pref,
                           ren_in=self._ren_out_reduced,
                           rd_sync_gate=self._rd_sync_gate,
                           ack_in=self._ack_reduced)

            # This is the end of the line if we aren't using tb
            ######################
            ##### PREFETCHER #####
            ######################
            prefetchers = []
            for i in range(self.interconnect_output_ports):

                pref = Prefetcher(fetch_width=self.mem_width,
                                  data_width=self.data_width,
                                  max_prefetch=self.max_prefetch)

                prefetchers.append(pref)

                if self.num_tb == 0:
                    assert self.fw_int == 1, \
                        "If no transpose buffer, data width needs match memory width"
                    self.add_child(f"pre_fetch_{i}", pref,
                                   clk=self._clk,
                                   rst_n=self._rst_n,
                                   data_in=self._data_to_pref[i],
                                   mem_valid_data=self._mem_valid_data_sync[i],
                                   mem_valid_data_out=self._mem_valid_data_pref[i],
                                   valid_read=self._valid_to_pref[i],
                                   tba_rdy_in=self._ren[i],
                                   #    data_out=self._data_out[i],
                                   valid_out=self._valid_out_alt[i],
                                   prefetch_step=self._prefetch_step[i])
                    self.wire(self._data_out[i], pref.ports.data_out[0])
                else:
                    self.add_child(f"pre_fetch_{i}", pref,
                                   clk=self._clk,
                                   rst_n=self._rst_n,
                                   data_in=self._data_to_pref[i],
                                   mem_valid_data=self._mem_valid_data_sync[i],
                                   mem_valid_data_out=self._mem_valid_data_pref[i],
                                   valid_read=self._valid_to_pref[i],
                                   tba_rdy_in=self._ready_tba[i],
                                   data_out=self._data_to_tba[i],
                                   valid_out=self._valid_to_tba[i],
                                   prefetch_step=self._prefetch_step[i])

                    #############################
                    ##### TRANSPOSE BUFFERS #####
                    #############################
            if self.num_tb > 0:

                self._tb_data_out = self.var("tb_data_out", self.data_width,
                                             size=self.interconnect_output_ports,
                                             explicit_array=True,
                                             packed=True)
                self._tb_valid_out = self.var("tb_valid_out", self.interconnect_output_ports)

                for i in range(self.interconnect_output_ports):

                    tba = TransposeBufferAggregation(word_width=self.data_width,
                                                     fetch_width=self.fw_int,
                                                     num_tb=self.num_tb,
                                                     max_tb_height=self.max_tb_height,
                                                     max_range=self.tb_range_max,
                                                     max_range_inner=self.tb_range_inner_max,
                                                     max_stride=self.max_tb_stride,
                                                     tb_iterator_support=self.tb_iterator_support)

                    self.add_child(f"tba_{i}", tba,
                                   clk=self._clk,
                                   rst_n=self._rst_n,
                                   SRAM_to_tb_data=self._data_to_tba[i],
                                   valid_data=self._valid_to_tba[i],
                                   tb_index_for_data=0,
                                   ack_in=self._valid_to_tba[i],
                                   mem_valid_data=self._mem_valid_data_pref[i],
                                   tb_to_interconnect_data=self._tb_data_out[i],
                                   tb_to_interconnect_valid=self._tb_valid_out[i],
                                   tb_arbiter_rdy=self._ready_tba[i],
                                   tba_ren=self._ren[i])

                for i in range(self.interconnect_output_ports):
                    self.wire(self._data_out[i], self._tb_data_out[i])
                    # self.wire(self._valid_out[i], self._tb_valid_out[i])
            else:
                self.wire(self._valid_out, self._valid_out_alt)

        ####################
        ##### ADD CODE #####
        ####################
        self.add_code(self.transpose_acks)
        self.add_code(self.reduce_acks)
Exemple #22
0
def decrement(var, value):
    return var - kts.const(value, var.width)
Exemple #23
0
 def set_inc(self, idx):
     self.inc[idx] = 0
     if (const(idx, 5) == 0) & self.step & (idx < self.dim):
         self.inc[idx] = 1
     elif (idx == self.mux_sel) & self.step & (idx < self.dim):
         self.inc[idx] = 1
 def code(self):
     # because the types are inferred
     # implicit const conversion doesn't work here
     self._reg = func(const(1, 2))
     self._out = func(const(1, 2))
Exemple #25
0
    def __init__(self,
                 data_width=16,
                 banks=1,
                 memory_width=64,
                 rw_same_cycle=False,
                 read_delay=1,
                 addr_width=9):
        super().__init__("strg_fifo")

        # Generation parameters
        self.banks = banks
        self.data_width = data_width
        self.memory_width = memory_width
        self.rw_same_cycle = rw_same_cycle
        self.read_delay = read_delay
        self.addr_width = addr_width
        self.fw_int = int(self.memory_width / self.data_width)

        # assert banks > 1 or rw_same_cycle is True or self.fw_int > 1, \
        #     "Can't sustain throughput with this setup. Need potential bandwidth for " + \
        #     "1 write and 1 read in a cycle - try using more banks or a macro that supports 1R1W"

        # Clock and Reset
        self._clk = self.clock("clk")
        self._rst_n = self.reset("rst_n")

        # Inputs + Outputs
        self._push = self.input("push", 1)
        self._data_in = self.input("data_in", self.data_width)
        self._pop = self.input("pop", 1)

        self._data_out = self.output("data_out", self.data_width)
        self._valid_out = self.output("valid_out", 1)
        self._empty = self.output("empty", 1)
        self._full = self.output("full", 1)

        # get relevant signals from the storage banks
        self._data_from_strg = self.input("data_from_strg", self.data_width,
                                          size=(self.banks,
                                                self.fw_int),
                                          explicit_array=True,
                                          packed=True)

        self._wen_addr = self.var("wen_addr", self.addr_width,
                                  size=self.banks,
                                  explicit_array=True,
                                  packed=True)

        self._ren_addr = self.var("ren_addr", self.addr_width,
                                  size=self.banks,
                                  explicit_array=True,
                                  packed=True)

        self._front_combined = self.var("front_combined", self.data_width,
                                        size=self.fw_int,
                                        explicit_array=True,
                                        packed=True)

        self._data_to_strg = self.output("data_to_strg", self.data_width,
                                         size=(self.banks,
                                               self.fw_int),
                                         explicit_array=True,
                                         packed=True)

        self._wen_to_strg = self.output("wen_to_strg", self.banks)
        self._ren_to_strg = self.output("ren_to_strg", self.banks)

        self._num_words_mem = self.var("num_words_mem", self.data_width)

        if self.banks == 1:
            self._curr_bank_wr = self.var("curr_bank_wr", 1)
            self.wire(self._curr_bank_wr, kts.const(0, 1))
            self._curr_bank_rd = self.var("curr_bank_rd", 1)
            self.wire(self._curr_bank_rd, kts.const(0, 1))
        else:
            self._curr_bank_wr = self.var("curr_bank_wr", kts.clog2(self.banks))
            self._curr_bank_rd = self.var("curr_bank_rd", kts.clog2(self.banks))

        self._write_queue = self.var("write_queue", self.data_width,
                                     size=(self.banks,
                                           self.fw_int),
                                     explicit_array=True,
                                     packed=True)

        # Lets us know if the bank has a write queued up
        self._queued_write = self.var("queued_write", self.banks)

        self._front_data_out = self.var("front_data_out", self.data_width)
        self._front_pop = self.var("front_pop", 1)
        self._front_empty = self.var("front_empty", 1)
        self._front_full = self.var("front_full", 1)
        self._front_valid = self.var("front_valid", 1)
        self._front_par_read = self.var("front_par_read", 1)
        self._front_par_out = self.var("front_par_out", self.data_width,
                                       size=(self.fw_int,
                                             1),
                                       explicit_array=True,
                                       packed=True)

        self._front_rd_ptr = self.var("front_rd_ptr", max(1, clog2(self.fw_int)))

        self._front_push = self.var("front_push", 1)
        self.wire(self._front_push, self._push & (~self._full | self._pop))

        self._front_rf = RegFIFO(data_width=self.data_width,
                                 width_mult=1,
                                 depth=self.fw_int,
                                 parallel=True,
                                 break_out_rd_ptr=True)

        # This one breaks out the read pointer so we can properly
        # reorder the data to storage

        self.add_child("front_rf", self._front_rf,
                       clk=self._clk,
                       clk_en=kts.const(1, 1),
                       rst_n=self._rst_n,
                       push=self._front_push,
                       pop=self._front_pop,
                       empty=self._front_empty,
                       full=self._front_full,
                       valid=self._front_valid,
                       parallel_read=self._front_par_read,
                       parallel_load=kts.const(0, 1),  # We don't need to parallel load the front
                       parallel_in=0,  # Same reason as above
                       parallel_out=self._front_par_out,
                       num_load=0,
                       rd_ptr_out=self._front_rd_ptr)
        self.wire(self._front_rf.ports.data_in[0], self._data_in)
        self.wire(self._front_data_out, self._front_rf.ports.data_out[0])

        self._back_data_in = self.var("back_data_in", self.data_width)
        self._back_data_out = self.var("back_data_out", self.data_width)
        self._back_push = self.var("back_push", 1)
        self._back_empty = self.var("back_empty", 1)
        self._back_full = self.var("back_full", 1)
        self._back_valid = self.var("back_valid", 1)
        self._back_pl = self.var("back_pl", 1)
        self._back_par_in = self.var("back_par_in", self.data_width,
                                     size=(self.fw_int,
                                           1),
                                     explicit_array=True,
                                     packed=True)
        self._back_num_load = self.var("back_num_load", clog2(self.fw_int) + 1)

        self._back_occ = self.var("back_occ", clog2(self.fw_int) + 1)
        self._front_occ = self.var("front_occ", clog2(self.fw_int) + 1)

        self._back_rf = RegFIFO(data_width=self.data_width,
                                width_mult=1,
                                depth=self.fw_int,
                                parallel=True,
                                break_out_rd_ptr=False)

        self._fw_is_1 = self.var("fw_is_1", 1)
        self.wire(self._fw_is_1, kts.const(self.fw_int == 1, 1))

        self._back_pop = self.var("back_pop", 1)
        if self.fw_int == 1:
            self.wire(self._back_pop, self._pop & (~self._empty | self._push) & ~self._back_pl)
        else:
            self.wire(self._back_pop, self._pop & (~self._empty | self._push))

        self.add_child("back_rf", self._back_rf,
                       clk=self._clk,
                       clk_en=kts.const(1, 1),
                       rst_n=self._rst_n,
                       push=self._back_push,
                       pop=self._back_pop,
                       empty=self._back_empty,
                       full=self._back_full,
                       valid=self._back_valid,
                       parallel_read=kts.const(0, 1),
                       # Only do back load when data is going there
                       parallel_load=self._back_pl & self._back_num_load.r_or(),
                       parallel_in=self._back_par_in,
                       num_load=self._back_num_load)
        self.wire(self._back_rf.ports.data_in[0], self._back_data_in)
        self.wire(self._back_data_out, self._back_rf.ports.data_out[0])
        # send the writes through when a read isn't happening
        for i in range(self.banks):
            self.add_code(self.send_writes, idx=i)
            self.add_code(self.send_reads, idx=i)

        # Set the parallel load to back bank - if no delay it's immediate
        # if not, it's delayed :)
        if self.read_delay == 1:
            self._ren_delay = self.var("ren_delay", 1)
            self.add_code(self.set_parallel_ld_delay_1)
            self.wire(self._back_pl, self._ren_delay)
        else:
            self.wire(self._back_pl, self._ren_to_strg.r_or())

        # Combine front end data - just the items + incoming
        # this data is actually based on the rd_ptr from the front fifo
        for i in range(self.fw_int):
            self.wire(self._front_combined[i], self._front_par_out[self._front_rd_ptr + i])
        # This is always true
        # self.wire(self._front_combined[self.fw_int - 1], self._data_in)

        # prioritize queued writes, otherwise send combined data
        for i in range(self.banks):
            self.wire(self._data_to_strg[i],
                      kts.ternary(self._queued_write[i], self._write_queue[i], self._front_combined))

        # Wire the thin output from front to thin input to back
        self.wire(self._back_data_in, self._front_data_out)
        self.wire(self._back_push, self._front_valid)
        self.add_code(self.set_front_pop)

        # Queue writes
        for i in range(self.banks):
            self.add_code(self.set_write_queue, idx=i)

        # Track number of words in memory
        # if self.read_delay == 1:
        #     self.add_code(self.set_num_words_mem_delay)
        # else:
        self.add_code(self.set_num_words_mem)

        # Track occupancy of the two small fifos
        self.add_code(self.set_front_occ)
        self.add_code(self.set_back_occ)

        if self.banks > 1:
            self.add_code(self.set_curr_bank_wr)
            self.add_code(self.set_curr_bank_rd)
        if self.read_delay == 1:
            self._prev_bank_rd = self.var("prev_bank_rd", max(1, kts.clog2(self.banks)))
            self.add_code(self.set_prev_bank_rd)

        # Parallel load data to back - based on num_load
        index_into = self._curr_bank_rd
        if self.read_delay == 1:
            index_into = self._prev_bank_rd
        for i in range(self.fw_int - 1):
            # Shift data over if you bypassed from the memory output
            self.wire(self._back_par_in[i],
                      kts.ternary(self._back_num_load == self.fw_int,
                                  self._data_from_strg[index_into][i],
                                  self._data_from_strg[index_into][i + 1]))
        self.wire(self._back_par_in[self.fw_int - 1],
                  kts.ternary(self._back_num_load == self.fw_int,
                              self._data_from_strg[index_into][self.fw_int - 1],
                              kts.const(0, self.data_width)))

        # Set the parallel read to the front fifo - analogous with trying to write to the memory
        self.add_code(self.set_front_par_read)

        # Set the number being parallely loaded into the register
        self.add_code(self.set_back_num_load)

        # Data out and valid out are (in the general case) just the data and valid from the back fifo
        # In the case where we have a fresh memory read, it would be from that
        bank_idx_read = self._curr_bank_rd
        if self.read_delay == 1:
            bank_idx_read = self._prev_bank_rd
        self.wire(self._data_out,
                  kts.ternary(self._back_pl, self._data_from_strg[bank_idx_read][0], self._back_data_out))
        self.wire(self._valid_out, kts.ternary(self._back_pl, self._pop, self._back_valid))

        # Set addresses to storage
        for i in range(self.banks):
            self.add_code(self.set_wen_addr, idx=i)
            self.add_code(self.set_ren_addr, idx=i)
        # Now deal with a shared address vs separate addresses
        if self.rw_same_cycle:
            # Separate
            self._wen_addr_out = self.output("wen_addr_out", self.addr_width,
                                             size=self.banks,
                                             explicit_array=True,
                                             packed=True)
            self._ren_addr_out = self.output("ren_addr_out", self.addr_width,
                                             size=self.banks,
                                             explicit_array=True,
                                             packed=True)
            self.wire(self._wen_addr_out, self._wen_addr)
            self.wire(self._ren_addr_out, self._ren_addr)
        else:
            self._addr_out = self.output("addr_out", self.addr_width,
                                         size=self.banks,
                                         explicit_array=True,
                                         packed=True)
            # If sharing the addresses, send read addr with priority
            for i in range(self.banks):
                self.wire(self._addr_out[i],
                          kts.ternary(self._wen_to_strg[i], self._wen_addr[i], self._ren_addr[i]))

        # Do final empty/full
        self._num_items = self.var("num_items", self.data_width)
        self.add_code(self.set_num_items)
        self._fifo_depth = self.input("fifo_depth", self.data_width)
        self._fifo_depth.add_attribute(ConfigRegAttr("Fifo depth..."))
        self.wire(self._empty, self._num_items == 0)
        self.wire(self._full, self._num_items == (self._fifo_depth))
Exemple #26
0
    def __init__(self,
                 data_width,
                 config_addr_width,
                 addr_width,
                 fetch_width,
                 total_sets,
                 sets_per_macro):
        super().__init__("storage_config_seq")

        self.data_width = data_width
        self.config_addr_width = config_addr_width
        self.addr_width = addr_width
        self.fetch_width = fetch_width
        self.fw_int = int(self.fetch_width / self.data_width)
        self.total_sets = total_sets
        self.sets_per_macro = sets_per_macro
        self.banks = int(self.total_sets / self.sets_per_macro)

        self.set_addr_width = clog2(total_sets)

        # self.storage_addr_width = self.

        # Clock and Reset
        self._clk = self.clock("clk")
        self._rst_n = self.reset("rst_n")

        # Inputs
        # phases = [] TODO
        # Take in the valid and data and attach an address + direct to a port
        self._config_data_in = self.input("config_data_in",
                                          self.data_width)

        self._config_addr_in = self.input("config_addr_in",
                                          self.config_addr_width)

        self._config_wr = self.input("config_wr", 1)
        self._config_rd = self.input("config_rd", 1)
        self._config_en = self.input("config_en", self.total_sets)

        self._clk_en = self.input("clk_en", 1)

        self._rd_data_stg = self.input("rd_data_stg", self.data_width,
                                       size=(self.banks,
                                             self.fw_int),
                                       explicit_array=True,
                                       packed=True)

        self._wr_data = self.output("wr_data",
                                    self.data_width,
                                    size=self.fw_int,
                                    explicit_array=True,
                                    packed=True)

        self._rd_data_out = self.output("rd_data_out", self.data_width,
                                        size=self.total_sets,
                                        explicit_array=True,
                                        packed=True)

        self._addr_out = self.output("addr_out",
                                     self.addr_width)

        # One set per macro means we directly send the config address through
        if self.sets_per_macro == 1:
            width = self.addr_width - self.config_addr_width
            if width > 0:
                self.wire(self._addr_out, kts.concat(kts.const(0, width), self._config_addr_in))
            else:
                self.wire(self._addr_out, self._config_addr_in[self.addr_width - 1, 0])
        else:
            width = self.addr_width - self.config_addr_width - clog2(self.sets_per_macro)
            self._set_to_addr = self.var("set_to_addr",
                                         clog2(self.sets_per_macro))
            self._reduce_en = self.var("reduce_en", self.sets_per_macro)
            for i in range(self.sets_per_macro):
                reduce_var = self._config_en[i]
                for j in range(self.banks - 1):
                    reduce_var = kts.concat(reduce_var, self._config_en[i + (self.sets_per_macro * (j + 1))])
                self.wire(self._reduce_en[i], reduce_var.r_or())
            self.add_code(self.demux_set_addr)
            if width > 0:
                self.wire(self._addr_out, kts.concat(kts.const(0, width),
                          self._set_to_addr,
                          self._config_addr_in))
            else:
                self.wire(self._addr_out, kts.concat(self._set_to_addr, self._config_addr_in))

        self._wen_out = self.output("wen_out", self.banks)
        self._ren_out = self.output("ren_out", self.banks)

        # Handle data passing
        if self.fw_int == 1:
            # If word width is same as data width, just pass everything through
            self.wire(self._wr_data[0], self._config_data_in)
            # self.wire(self._rd_data_out, self._rd_data_stg[0])
            num = 0
            for i in range(self.banks):
                for j in range(self.sets_per_macro):
                    self.wire(self._rd_data_out[num], self._rd_data_stg[i])
                    num = num + 1
        else:
            self._data_wr_reg = self.var("data_wr_reg",
                                         self.data_width,
                                         size=self.fw_int - 1,
                                         packed=True,
                                         explicit_array=True)
            # self._data_rd_reg = self.var("data_rd_reg",
            #                              self.data_width,
            #                              size=self.fw_int - 1,
            #                              packed=True,
            #                              explicit_array=True)

            # Have word counter for repeated reads/writes
            self._cnt = self.var("cnt", clog2(self.fw_int))
            self._rd_cnt = self.var("rd_cnt", clog2(self.fw_int))
            self.add_code(self.update_cnt)
            self.add_code(self.update_rd_cnt)
            # Gate wen if not about to finish the word

            num = 0
            for i in range(self.banks):
                for j in range(self.sets_per_macro):
                    self.wire(self._rd_data_out[num], self._rd_data_stg[i][self._rd_cnt])
                    num = num + 1

            # Deal with writing to the data buffer
            self.add_code(self.write_buffer)

            # Wire the reg + such to this guy
            for i in range(self.fw_int - 1):
                self.wire(self._wr_data[i], self._data_wr_reg[i])
            self.wire(self._wr_data[self.fw_int - 1], self._config_data_in)

        # If we have one bank, we can just always rd/wr from that one
        if self.banks == 1:
            if self.fw_int == 1:
                self.wire(self._wen_out, self._config_wr)
            else:
                self.wire(self._wen_out,
                          self._config_wr & (self._cnt == (self.fw_int - 1)))
            self.wire(self._ren_out, self._config_rd)
        # Otherwise we need to extract the bank from the set
        else:
            if self.fw_int == 1:
                for i in range(self.banks):
                    width = self.sets_per_macro
                    self.wire(self._wen_out[i], self._config_wr &
                              self._config_en[(i + 1) * width - 1, i * width].r_or())
            else:
                for i in range(self.banks):
                    width = self.sets_per_macro
                    self.wire(self._wen_out[i],
                              self._config_wr & self._config_en[(i + 1) * width - 1, i * width].r_or() &
                              (self._cnt == (self.fw_int - 1)))

            for i in range(self.banks):
                width = self.sets_per_macro
                self.wire(self._ren_out[i],
                          self._config_rd & self._config_en[(i + 1) * width - 1, i * width].r_or())
Exemple #27
0
def increment(var, value):
    return var + kts.const(value, var.width)
Exemple #28
0
    def __init__(self, iterator_support=6, config_width=16, use_enable=True):

        super().__init__(f"sched_gen_{iterator_support}_{config_width}")

        self.iterator_support = iterator_support
        self.config_width = config_width
        self.use_enable = use_enable

        # PORT DEFS: begin

        # INPUTS
        self._clk = self.clock("clk")
        self._rst_n = self.reset("rst_n")

        # OUTPUTS
        self._valid_output = self.output("valid_output", 1)

        # VARS
        self._valid_out = self.var("valid_out", 1)
        self._cycle_count = self.input("cycle_count", self.config_width)
        self._mux_sel = self.input("mux_sel",
                                   max(clog2(self.iterator_support), 1))
        self._addr_out = self.var("addr_out", self.config_width)

        # Receive signal on last iteration of looping structure and
        # gate the output...
        self._finished = self.input("finished", 1)
        self._valid_gate_inv = self.var("valid_gate_inv", 1)
        self._valid_gate = self.var("valid_gate", 1)
        self.wire(self._valid_gate, ~self._valid_gate_inv)

        # Since dim = 0 is not sufficient, we need a way to prevent
        # the controllers from firing on the starting offset
        if self.use_enable:
            self._enable = self.input("enable", 1)
            self._enable.add_attribute(
                ConfigRegAttr("Disable the controller so it never fires..."))
            self._enable.add_attribute(
                FormalAttr(f"{self._enable.name}",
                           FormalSignalConstraint.SOLVE))
        # Otherwise we set it as a 1 and leave it up to synthesis...
        else:
            self._enable = self.var("enable", 1)
            self.wire(self._enable, kratos.const(1, 1))

        @always_ff((posedge, "clk"), (negedge, "rst_n"))
        def valid_gate_inv_ff():
            if ~self._rst_n:
                self._valid_gate_inv = 0
            # If we are finishing the looping structure, turn this off to implement one-shot
            elif self._finished:
                self._valid_gate_inv = 1

        self.add_code(valid_gate_inv_ff)

        # Compare based on minimum of addr + global cycle...
        self.c_a_cmp = min(self._cycle_count.width, self._addr_out.width)

        # PORT DEFS: end

        self.add_child(f"sched_addr_gen",
                       AddrGen(iterator_support=self.iterator_support,
                               config_width=self.config_width),
                       clk=self._clk,
                       rst_n=self._rst_n,
                       step=self._valid_out,
                       mux_sel=self._mux_sel,
                       addr_out=self._addr_out,
                       restart=const(0, 1))

        self.add_code(self.set_valid_out)
        self.add_code(self.set_valid_output)
Exemple #29
0
    def __init__(
            self,
            data_width=16,  # CGRA Params
            mem_depth=32,
            default_iterator_support=3,
            interconnect_input_ports=2,  # Connection to int
            interconnect_output_ports=2,
            mem_input_ports=1,
            mem_output_ports=1,
            config_data_width=32,
            config_addr_width=8,
            cycle_count_width=16,
            add_clk_enable=True,
            add_flush=True):
        super().__init__("pond", debug=True)

        self.interconnect_input_ports = interconnect_input_ports
        self.interconnect_output_ports = interconnect_output_ports
        self.mem_input_ports = mem_input_ports
        self.mem_output_ports = mem_output_ports
        self.mem_depth = mem_depth
        self.data_width = data_width
        self.config_data_width = config_data_width
        self.config_addr_width = config_addr_width
        self.add_clk_enable = add_clk_enable
        self.add_flush = add_flush
        self.cycle_count_width = cycle_count_width
        self.default_iterator_support = default_iterator_support
        self.default_config_width = kts.clog2(self.mem_depth)
        # inputs
        self._clk = self.clock("clk")
        self._clk.add_attribute(
            FormalAttr(f"{self._clk.name}", FormalSignalConstraint.CLK))
        self._rst_n = self.reset("rst_n")
        self._rst_n.add_attribute(
            FormalAttr(f"{self._rst_n.name}", FormalSignalConstraint.RSTN))
        self._clk_en = self.clock_en("clk_en", 1)

        # Enable/Disable tile
        self._tile_en = self.input("tile_en", 1)
        self._tile_en.add_attribute(
            ConfigRegAttr("Tile logic enable manifested as clock gate"))

        gclk = self.var("gclk", 1)
        self._gclk = kts.util.clock(gclk)
        self.wire(gclk, kts.util.clock(self._clk & self._tile_en))

        self._cycle_count = add_counter(self, "cycle_count",
                                        self.cycle_count_width)

        # Create write enable + addr, same for read.
        # self._write = self.input("write", self.interconnect_input_ports)
        self._write = self.var("write", self.mem_input_ports)
        # self._write.add_attribute(ControlSignalAttr(is_control=True))

        self._write_addr = self.var("write_addr",
                                    kts.clog2(self.mem_depth),
                                    size=self.interconnect_input_ports,
                                    explicit_array=True,
                                    packed=True)

        # Add "_pond" suffix to avoid error during garnet RTL generation
        self._data_in = self.input("data_in_pond",
                                   self.data_width,
                                   size=self.interconnect_input_ports,
                                   explicit_array=True,
                                   packed=True)
        self._data_in.add_attribute(
            FormalAttr(f"{self._data_in.name}",
                       FormalSignalConstraint.SEQUENCE))
        self._data_in.add_attribute(ControlSignalAttr(is_control=False))

        self._read = self.var("read", self.mem_output_ports)
        self._t_write = self.var("t_write", self.interconnect_input_ports)
        self._t_read = self.var("t_read", self.interconnect_output_ports)
        # self._read.add_attribute(ControlSignalAttr(is_control=True))

        self._read_addr = self.var("read_addr",
                                   kts.clog2(self.mem_depth),
                                   size=self.interconnect_output_ports,
                                   explicit_array=True,
                                   packed=True)

        self._s_read_addr = self.var("s_read_addr",
                                     kts.clog2(self.mem_depth),
                                     size=self.interconnect_output_ports,
                                     explicit_array=True,
                                     packed=True)

        self._data_out = self.output("data_out_pond",
                                     self.data_width,
                                     size=self.interconnect_output_ports,
                                     explicit_array=True,
                                     packed=True)
        self._data_out.add_attribute(
            FormalAttr(f"{self._data_out.name}",
                       FormalSignalConstraint.SEQUENCE))
        self._data_out.add_attribute(ControlSignalAttr(is_control=False))

        self._valid_out = self.output("valid_out_pond",
                                      self.interconnect_output_ports)
        self._valid_out.add_attribute(
            FormalAttr(f"{self._valid_out.name}",
                       FormalSignalConstraint.SEQUENCE))
        self._valid_out.add_attribute(ControlSignalAttr(is_control=False))

        self._mem_data_out = self.var("mem_data_out",
                                      self.data_width,
                                      size=self.mem_output_ports,
                                      explicit_array=True,
                                      packed=True)

        self._s_mem_data_in = self.var("s_mem_data_in",
                                       self.data_width,
                                       size=self.interconnect_input_ports,
                                       explicit_array=True,
                                       packed=True)

        self._mem_data_in = self.var("mem_data_in",
                                     self.data_width,
                                     size=self.mem_input_ports,
                                     explicit_array=True,
                                     packed=True)

        self._s_mem_write_addr = self.var("s_mem_write_addr",
                                          kts.clog2(self.mem_depth),
                                          size=self.interconnect_input_ports,
                                          explicit_array=True,
                                          packed=True)

        self._s_mem_read_addr = self.var("s_mem_read_addr",
                                         kts.clog2(self.mem_depth),
                                         size=self.interconnect_output_ports,
                                         explicit_array=True,
                                         packed=True)

        self._mem_write_addr = self.var("mem_write_addr",
                                        kts.clog2(self.mem_depth),
                                        size=self.mem_input_ports,
                                        explicit_array=True,
                                        packed=True)

        self._mem_read_addr = self.var("mem_read_addr",
                                       kts.clog2(self.mem_depth),
                                       size=self.mem_output_ports,
                                       explicit_array=True,
                                       packed=True)

        if self.interconnect_output_ports == 1:
            self.wire(self._data_out[0], self._mem_data_out[0])
        else:
            for i in range(self.interconnect_output_ports):
                self.wire(self._data_out[i], self._mem_data_out[0])

        # Valid out is simply passing the read signal through...
        self.wire(self._valid_out, self._t_read)

        # Create write addressors
        for wr_port in range(self.interconnect_input_ports):

            RF_WRITE_ITER = ForLoop(
                iterator_support=self.default_iterator_support,
                config_width=self.cycle_count_width)
            RF_WRITE_ADDR = AddrGen(
                iterator_support=self.default_iterator_support,
                config_width=self.default_config_width)
            RF_WRITE_SCHED = SchedGen(
                iterator_support=self.default_iterator_support,
                config_width=self.cycle_count_width,
                use_enable=True)

            self.add_child(f"rf_write_iter_{wr_port}",
                           RF_WRITE_ITER,
                           clk=self._gclk,
                           rst_n=self._rst_n,
                           step=self._t_write[wr_port])
            # Whatever comes through here should hopefully just pipe through seamlessly
            # addressor modules
            self.add_child(f"rf_write_addr_{wr_port}",
                           RF_WRITE_ADDR,
                           clk=self._gclk,
                           rst_n=self._rst_n,
                           step=self._t_write[wr_port],
                           mux_sel=RF_WRITE_ITER.ports.mux_sel_out,
                           restart=RF_WRITE_ITER.ports.restart)
            safe_wire(self, self._write_addr[wr_port],
                      RF_WRITE_ADDR.ports.addr_out)

            self.add_child(f"rf_write_sched_{wr_port}",
                           RF_WRITE_SCHED,
                           clk=self._gclk,
                           rst_n=self._rst_n,
                           mux_sel=RF_WRITE_ITER.ports.mux_sel_out,
                           finished=RF_WRITE_ITER.ports.restart,
                           cycle_count=self._cycle_count,
                           valid_output=self._t_write[wr_port])

        # Create read addressors
        for rd_port in range(self.interconnect_output_ports):

            RF_READ_ITER = ForLoop(
                iterator_support=self.default_iterator_support,
                config_width=self.cycle_count_width)
            RF_READ_ADDR = AddrGen(
                iterator_support=self.default_iterator_support,
                config_width=self.default_config_width)
            RF_READ_SCHED = SchedGen(
                iterator_support=self.default_iterator_support,
                config_width=self.cycle_count_width,
                use_enable=True)

            self.add_child(f"rf_read_iter_{rd_port}",
                           RF_READ_ITER,
                           clk=self._gclk,
                           rst_n=self._rst_n,
                           step=self._t_read[rd_port])

            self.add_child(f"rf_read_addr_{rd_port}",
                           RF_READ_ADDR,
                           clk=self._gclk,
                           rst_n=self._rst_n,
                           step=self._t_read[rd_port],
                           mux_sel=RF_READ_ITER.ports.mux_sel_out,
                           restart=RF_READ_ITER.ports.restart)
            if self.interconnect_output_ports > 1:
                safe_wire(self, self._read_addr[rd_port],
                          RF_READ_ADDR.ports.addr_out)
            else:
                safe_wire(self, self._read_addr[rd_port],
                          RF_READ_ADDR.ports.addr_out)

            self.add_child(f"rf_read_sched_{rd_port}",
                           RF_READ_SCHED,
                           clk=self._gclk,
                           rst_n=self._rst_n,
                           mux_sel=RF_READ_ITER.ports.mux_sel_out,
                           finished=RF_READ_ITER.ports.restart,
                           cycle_count=self._cycle_count,
                           valid_output=self._t_read[rd_port])

        self.wire(self._write, self._t_write.r_or())
        self.wire(self._mem_write_addr[0],
                  decode(self, self._t_write, self._s_mem_write_addr))

        self.wire(self._mem_data_in[0],
                  decode(self, self._t_write, self._s_mem_data_in))

        self.wire(self._read, self._t_read.r_or())
        self.wire(self._mem_read_addr[0],
                  decode(self, self._t_read, self._s_mem_read_addr))
        # ===================================
        # Instantiate config hooks...
        # ===================================
        self.fw_int = 1
        self.data_words_per_set = 2**self.config_addr_width
        self.sets = int(
            (self.fw_int * self.mem_depth) / self.data_words_per_set)

        self.sets_per_macro = max(
            1, int(self.mem_depth / self.data_words_per_set))
        self.total_sets = max(1, 1 * self.sets_per_macro)

        self._config_data_in = self.input("config_data_in",
                                          self.config_data_width)
        self._config_data_in.add_attribute(ControlSignalAttr(is_control=False))

        self._config_data_in_shrt = self.var("config_data_in_shrt",
                                             self.data_width)

        self.wire(self._config_data_in_shrt,
                  self._config_data_in[self.data_width - 1, 0])

        self._config_addr_in = self.input("config_addr_in",
                                          self.config_addr_width)
        self._config_addr_in.add_attribute(ControlSignalAttr(is_control=False))

        self._config_data_out_shrt = self.var("config_data_out_shrt",
                                              self.data_width,
                                              size=self.total_sets,
                                              explicit_array=True,
                                              packed=True)

        self._config_data_out = self.output("config_data_out",
                                            self.config_data_width,
                                            size=self.total_sets,
                                            explicit_array=True,
                                            packed=True)
        self._config_data_out.add_attribute(
            ControlSignalAttr(is_control=False))

        for i in range(self.total_sets):
            self.wire(
                self._config_data_out[i],
                self._config_data_out_shrt[i].extend(self.config_data_width))

        self._config_read = self.input("config_read", 1)
        self._config_read.add_attribute(ControlSignalAttr(is_control=False))

        self._config_write = self.input("config_write", 1)
        self._config_write.add_attribute(ControlSignalAttr(is_control=False))

        self._config_en = self.input("config_en", self.total_sets)
        self._config_en.add_attribute(ControlSignalAttr(is_control=False))

        self._mem_data_cfg = self.var("mem_data_cfg",
                                      self.data_width,
                                      explicit_array=True,
                                      packed=True)

        self._mem_addr_cfg = self.var("mem_addr_cfg",
                                      kts.clog2(self.mem_depth))

        # Add config...
        stg_cfg_seq = StorageConfigSeq(
            data_width=self.data_width,
            config_addr_width=self.config_addr_width,
            addr_width=kts.clog2(self.mem_depth),
            fetch_width=self.data_width,
            total_sets=self.total_sets,
            sets_per_macro=self.sets_per_macro)

        # The clock to config sequencer needs to be the normal clock or
        # if the tile is off, we bring the clock back in based on config_en
        cfg_seq_clk = self.var("cfg_seq_clk", 1)
        self._cfg_seq_clk = kts.util.clock(cfg_seq_clk)
        self.wire(cfg_seq_clk, kts.util.clock(self._gclk))

        self.add_child(f"config_seq",
                       stg_cfg_seq,
                       clk=self._cfg_seq_clk,
                       rst_n=self._rst_n,
                       clk_en=self._clk_en | self._config_en.r_or(),
                       config_data_in=self._config_data_in_shrt,
                       config_addr_in=self._config_addr_in,
                       config_wr=self._config_write,
                       config_rd=self._config_read,
                       config_en=self._config_en,
                       wr_data=self._mem_data_cfg,
                       rd_data_out=self._config_data_out_shrt,
                       addr_out=self._mem_addr_cfg)

        if self.interconnect_output_ports == 1:
            self.wire(stg_cfg_seq.ports.rd_data_stg, self._mem_data_out)
        else:
            self.wire(stg_cfg_seq.ports.rd_data_stg[0], self._mem_data_out[0])

        self.RF_GEN = RegisterFile(data_width=self.data_width,
                                   write_ports=self.mem_input_ports,
                                   read_ports=self.mem_output_ports,
                                   width_mult=1,
                                   depth=self.mem_depth,
                                   read_delay=0)

        # Now we can instantiate and wire up the register file
        self.add_child(f"rf",
                       self.RF_GEN,
                       clk=self._gclk,
                       rst_n=self._rst_n,
                       data_out=self._mem_data_out)

        # Opt in for config_write
        self._write_rf = self.var("write_rf", self.mem_input_ports)
        self.wire(
            self._write_rf[0],
            kts.ternary(self._config_en.r_or(), self._config_write,
                        self._write[0]))
        for i in range(self.mem_input_ports - 1):
            self.wire(
                self._write_rf[i + 1],
                kts.ternary(self._config_en.r_or(), kts.const(0, 1),
                            self._write[i + 1]))
        self.wire(self.RF_GEN.ports.wen, self._write_rf)

        # Opt in for config_data_in
        for i in range(self.interconnect_input_ports):
            self.wire(
                self._s_mem_data_in[i],
                kts.ternary(self._config_en.r_or(), self._mem_data_cfg,
                            self._data_in[i]))
        self.wire(self.RF_GEN.ports.data_in, self._mem_data_in)

        # Opt in for config_addr
        for i in range(self.interconnect_input_ports):
            self.wire(
                self._s_mem_write_addr[i],
                kts.ternary(self._config_en.r_or(), self._mem_addr_cfg,
                            self._write_addr[i]))

        self.wire(self.RF_GEN.ports.wr_addr, self._mem_write_addr[0])

        for i in range(self.interconnect_output_ports):
            self.wire(
                self._s_mem_read_addr[i],
                kts.ternary(self._config_en.r_or(), self._mem_addr_cfg,
                            self._read_addr[i]))

        self.wire(self.RF_GEN.ports.rd_addr, self._mem_read_addr[0])

        if self.add_clk_enable:
            # self.clock_en("clk_en")
            kts.passes.auto_insert_clock_enable(self.internal_generator)
            clk_en_port = self.internal_generator.get_port("clk_en")
            clk_en_port.add_attribute(ControlSignalAttr(False))

        if self.add_flush:
            self.add_attribute("sync-reset=flush")
            kts.passes.auto_insert_sync_reset(self.internal_generator)
            flush_port = self.internal_generator.get_port("flush")
            flush_port.add_attribute(ControlSignalAttr(True))

        # Finally, lift the config regs...
        lift_config_reg(self.internal_generator)
Exemple #30
0
 def code():
     out_ = test_add(in_, const(1, 16))