コード例 #1
0
ファイル: extract_tile_info.py プロジェクト: StanfordAHA/lake
        curr_port = int_gen.get_port(port_name)
        attrs = curr_port.find_attribute(lambda a: isinstance(a, ControlSignalAttr))
        if len(attrs) != 1:
            continue
        cr_attr = attrs[0]
        # Now we have this
        intf_sigs.append(IO_info(port_name=port_name,
                                 port_size=curr_port.size,
                                 port_width=curr_port.width,
                                 is_ctrl=cr_attr.get_control(),
                                 port_dir=str(curr_port.port_direction),
                                 expl_arr=curr_port.explicit_array))
    return intf_sigs


if __name__ == "__main__":
    tsmc_info = SRAMMacroInfo("tsmc_name")
    use_sram_stub = False
    fifo_mode = True
    mem_width = 64
    lake_dut = LakeTop(mem_width=mem_width,
                       sram_macro_info=tsmc_info,
                       use_sram_stub=use_sram_stub,
                       fifo_mode=fifo_mode,
                       add_clk_enable=True,
                       add_flush=True)
    sram_port_pass = change_sram_port_names(use_sram_stub=use_sram_stub, sram_macro_info=tsmc_info)
    # Perform pass to move config_reg
    extract_top_config(lake_dut)
    # get_interface(lake_dut)
コード例 #2
0
ファイル: test_strg_fifo.py プロジェクト: StanfordAHA/lake
def test_storage_fifo(
        mem_width,  # CGRA Params
        depth,
        in_out_ports,
        banks=1,
        data_width=16,
        mem_depth=512,
        input_iterator_support=6,  # Addr Controllers
        output_iterator_support=6,  # Addr Controllers
        mem_input_ports=1,
        mem_output_ports=1,
        use_sram_stub=1,
        read_delay=1,  # Cycle delay in read (SRAM vs Register File)
        rw_same_cycle=False,  # Does the memory allow r+w in same cycle?
        agg_height=4,
        config_data_width=32,
        config_addr_width=8,
        fifo_mode=True):

    fw_int = int(mem_width / data_width)

    if banks == 1 and fw_int == 1:
        return

    new_config = {}
    new_config["fifo_ctrl_fifo_depth"] = depth
    new_config["mode"] = 1
    new_config["tile_en"] = 1

    model_rf = RegFIFOModel(data_width=data_width,
                            width_mult=fw_int,
                            depth=depth)

    # DUT
    lt_dut = LakeTop(data_width=data_width,
                     mem_width=mem_width,
                     mem_depth=mem_depth,
                     banks=banks,
                     input_iterator_support=input_iterator_support,
                     output_iterator_support=output_iterator_support,
                     interconnect_input_ports=in_out_ports,
                     interconnect_output_ports=in_out_ports,
                     mem_input_ports=mem_input_ports,
                     mem_output_ports=mem_output_ports,
                     use_sram_stub=use_sram_stub,
                     read_delay=read_delay,
                     rw_same_cycle=rw_same_cycle,
                     agg_height=agg_height,
                     config_data_width=config_data_width,
                     config_addr_width=config_addr_width,
                     fifo_mode=fifo_mode)

    magma_dut = kts.util.to_magma(lt_dut,
                                  flatten_array=True,
                                  check_multiple_driver=False,
                                  optimize_if=False,
                                  check_flip_flop_always_ff=False)

    tester = fault.Tester(magma_dut, magma_dut.clk)
    tester.zero_inputs()
    ###
    for key, value in new_config.items():
        setattr(tester.circuit, key, value)

    rand.seed(0)
    tester.circuit.clk = 0
    tester.circuit.rst_n = 0
    tester.step(2)
    tester.circuit.rst_n = 1
    tester.step(2)

    data_in = 0
    push = 1
    pop = 0

    push_cnt = 0
    pop_cnt = 0

    tester.circuit.clk_en = 1

    for i in range(2000):
        data_in = rand.randint(0, 2**data_width - 1)
        push = rand.randint(0, 1)
        pop = rand.randint(0, 1)

        if in_out_ports > 1:
            tester.circuit.data_in_0 = data_in
        else:
            tester.circuit.data_in = data_in

        tester.circuit.ren_in[0] = pop
        tester.circuit.wen_in[0] = push

        (model_out, model_val_x, model_empty, model_full,
         model_val) = model_rf.interact(push, pop, [data_in], push)

        push_cnt = push_cnt + push
        pop_cnt = pop_cnt + pop

        tester.eval()

        tester.circuit.empty.expect(model_empty)
        tester.circuit.full.expect(model_full)
        # Now check the outputs
        tester.circuit.valid_out.expect(model_val)
        if model_val:
            if in_out_ports > 1:
                tester.circuit.data_out_0.expect(model_out[0])
            else:
                tester.circuit.data_out.expect(model_out[0])

        tester.step(2)

    with tempfile.TemporaryDirectory() as tempdir:
        tester.compile_and_run(target="verilator",
                               directory=tempdir,
                               magma_output="verilog",
                               flags=["-Wno-fatal"])
コード例 #3
0
ファイル: test_strg_ram.py プロジェクト: StanfordAHA/lake
def test_storage_ram(mem_width,  # CGRA Params
                     in_out_ports,
                     banks=1,
                     data_width=16,
                     mem_depth=512,
                     input_iterator_support=6,  # Addr Controllers
                     output_iterator_support=6,
                     mem_input_ports=1,
                     mem_output_ports=1,
                     use_sram_stub=1,
                     read_delay=1,  # Cycle delay in read (SRAM vs Register File)
                     rw_same_cycle=False,  # Does the memory allow r+w in same cycle?
                     agg_height=4,
                     num_tiles=1,
                     config_data_width=32,
                     config_addr_width=8,
                     fifo_mode=True):

    # TODO: This currently doesn't generate...
    if mem_width == 16 and in_out_ports == 2:
        return

    fw_int = int(mem_width / data_width)

    new_config = {}
    new_config["mode"] = 2
    new_config["tile_en"] = 1

    sram_model = SRAMModel(data_width=data_width,
                           width_mult=fw_int,
                           depth=mem_depth,
                           num_tiles=num_tiles)

    # DUT
    lt_dut = LakeTop(data_width=data_width,
                     mem_width=mem_width,
                     mem_depth=mem_depth,
                     banks=banks,
                     input_iterator_support=input_iterator_support,
                     output_iterator_support=output_iterator_support,
                     interconnect_input_ports=in_out_ports,
                     interconnect_output_ports=in_out_ports,
                     mem_input_ports=mem_input_ports,
                     mem_output_ports=mem_output_ports,
                     use_sram_stub=use_sram_stub,
                     num_tiles=num_tiles,
                     read_delay=read_delay,
                     rw_same_cycle=rw_same_cycle,
                     agg_height=agg_height,
                     config_data_width=config_data_width,
                     config_addr_width=config_addr_width,
                     fifo_mode=fifo_mode)

    magma_dut = kts.util.to_magma(lt_dut,
                                  flatten_array=True,
                                  check_multiple_driver=False,
                                  optimize_if=False,
                                  check_flip_flop_always_ff=False)

    tester = fault.Tester(magma_dut, magma_dut.clk)
    tester.zero_inputs()
    ###
    for key, value in new_config.items():
        setattr(tester.circuit, key, value)

    rand.seed(0)
    tester.circuit.clk = 0
    tester.circuit.rst_n = 0
    tester.step(2)
    tester.circuit.rst_n = 1
    tester.step(2)

    data_in = 0
    addr_in = 0

    write = 0
    read = 0

    prev_wr = 0
    prev_rd = 0

    stall = fw_int > 1

    tester.circuit.clk_en = 1

    for i in range(2000):
        data_in = rand.randint(0, 2 ** data_width - 1)
        write = rand.randint(0, 1)
        read = rand.randint(0, 1)
        addr_in = rand.randint(0, 64)

        if prev_wr == 1 and stall:
            write = 0
            read = 0
            prev_wr = 0

        if write:
            prev_wr = 1
            read = 0

        if in_out_ports > 1:
            tester.circuit.data_in_0 = data_in
            tester.circuit.addr_in_0 = addr_in
        else:
            tester.circuit.data_in = data_in
            tester.circuit.addr_in = addr_in

        tester.circuit.wen_in[0] = write
        tester.circuit.ren_in[0] = read
        model_out = sram_model.interact(wen=write, cen=(write | read), addr=addr_in, data=[data_in])

        tester.eval()

        # # Now check the outputs
        tester.circuit.valid_out.expect(prev_rd)
        if prev_rd:
            if in_out_ports > 1:
                tester.circuit.data_out_0.expect(model_out[0])
            else:
                tester.circuit.data_out.expect(model_out[0])

        tester.step(2)
        prev_rd = read

    with tempfile.TemporaryDirectory() as tempdir:
        # tempdir = "strg_ram_dump"
        tester.compile_and_run(target="verilator",
                               directory=tempdir,
                               magma_output="verilog",
                               flags=["-Wno-fatal"])
コード例 #4
0
    def __init__(
            self,
            data_width=16,  # CGRA Params
            mem_width=16,
            mem_depth=256,
            banks=1,
            input_iterator_support=6,  # Addr Controllers
            output_iterator_support=6,
            input_config_width=16,
            output_config_width=16,
            interconnect_input_ports=1,  # Connection to int
            interconnect_output_ports=1,
            mem_input_ports=1,
            mem_output_ports=1,
            use_sram_stub=True,
            sram_macro_info=SRAMMacroInfo(),
            read_delay=1,  # Cycle delay in read (SRAM vs Register File)
            rw_same_cycle=True,  # Does the memory allow r+w in same cycle?
            agg_height=4,
            tb_sched_max=16,
            config_data_width=32,
            config_addr_width=8,
            num_tiles=1,
            remove_tb=False,
            fifo_mode=False,
            add_clk_enable=True,
            add_flush=True,
            override_name=None):

        # name
        if override_name:
            self.__name = override_name + "Core"
            lake_name = override_name
        else:
            self.__name = "MemCore"
            lake_name = "LakeTop"

        super().__init__(config_addr_width, config_data_width)

        # Capture everything to the tile object
        self.data_width = data_width
        self.mem_width = mem_width
        self.mem_depth = mem_depth
        self.banks = banks
        self.fw_int = int(self.mem_width / self.data_width)
        self.input_iterator_support = input_iterator_support
        self.output_iterator_support = output_iterator_support
        self.input_config_width = input_config_width
        self.output_config_width = output_config_width
        self.interconnect_input_ports = interconnect_input_ports
        self.interconnect_output_ports = interconnect_output_ports
        self.mem_input_ports = mem_input_ports
        self.mem_output_ports = mem_output_ports
        self.use_sram_stub = use_sram_stub
        self.sram_macro_info = sram_macro_info
        self.read_delay = read_delay
        self.rw_same_cycle = rw_same_cycle
        self.agg_height = agg_height
        self.config_data_width = config_data_width
        self.config_addr_width = config_addr_width
        self.num_tiles = num_tiles
        self.remove_tb = remove_tb
        self.fifo_mode = fifo_mode
        self.add_clk_enable = add_clk_enable
        self.add_flush = add_flush
        # self.app_ctrl_depth_width = app_ctrl_depth_width
        # self.stcl_valid_iter = stcl_valid_iter

        # Typedefs for ease
        TData = magma.Bits[self.data_width]
        TBit = magma.Bits[1]

        self.__inputs = []
        self.__outputs = []

        # cache_key = (self.data_width, self.mem_width, self.mem_depth, self.banks,
        #              self.input_iterator_support, self.output_iterator_support,
        #              self.interconnect_input_ports, self.interconnect_output_ports,
        #              self.use_sram_stub, self.sram_macro_info, self.read_delay,
        #              self.rw_same_cycle, self.agg_height, self.max_agg_schedule,
        #              self.input_max_port_sched, self.output_max_port_sched,
        #              self.align_input, self.max_line_length, self.max_tb_height,
        #              self.tb_range_max, self.tb_sched_max, self.max_tb_stride,
        #              self.num_tb, self.tb_iterator_support, self.multiwrite,
        #              self.max_prefetch, self.config_data_width, self.config_addr_width,
        #              self.num_tiles, self.remove_tb, self.fifo_mode, self.stcl_valid_iter,
        #              self.add_clk_enable, self.add_flush, self.app_ctrl_depth_width)

        cache_key = (self.data_width, self.mem_width, self.mem_depth,
                     self.banks, self.input_iterator_support,
                     self.output_iterator_support,
                     self.interconnect_input_ports,
                     self.interconnect_output_ports, self.use_sram_stub,
                     self.sram_macro_info, self.read_delay, self.rw_same_cycle,
                     self.agg_height, self.config_data_width,
                     self.config_addr_width, self.num_tiles, self.remove_tb,
                     self.fifo_mode, self.add_clk_enable, self.add_flush)

        # Check for circuit caching
        if cache_key not in MemCore.__circuit_cache:

            # Instantiate core object here - will only use the object representation to
            # query for information. The circuit representation will be cached and retrieved
            # in the following steps.
            # lt_dut = LakeTop(data_width=self.data_width,
            #                  mem_width=self.mem_width,
            #                  mem_depth=self.mem_depth,
            #                  banks=self.banks,
            #                  input_iterator_support=self.input_iterator_support,
            #                  output_iterator_support=self.output_iterator_support,
            #                  input_config_width=self.input_config_width,
            #                  output_config_width=self.output_config_width,
            #                  interconnect_input_ports=self.interconnect_input_ports,
            #                  interconnect_output_ports=self.interconnect_output_ports,
            #                  use_sram_stub=self.use_sram_stub,
            #                  sram_macro_info=self.sram_macro_info,
            #                  read_delay=self.read_delay,
            #                  rw_same_cycle=self.rw_same_cycle,
            #                  agg_height=self.agg_height,
            #                  max_agg_schedule=self.max_agg_schedule,
            #                  input_max_port_sched=self.input_max_port_sched,
            #                  output_max_port_sched=self.output_max_port_sched,
            #                  align_input=self.align_input,
            #                  max_line_length=self.max_line_length,
            #                  max_tb_height=self.max_tb_height,
            #                  tb_range_max=self.tb_range_max,
            #                  tb_range_inner_max=self.tb_range_inner_max,
            #                  tb_sched_max=self.tb_sched_max,
            #                  max_tb_stride=self.max_tb_stride,
            #                  num_tb=self.num_tb,
            #                  tb_iterator_support=self.tb_iterator_support,
            #                  multiwrite=self.multiwrite,
            #                  max_prefetch=self.max_prefetch,
            #                  config_data_width=self.config_data_width,
            #                  config_addr_width=self.config_addr_width,
            #                  num_tiles=self.num_tiles,
            #                  app_ctrl_depth_width=self.app_ctrl_depth_width,
            #                  remove_tb=self.remove_tb,
            #                  fifo_mode=self.fifo_mode,
            #                  add_clk_enable=self.add_clk_enable,
            #                  add_flush=self.add_flush,
            #                  stcl_valid_iter=self.stcl_valid_iter)

            lt_dut = LakeTop(
                data_width=self.data_width,
                mem_width=self.mem_width,
                mem_depth=self.mem_depth,
                banks=self.banks,
                input_iterator_support=self.input_iterator_support,
                output_iterator_support=self.output_iterator_support,
                input_config_width=self.input_config_width,
                output_config_width=self.output_config_width,
                interconnect_input_ports=self.interconnect_input_ports,
                interconnect_output_ports=self.interconnect_output_ports,
                use_sram_stub=self.use_sram_stub,
                sram_macro_info=self.sram_macro_info,
                read_delay=self.read_delay,
                rw_same_cycle=self.rw_same_cycle,
                agg_height=self.agg_height,
                config_data_width=self.config_data_width,
                config_addr_width=self.config_addr_width,
                num_tiles=self.num_tiles,
                remove_tb=self.remove_tb,
                fifo_mode=self.fifo_mode,
                add_clk_enable=self.add_clk_enable,
                add_flush=self.add_flush,
                name=lake_name,
                gen_addr=False)

            change_sram_port_pass = change_sram_port_names(
                use_sram_stub, sram_macro_info)
            circ = kts.util.to_magma(
                lt_dut,
                flatten_array=True,
                check_multiple_driver=False,
                optimize_if=False,
                check_flip_flop_always_ff=False,
                additional_passes={"change_sram_port": change_sram_port_pass})
            MemCore.__circuit_cache[cache_key] = (circ, lt_dut)
        else:
            circ, lt_dut = MemCore.__circuit_cache[cache_key]

        # Save as underlying circuit object
        self.underlying = FromMagma(circ)

        # Enumerate input and output ports
        # (clk and reset are assumed)
        core_interface = get_interface(lt_dut)
        cfgs = extract_top_config(lt_dut)
        assert len(cfgs) > 0, "No configs?"

        # We basically add in the configuration bus differently
        # than the other ports...
        skip_names = [
            "config_data_in", "config_write", "config_addr_in",
            "config_data_out", "config_read", "config_en", "clk_en"
        ]

        # Create a list of signals that will be able to be
        # hardwired to a constant at runtime...
        control_signals = []
        # The rest of the signals to wire to the underlying representation...
        other_signals = []

        # for port_name, port_size, port_width, is_ctrl, port_dir, explicit_array in core_interface:
        for io_info in core_interface:
            if io_info.port_name in skip_names:
                continue
            ind_ports = io_info.port_width
            intf_type = TBit
            # For our purposes, an explicit array means the inner data HAS to be 16 bits
            if io_info.expl_arr:
                ind_ports = io_info.port_size[0]
                intf_type = TData
            dir_type = magma.In
            app_list = self.__inputs
            if io_info.port_dir == "PortDirection.Out":
                dir_type = magma.Out
                app_list = self.__outputs
            if ind_ports > 1:
                for i in range(ind_ports):
                    self.add_port(f"{io_info.port_name}_{i}",
                                  dir_type(intf_type))
                    app_list.append(self.ports[f"{io_info.port_name}_{i}"])
            else:
                self.add_port(io_info.port_name, dir_type(intf_type))
                app_list.append(self.ports[io_info.port_name])

            # classify each signal for wiring to underlying representation...
            if io_info.is_ctrl:
                control_signals.append((io_info.port_name, io_info.port_width))
            else:
                if ind_ports > 1:
                    for i in range(ind_ports):
                        other_signals.append(
                            (f"{io_info.port_name}_{i}", io_info.port_dir,
                             io_info.expl_arr, i, io_info.port_name))
                else:
                    other_signals.append(
                        (io_info.port_name, io_info.port_dir, io_info.expl_arr,
                         0, io_info.port_name))

        assert (len(self.__outputs) > 0)

        # We call clk_en stall at this level for legacy reasons????
        self.add_ports(stall=magma.In(TBit), )

        self.chain_idx_bits = max(1, kts.clog2(self.num_tiles))

        # put a 1-bit register and a mux to select the control signals
        for control_signal, width in control_signals:
            if width == 1:
                mux = MuxWrapper(2, 1, name=f"{control_signal}_sel")
                reg_value_name = f"{control_signal}_reg_value"
                reg_sel_name = f"{control_signal}_reg_sel"
                self.add_config(reg_value_name, 1)
                self.add_config(reg_sel_name, 1)
                self.wire(mux.ports.I[0], self.ports[control_signal])
                self.wire(mux.ports.I[1],
                          self.registers[reg_value_name].ports.O)
                self.wire(mux.ports.S, self.registers[reg_sel_name].ports.O)
                # 0 is the default wire, which takes from the routing network
                self.wire(mux.ports.O[0],
                          self.underlying.ports[control_signal][0])
            else:
                for i in range(width):
                    mux = MuxWrapper(2, 1, name=f"{control_signal}_{i}_sel")
                    reg_value_name = f"{control_signal}_{i}_reg_value"
                    reg_sel_name = f"{control_signal}_{i}_reg_sel"
                    self.add_config(reg_value_name, 1)
                    self.add_config(reg_sel_name, 1)
                    self.wire(mux.ports.I[0],
                              self.ports[f"{control_signal}_{i}"])
                    self.wire(mux.ports.I[1],
                              self.registers[reg_value_name].ports.O)
                    self.wire(mux.ports.S,
                              self.registers[reg_sel_name].ports.O)
                    # 0 is the default wire, which takes from the routing network
                    self.wire(mux.ports.O[0],
                              self.underlying.ports[control_signal][i])

        # Wire the other signals up...
        for pname, pdir, expl_arr, ind, uname in other_signals:
            # If we are in an explicit array moment, use the given wire name...
            if expl_arr is False:
                # And if not, use the index
                self.wire(self.ports[pname][0],
                          self.underlying.ports[uname][ind])
            else:
                self.wire(self.ports[pname], self.underlying.ports[pname])

        # CLK, RESET, and STALL PER STANDARD PROCEDURE

        # Need to invert this
        self.resetInverter = FromMagma(mantle.DefineInvert(1))
        self.wire(self.resetInverter.ports.I[0], self.ports.reset)
        self.wire(self.resetInverter.ports.O[0], self.underlying.ports.rst_n)
        self.wire(self.ports.clk, self.underlying.ports.clk)

        # Mem core uses clk_en (essentially active low stall)
        self.stallInverter = FromMagma(mantle.DefineInvert(1))
        self.wire(self.stallInverter.ports.I, self.ports.stall)
        self.wire(self.stallInverter.ports.O[0],
                  self.underlying.ports.clk_en[0])

        # we have six? features in total
        # 0:    TILE
        # 1:    TILE
        # 1-4:  SMEM
        # Feature 0: Tile
        self.__features: List[CoreFeature] = [self]
        # Features 1-4: SRAM
        self.num_sram_features = lt_dut.total_sets
        for sram_index in range(self.num_sram_features):
            core_feature = CoreFeature(self, sram_index + 1)
            self.__features.append(core_feature)

        # Wire the config
        for idx, core_feature in enumerate(self.__features):
            if (idx > 0):
                self.add_port(
                    f"config_{idx}",
                    magma.In(
                        ConfigurationType(self.config_addr_width,
                                          self.config_data_width)))
                # port aliasing
                core_feature.ports["config"] = self.ports[f"config_{idx}"]
        self.add_port(
            "config",
            magma.In(
                ConfigurationType(self.config_addr_width,
                                  self.config_data_width)))

        # or the signal up
        t = ConfigurationType(self.config_addr_width, self.config_data_width)
        t_names = ["config_addr", "config_data"]
        or_gates = {}
        for t_name in t_names:
            port_type = t[t_name]
            or_gate = FromMagma(
                mantle.DefineOr(len(self.__features), len(port_type)))
            or_gate.instance_name = f"OR_{t_name}_FEATURE"
            for idx, core_feature in enumerate(self.__features):
                self.wire(or_gate.ports[f"I{idx}"],
                          core_feature.ports.config[t_name])
            or_gates[t_name] = or_gate

        self.wire(
            or_gates["config_addr"].ports.O,
            self.underlying.ports.config_addr_in[0:self.config_addr_width])
        self.wire(or_gates["config_data"].ports.O,
                  self.underlying.ports.config_data_in)

        # read data out
        for idx, core_feature in enumerate(self.__features):
            if (idx > 0):
                # self.add_port(f"read_config_data_{idx}",
                self.add_port(f"read_config_data_{idx}",
                              magma.Out(magma.Bits[self.config_data_width]))
                # port aliasing
                core_feature.ports["read_config_data"] = \
                    self.ports[f"read_config_data_{idx}"]

        # MEM Config
        configurations = []
        # merged_configs = []
        skip_cfgs = []

        for cfg_info in cfgs:
            if cfg_info.port_name in skip_cfgs:
                continue
            if cfg_info.expl_arr:
                if cfg_info.port_size[0] > 1:
                    for i in range(cfg_info.port_size[0]):
                        configurations.append(
                            (f"{cfg_info.port_name}_{i}", cfg_info.port_width))
                else:
                    configurations.append(
                        (cfg_info.port_name, cfg_info.port_width))
            else:
                configurations.append(
                    (cfg_info.port_name, cfg_info.port_width))

        # Do all the stuff for the main config
        main_feature = self.__features[0]
        for config_reg_name, width in configurations:
            main_feature.add_config(config_reg_name, width)
            if (width == 1):
                self.wire(main_feature.registers[config_reg_name].ports.O[0],
                          self.underlying.ports[config_reg_name][0])
            else:
                self.wire(main_feature.registers[config_reg_name].ports.O,
                          self.underlying.ports[config_reg_name])

        # SRAM
        # These should also account for num features
        # or_all_cfg_rd = FromMagma(mantle.DefineOr(4, 1))
        or_all_cfg_rd = FromMagma(mantle.DefineOr(self.num_sram_features, 1))
        or_all_cfg_rd.instance_name = f"OR_CONFIG_WR_SRAM"
        or_all_cfg_wr = FromMagma(mantle.DefineOr(self.num_sram_features, 1))
        or_all_cfg_wr.instance_name = f"OR_CONFIG_RD_SRAM"
        for sram_index in range(self.num_sram_features):
            core_feature = self.__features[sram_index + 1]
            self.add_port(f"config_en_{sram_index}", magma.In(magma.Bit))
            # port aliasing
            core_feature.ports["config_en"] = \
                self.ports[f"config_en_{sram_index}"]
            # Sort of a temp hack - the name is just config_data_out
            if self.num_sram_features == 1:
                self.wire(core_feature.ports.read_config_data,
                          self.underlying.ports["config_data_out"])
            else:
                self.wire(
                    core_feature.ports.read_config_data,
                    self.underlying.ports[f"config_data_out_{sram_index}"])
            # also need to wire the sram signal
            # the config enable is the OR of the rd+wr
            or_gate_en = FromMagma(mantle.DefineOr(2, 1))
            or_gate_en.instance_name = f"OR_CONFIG_EN_SRAM_{sram_index}"

            self.wire(or_gate_en.ports.I0, core_feature.ports.config.write)
            self.wire(or_gate_en.ports.I1, core_feature.ports.config.read)
            self.wire(core_feature.ports.config_en,
                      self.underlying.ports["config_en"][sram_index])
            # Still connect to the OR of all the config rd/wr
            self.wire(core_feature.ports.config.write,
                      or_all_cfg_wr.ports[f"I{sram_index}"])
            self.wire(core_feature.ports.config.read,
                      or_all_cfg_rd.ports[f"I{sram_index}"])

        self.wire(or_all_cfg_rd.ports.O[0],
                  self.underlying.ports.config_read[0])
        self.wire(or_all_cfg_wr.ports.O[0],
                  self.underlying.ports.config_write[0])
        self._setup_config()

        conf_names = list(self.registers.keys())
        conf_names.sort()
        with open("mem_cfg.txt", "w+") as cfg_dump:
            for idx, reg in enumerate(conf_names):
                write_line = f"(\"{reg}\", 0),  # {self.registers[reg].width}\n"
                cfg_dump.write(write_line)
        with open("mem_synth.txt", "w+") as cfg_dump:
            for idx, reg in enumerate(conf_names):
                write_line = f"{reg}\n"
                cfg_dump.write(write_line)
コード例 #5
0
ファイル: memory_core_magma.py プロジェクト: zamyers/garnet
    def __init__(
            self,
            data_width=16,  # CGRA Params
            mem_width=64,
            mem_depth=512,
            banks=1,
            input_iterator_support=6,  # Addr Controllers
            output_iterator_support=6,
            input_config_width=16,
            output_config_width=16,
            interconnect_input_ports=2,  # Connection to int
            interconnect_output_ports=2,
            mem_input_ports=1,
            mem_output_ports=1,
            use_sram_stub=1,
            sram_macro_info=SRAMMacroInfo("TS1N16FFCLLSBLVTC512X32M4S"),
            read_delay=1,  # Cycle delay in read (SRAM vs Register File)
            rw_same_cycle=False,  # Does the memory allow r+w in same cycle?
            agg_height=4,
            max_agg_schedule=16,
            input_max_port_sched=16,
            output_max_port_sched=16,
            align_input=1,
            max_line_length=128,
            max_tb_height=1,
            tb_range_max=1024,
            tb_range_inner_max=64,
            tb_sched_max=16,
            max_tb_stride=15,
            num_tb=1,
            tb_iterator_support=2,
            multiwrite=1,
            max_prefetch=8,
            config_data_width=32,
            config_addr_width=8,
            num_tiles=2,
            app_ctrl_depth_width=16,
            remove_tb=False,
            fifo_mode=True,
            add_clk_enable=True,
            add_flush=True,
            core_reset_pos=False,
            stcl_valid_iter=4):

        super().__init__(config_addr_width, config_data_width)

        # Capture everything to the tile object
        self.data_width = data_width
        self.mem_width = mem_width
        self.mem_depth = mem_depth
        self.banks = banks
        self.fw_int = int(self.mem_width / self.data_width)
        self.input_iterator_support = input_iterator_support
        self.output_iterator_support = output_iterator_support
        self.input_config_width = input_config_width
        self.output_config_width = output_config_width
        self.interconnect_input_ports = interconnect_input_ports
        self.interconnect_output_ports = interconnect_output_ports
        self.mem_input_ports = mem_input_ports
        self.mem_output_ports = mem_output_ports
        self.use_sram_stub = use_sram_stub
        self.sram_macro_info = sram_macro_info
        self.read_delay = read_delay
        self.rw_same_cycle = rw_same_cycle
        self.agg_height = agg_height
        self.max_agg_schedule = max_agg_schedule
        self.input_max_port_sched = input_max_port_sched
        self.output_max_port_sched = output_max_port_sched
        self.align_input = align_input
        self.max_line_length = max_line_length
        self.max_tb_height = max_tb_height
        self.tb_range_max = tb_range_max
        self.tb_range_inner_max = tb_range_inner_max
        self.tb_sched_max = tb_sched_max
        self.max_tb_stride = max_tb_stride
        self.num_tb = num_tb
        self.tb_iterator_support = tb_iterator_support
        self.multiwrite = multiwrite
        self.max_prefetch = max_prefetch
        self.config_data_width = config_data_width
        self.config_addr_width = config_addr_width
        self.num_tiles = num_tiles
        self.remove_tb = remove_tb
        self.fifo_mode = fifo_mode
        self.add_clk_enable = add_clk_enable
        self.add_flush = add_flush
        self.core_reset_pos = core_reset_pos
        self.app_ctrl_depth_width = app_ctrl_depth_width
        self.stcl_valid_iter = stcl_valid_iter

        # Typedefs for ease
        TData = magma.Bits[self.data_width]
        TBit = magma.Bits[1]

        self.__inputs = []
        self.__outputs = []

        # Enumerate input and output ports
        # (clk and reset are assumed)
        if self.interconnect_input_ports > 1:
            for i in range(self.interconnect_input_ports):
                self.add_port(f"addr_in_{i}", magma.In(TData))
                self.__inputs.append(self.ports[f"addr_in_{i}"])
                self.add_port(f"data_in_{i}", magma.In(TData))
                self.__inputs.append(self.ports[f"data_in_{i}"])
                self.add_port(f"wen_in_{i}", magma.In(TBit))
                self.__inputs.append(self.ports[f"wen_in_{i}"])
        else:
            self.add_port("addr_in", magma.In(TData))
            self.__inputs.append(self.ports[f"addr_in"])
            self.add_port("data_in", magma.In(TData))
            self.__inputs.append(self.ports[f"data_in"])
            self.add_port("wen_in", magma.In(TBit))
            self.__inputs.append(self.ports.wen_in)

        if self.interconnect_output_ports > 1:
            for i in range(self.interconnect_output_ports):
                self.add_port(f"data_out_{i}", magma.Out(TData))
                self.__outputs.append(self.ports[f"data_out_{i}"])
                self.add_port(f"ren_in_{i}", magma.In(TBit))
                self.__inputs.append(self.ports[f"ren_in_{i}"])
                self.add_port(f"valid_out_{i}", magma.Out(TBit))
                self.__outputs.append(self.ports[f"valid_out_{i}"])
                # Chaining
                self.add_port(f"chain_valid_in_{i}", magma.In(TBit))
                self.__inputs.append(self.ports[f"chain_valid_in_{i}"])
                self.add_port(f"chain_data_in_{i}", magma.In(TData))
                self.__inputs.append(self.ports[f"chain_data_in_{i}"])
                self.add_port(f"chain_data_out_{i}", magma.Out(TData))
                self.__outputs.append(self.ports[f"chain_data_out_{i}"])
                self.add_port(f"chain_valid_out_{i}", magma.Out(TBit))
                self.__outputs.append(self.ports[f"chain_valid_out_{i}"])
        else:
            self.add_port("data_out", magma.Out(TData))
            self.__outputs.append(self.ports[f"data_out"])
            self.add_port(f"ren_in", magma.In(TBit))
            self.__inputs.append(self.ports[f"ren_in"])
            self.add_port(f"valid_out", magma.Out(TBit))
            self.__outputs.append(self.ports[f"valid_out"])
            self.add_port(f"chain_valid_in", magma.In(TBit))
            self.__inputs.append(self.ports[f"chain_valid_in"])
            self.add_port(f"chain_data_in", magma.In(TData))
            self.__inputs.append(self.ports[f"chain_data_in"])
            self.add_port(f"chain_data_out", magma.Out(TData))
            self.__outputs.append(self.ports[f"chain_data_out"])
            self.add_port(f"chain_valid_out", magma.Out(TBit))
            self.__outputs.append(self.ports[f"chain_valid_out"])

        self.add_ports(flush=magma.In(TBit),
                       full=magma.Out(TBit),
                       empty=magma.Out(TBit),
                       stall=magma.In(TBit),
                       sram_ready_out=magma.Out(TBit))

        self.__inputs.append(self.ports.flush)
        # self.__inputs.append(self.ports.stall)

        self.__outputs.append(self.ports.full)
        self.__outputs.append(self.ports.empty)
        self.__outputs.append(self.ports.sram_ready_out)

        cache_key = (self.data_width, self.mem_width, self.mem_depth,
                     self.banks, self.input_iterator_support,
                     self.output_iterator_support,
                     self.interconnect_input_ports,
                     self.interconnect_output_ports, self.use_sram_stub,
                     self.sram_macro_info, self.read_delay, self.rw_same_cycle,
                     self.agg_height, self.max_agg_schedule,
                     self.input_max_port_sched, self.output_max_port_sched,
                     self.align_input, self.max_line_length,
                     self.max_tb_height, self.tb_range_max, self.tb_sched_max,
                     self.max_tb_stride, self.num_tb, self.tb_iterator_support,
                     self.multiwrite, self.max_prefetch,
                     self.config_data_width, self.config_addr_width,
                     self.num_tiles, self.remove_tb, self.fifo_mode,
                     self.stcl_valid_iter, self.add_clk_enable, self.add_flush,
                     self.app_ctrl_depth_width)

        # Check for circuit caching
        if cache_key not in MemCore.__circuit_cache:

            # Instantiate core object here - will only use the object representation to
            # query for information. The circuit representation will be cached and retrieved
            # in the following steps.
            lt_dut = LakeTop(
                data_width=self.data_width,
                mem_width=self.mem_width,
                mem_depth=self.mem_depth,
                banks=self.banks,
                input_iterator_support=self.input_iterator_support,
                output_iterator_support=self.output_iterator_support,
                input_config_width=self.input_config_width,
                output_config_width=self.output_config_width,
                interconnect_input_ports=self.interconnect_input_ports,
                interconnect_output_ports=self.interconnect_output_ports,
                use_sram_stub=self.use_sram_stub,
                sram_macro_info=self.sram_macro_info,
                read_delay=self.read_delay,
                rw_same_cycle=self.rw_same_cycle,
                agg_height=self.agg_height,
                max_agg_schedule=self.max_agg_schedule,
                input_max_port_sched=self.input_max_port_sched,
                output_max_port_sched=self.output_max_port_sched,
                align_input=self.align_input,
                max_line_length=self.max_line_length,
                max_tb_height=self.max_tb_height,
                tb_range_max=self.tb_range_max,
                tb_range_inner_max=self.tb_range_inner_max,
                tb_sched_max=self.tb_sched_max,
                max_tb_stride=self.max_tb_stride,
                num_tb=self.num_tb,
                tb_iterator_support=self.tb_iterator_support,
                multiwrite=self.multiwrite,
                max_prefetch=self.max_prefetch,
                config_data_width=self.config_data_width,
                config_addr_width=self.config_addr_width,
                num_tiles=self.num_tiles,
                app_ctrl_depth_width=self.app_ctrl_depth_width,
                remove_tb=self.remove_tb,
                fifo_mode=self.fifo_mode,
                add_clk_enable=self.add_clk_enable,
                add_flush=self.add_flush,
                stcl_valid_iter=self.stcl_valid_iter)

            change_sram_port_pass = change_sram_port_names(
                use_sram_stub, sram_macro_info)
            circ = kts.util.to_magma(
                lt_dut,
                flatten_array=True,
                check_multiple_driver=False,
                optimize_if=False,
                check_flip_flop_always_ff=False,
                additional_passes={"change_sram_port": change_sram_port_pass})
            MemCore.__circuit_cache[cache_key] = (circ, lt_dut)
        else:
            circ, lt_dut = MemCore.__circuit_cache[cache_key]

        # Save as underlying circuit object
        self.underlying = FromMagma(circ)

        self.chain_idx_bits = max(1, kts.clog2(self.num_tiles))

        # put a 1-bit register and a mux to select the control signals
        # TODO: check if enable_chain_output needs to be here? I don't think so?
        control_signals = [("wen_in", self.interconnect_input_ports),
                           ("ren_in", self.interconnect_output_ports),
                           ("flush", 1),
                           ("chain_valid_in", self.interconnect_output_ports)]
        for control_signal, width in control_signals:
            # TODO: consult with Ankita to see if we can use the normal
            # mux here
            if width == 1:
                mux = MuxWrapper(2, 1, name=f"{control_signal}_sel")
                reg_value_name = f"{control_signal}_reg_value"
                reg_sel_name = f"{control_signal}_reg_sel"
                self.add_config(reg_value_name, 1)
                self.add_config(reg_sel_name, 1)
                self.wire(mux.ports.I[0], self.ports[control_signal])
                self.wire(mux.ports.I[1],
                          self.registers[reg_value_name].ports.O)
                self.wire(mux.ports.S, self.registers[reg_sel_name].ports.O)
                # 0 is the default wire, which takes from the routing network
                self.wire(mux.ports.O[0],
                          self.underlying.ports[control_signal][0])
            else:
                for i in range(width):
                    mux = MuxWrapper(2, 1, name=f"{control_signal}_{i}_sel")
                    reg_value_name = f"{control_signal}_{i}_reg_value"
                    reg_sel_name = f"{control_signal}_{i}_reg_sel"
                    self.add_config(reg_value_name, 1)
                    self.add_config(reg_sel_name, 1)
                    self.wire(mux.ports.I[0],
                              self.ports[f"{control_signal}_{i}"])
                    self.wire(mux.ports.I[1],
                              self.registers[reg_value_name].ports.O)
                    self.wire(mux.ports.S,
                              self.registers[reg_sel_name].ports.O)
                    # 0 is the default wire, which takes from the routing network
                    self.wire(mux.ports.O[0],
                              self.underlying.ports[control_signal][i])

        if self.interconnect_input_ports > 1:
            for i in range(self.interconnect_input_ports):
                self.wire(self.ports[f"data_in_{i}"],
                          self.underlying.ports[f"data_in_{i}"])
                self.wire(self.ports[f"addr_in_{i}"],
                          self.underlying.ports[f"addr_in_{i}"])
        else:
            self.wire(self.ports.addr_in, self.underlying.ports.addr_in)
            self.wire(self.ports.data_in, self.underlying.ports.data_in)

        if self.interconnect_output_ports > 1:
            for i in range(self.interconnect_output_ports):
                self.wire(self.ports[f"data_out_{i}"],
                          self.underlying.ports[f"data_out_{i}"])
                self.wire(self.ports[f"chain_data_in_{i}"],
                          self.underlying.ports[f"chain_data_in_{i}"])
                self.wire(self.ports[f"chain_data_out_{i}"],
                          self.underlying.ports[f"chain_data_out_{i}"])
        else:
            self.wire(self.ports.data_out, self.underlying.ports.data_out)
            self.wire(self.ports.chain_data_in,
                      self.underlying.ports.chain_data_in)
            self.wire(self.ports.chain_data_out,
                      self.underlying.ports.chain_data_out)

        # Need to invert this
        self.resetInverter = FromMagma(mantle.DefineInvert(1))
        self.wire(self.resetInverter.ports.I[0], self.ports.reset)
        self.wire(self.resetInverter.ports.O[0], self.underlying.ports.rst_n)
        self.wire(self.ports.clk, self.underlying.ports.clk)
        if self.interconnect_output_ports == 1:
            self.wire(self.ports.valid_out[0],
                      self.underlying.ports.valid_out[0])
            self.wire(self.ports.chain_valid_out[0],
                      self.underlying.ports.chain_valid_out[0])
        else:
            for j in range(self.interconnect_output_ports):
                self.wire(self.ports[f"valid_out_{j}"][0],
                          self.underlying.ports.valid_out[j])
                self.wire(self.ports[f"chain_valid_out_{j}"][0],
                          self.underlying.ports.chain_valid_out[j])
        self.wire(self.ports.empty[0], self.underlying.ports.empty[0])
        self.wire(self.ports.full[0], self.underlying.ports.full[0])

        # PE core uses clk_en (essentially active low stall)
        self.stallInverter = FromMagma(mantle.DefineInvert(1))
        self.wire(self.stallInverter.ports.I, self.ports.stall)
        self.wire(self.stallInverter.ports.O[0],
                  self.underlying.ports.clk_en[0])

        self.wire(self.ports.sram_ready_out[0],
                  self.underlying.ports.sram_ready_out[0])

        # we have six? features in total
        # 0:    TILE
        # 1:    TILE
        # 1-4:  SMEM
        # Feature 0: Tile
        self.__features: List[CoreFeature] = [self]
        # Features 1-4: SRAM
        self.num_sram_features = lt_dut.total_sets
        for sram_index in range(self.num_sram_features):
            core_feature = CoreFeature(self, sram_index + 1)
            self.__features.append(core_feature)

        # Wire the config
        for idx, core_feature in enumerate(self.__features):
            if (idx > 0):
                self.add_port(f"config_{idx}",
                              magma.In(ConfigurationType(8, 32)))
                # port aliasing
                core_feature.ports["config"] = self.ports[f"config_{idx}"]
        self.add_port("config", magma.In(ConfigurationType(8, 32)))

        # or the signal up
        t = ConfigurationType(8, 32)
        t_names = ["config_addr", "config_data"]
        or_gates = {}
        for t_name in t_names:
            port_type = t[t_name]
            or_gate = FromMagma(
                mantle.DefineOr(len(self.__features), len(port_type)))
            or_gate.instance_name = f"OR_{t_name}_FEATURE"
            for idx, core_feature in enumerate(self.__features):
                self.wire(or_gate.ports[f"I{idx}"],
                          core_feature.ports.config[t_name])
            or_gates[t_name] = or_gate

        self.wire(or_gates["config_addr"].ports.O,
                  self.underlying.ports.config_addr_in[0:8])
        self.wire(or_gates["config_data"].ports.O,
                  self.underlying.ports.config_data_in)

        # read data out
        for idx, core_feature in enumerate(self.__features):
            if (idx > 0):
                # self.add_port(f"read_config_data_{idx}",
                self.add_port(f"read_config_data_{idx}",
                              magma.Out(magma.Bits[32]))
                # port aliasing
                core_feature.ports["read_config_data"] = \
                    self.ports[f"read_config_data_{idx}"]

        # MEM Config
        configurations = [("tile_en", 1), ("fifo_ctrl_fifo_depth", 16),
                          ("mode", 2), ("enable_chain_output", 1),
                          ("enable_chain_input", 1)]
        #            ("stencil_width", 16), NOT YET

        merged_configs = []
        merged_in_sched = []
        merged_out_sched = []

        # Add config registers to configurations
        # TODO: Have lake spit this information out automatically from the wrapper

        configurations.append((f"chain_idx_input", self.chain_idx_bits))
        configurations.append((f"chain_idx_output", self.chain_idx_bits))
        for i in range(self.interconnect_input_ports):
            configurations.append((f"strg_ub_agg_align_{i}_line_length",
                                   kts.clog2(self.max_line_length)))
            configurations.append((f"strg_ub_agg_in_{i}_in_period",
                                   kts.clog2(self.input_max_port_sched)))

            # num_bits_in_sched = kts.clog2(self.agg_height)
            # sched_per_feat = math.floor(self.config_data_width / num_bits_in_sched)
            # new_width = num_bits_in_sched * sched_per_feat
            # feat_num = 0
            # num_feats_merge = math.ceil(self.input_max_port_sched / sched_per_feat)
            # for k in range(num_feats_merge):
            #    num_here = sched_per_feat
            #    if self.input_max_port_sched - (k * sched_per_feat) < sched_per_feat:
            #        num_here = self.input_max_port_sched - (k * sched_per_feat)
            #    merged_configs.append((f"strg_ub_agg_in_{i}_in_sched_merged_{k * sched_per_feat}",
            #                          num_here * num_bits_in_sched, num_here))
            for j in range(self.input_max_port_sched):
                configurations.append((f"strg_ub_agg_in_{i}_in_sched_{j}",
                                       kts.clog2(self.agg_height)))

            configurations.append((f"strg_ub_agg_in_{i}_out_period",
                                   kts.clog2(self.input_max_port_sched)))

            for j in range(self.output_max_port_sched):
                configurations.append((f"strg_ub_agg_in_{i}_out_sched_{j}",
                                       kts.clog2(self.agg_height)))

            configurations.append((f"strg_ub_app_ctrl_write_depth_wo_{i}",
                                   self.app_ctrl_depth_width))
            configurations.append((f"strg_ub_app_ctrl_write_depth_ss_{i}",
                                   self.app_ctrl_depth_width))
            configurations.append(
                (f"strg_ub_app_ctrl_coarse_write_depth_wo_{i}",
                 self.app_ctrl_depth_width))
            configurations.append(
                (f"strg_ub_app_ctrl_coarse_write_depth_ss_{i}",
                 self.app_ctrl_depth_width))

            configurations.append(
                (f"strg_ub_input_addr_ctrl_address_gen_{i}_dimensionality",
                 1 + kts.clog2(self.input_iterator_support)))
            configurations.append(
                (f"strg_ub_input_addr_ctrl_address_gen_{i}_starting_addr",
                 self.input_config_width))
            for j in range(self.input_iterator_support):
                configurations.append(
                    (f"strg_ub_input_addr_ctrl_address_gen_{i}_ranges_{j}",
                     self.input_config_width))
                configurations.append(
                    (f"strg_ub_input_addr_ctrl_address_gen_{i}_strides_{j}",
                     self.input_config_width))

        configurations.append(
            (f"strg_ub_app_ctrl_prefill", self.interconnect_output_ports))
        configurations.append((f"strg_ub_app_ctrl_coarse_prefill",
                               self.interconnect_output_ports))

        for i in range(self.stcl_valid_iter):
            configurations.append((f"strg_ub_app_ctrl_ranges_{i}", 16))
            configurations.append((f"strg_ub_app_ctrl_threshold_{i}", 16))

        for i in range(self.interconnect_output_ports):
            configurations.append((f"strg_ub_app_ctrl_input_port_{i}",
                                   kts.clog2(self.interconnect_input_ports)))
            configurations.append((f"strg_ub_app_ctrl_read_depth_{i}",
                                   self.app_ctrl_depth_width))
            configurations.append((f"strg_ub_app_ctrl_coarse_input_port_{i}",
                                   kts.clog2(self.interconnect_input_ports)))
            configurations.append((f"strg_ub_app_ctrl_coarse_read_depth_{i}",
                                   self.app_ctrl_depth_width))

            configurations.append(
                (f"strg_ub_output_addr_ctrl_address_gen_{i}_dimensionality",
                 1 + kts.clog2(self.output_iterator_support)))
            configurations.append(
                (f"strg_ub_output_addr_ctrl_address_gen_{i}_starting_addr",
                 self.output_config_width))
            for j in range(self.output_iterator_support):
                configurations.append(
                    (f"strg_ub_output_addr_ctrl_address_gen_{i}_ranges_{j}",
                     self.output_config_width))
                configurations.append(
                    (f"strg_ub_output_addr_ctrl_address_gen_{i}_strides_{j}",
                     self.output_config_width))

            configurations.append((f"strg_ub_pre_fetch_{i}_input_latency",
                                   kts.clog2(self.max_prefetch) + 1))
            configurations.append((f"strg_ub_sync_grp_sync_group_{i}",
                                   self.interconnect_output_ports))
            configurations.append(
                (f"strg_ub_rate_matched_{i}",
                 1 + kts.clog2(self.interconnect_input_ports)))

            for j in range(self.num_tb):
                configurations.append(
                    (f"strg_ub_tba_{i}_tb_{j}_dimensionality", 2))
                num_indices_bits = 1 + kts.clog2(self.fw_int)
                indices_per_feat = math.floor(self.config_data_width /
                                              num_indices_bits)
                new_width = num_indices_bits * indices_per_feat
                feat_num = 0
                num_feats_merge = math.ceil(self.tb_range_inner_max /
                                            indices_per_feat)
                for k in range(num_feats_merge):
                    num_idx = indices_per_feat
                    if (self.tb_range_inner_max -
                        (k * indices_per_feat)) < indices_per_feat:
                        num_idx = self.tb_range_inner_max - (k *
                                                             indices_per_feat)
                    merged_configs.append((
                        f"strg_ub_tba_{i}_tb_{j}_indices_merged_{k * indices_per_feat}",
                        num_idx * num_indices_bits, num_idx))


#                for k in range(self.tb_range_inner_max):
#                    configurations.append((f"strg_ub_tba_{i}_tb_{j}_indices_{k}", kts.clog2(self.fw_int) + 1))
                configurations.append((f"strg_ub_tba_{i}_tb_{j}_range_inner",
                                       kts.clog2(self.tb_range_inner_max)))
                configurations.append((f"strg_ub_tba_{i}_tb_{j}_range_outer",
                                       kts.clog2(self.tb_range_max)))
                configurations.append((f"strg_ub_tba_{i}_tb_{j}_stride",
                                       kts.clog2(self.max_tb_stride)))
                configurations.append((f"strg_ub_tba_{i}_tb_{j}_tb_height",
                                       max(1, kts.clog2(self.num_tb))))
                configurations.append((f"strg_ub_tba_{i}_tb_{j}_starting_addr",
                                       max(1, kts.clog2(self.fw_int))))

        # Do all the stuff for the main config
        main_feature = self.__features[0]
        for config_reg_name, width in configurations:
            main_feature.add_config(config_reg_name, width)
            if (width == 1):
                self.wire(main_feature.registers[config_reg_name].ports.O[0],
                          self.underlying.ports[config_reg_name][0])
            else:
                self.wire(main_feature.registers[config_reg_name].ports.O,
                          self.underlying.ports[config_reg_name])

        for config_reg_name, width, num_merged in merged_configs:
            main_feature.add_config(config_reg_name, width)
            token_under = config_reg_name.split("_")
            base_name = config_reg_name.split("_merged")[0]
            base_indices = int(config_reg_name.split("_merged_")[1])
            num_bits = width // num_merged
            for i in range(num_merged):
                self.wire(
                    main_feature.registers[config_reg_name].ports.
                    O[i * num_bits:(i + 1) * num_bits],
                    self.underlying.ports[f"{base_name}_{base_indices + i}"])

        # SRAM
        # These should also account for num features
        # or_all_cfg_rd = FromMagma(mantle.DefineOr(4, 1))
        or_all_cfg_rd = FromMagma(mantle.DefineOr(self.num_sram_features, 1))
        or_all_cfg_rd.instance_name = f"OR_CONFIG_WR_SRAM"
        or_all_cfg_wr = FromMagma(mantle.DefineOr(self.num_sram_features, 1))
        or_all_cfg_wr.instance_name = f"OR_CONFIG_RD_SRAM"
        for sram_index in range(self.num_sram_features):
            core_feature = self.__features[sram_index + 1]
            self.add_port(f"config_en_{sram_index}", magma.In(magma.Bit))
            # port aliasing
            core_feature.ports["config_en"] = \
                self.ports[f"config_en_{sram_index}"]
            self.wire(core_feature.ports.read_config_data,
                      self.underlying.ports[f"config_data_out_{sram_index}"])
            # also need to wire the sram signal
            # the config enable is the OR of the rd+wr
            or_gate_en = FromMagma(mantle.DefineOr(2, 1))
            or_gate_en.instance_name = f"OR_CONFIG_EN_SRAM_{sram_index}"

            self.wire(or_gate_en.ports.I0, core_feature.ports.config.write)
            self.wire(or_gate_en.ports.I1, core_feature.ports.config.read)
            self.wire(core_feature.ports.config_en,
                      self.underlying.ports["config_en"][sram_index])
            # Still connect to the OR of all the config rd/wr
            self.wire(core_feature.ports.config.write,
                      or_all_cfg_wr.ports[f"I{sram_index}"])
            self.wire(core_feature.ports.config.read,
                      or_all_cfg_rd.ports[f"I{sram_index}"])

        self.wire(or_all_cfg_rd.ports.O[0],
                  self.underlying.ports.config_read[0])
        self.wire(or_all_cfg_wr.ports.O[0],
                  self.underlying.ports.config_write[0])
        self._setup_config()

        conf_names = list(self.registers.keys())
        conf_names.sort()
        with open("mem_cfg.txt", "w+") as cfg_dump:
            for idx, reg in enumerate(conf_names):
                write_line = f"|{reg}|{idx}|{self.registers[reg].width}||\n"
                cfg_dump.write(write_line)
        with open("mem_synth.txt", "w+") as cfg_dump:
            for idx, reg in enumerate(conf_names):
                write_line = f"{reg}\n"
                cfg_dump.write(write_line)
コード例 #6
0
    def __init__(
            self,
            data_width=16,  # CGRA Params
            mem_width=64,
            mem_depth=512,
            banks=2,
            input_iterator_support=6,  # Addr Controllers
            output_iterator_support=6,
            interconnect_input_ports=1,  # Connection to int
            interconnect_output_ports=3,
            mem_input_ports=1,
            mem_output_ports=1,
            use_sram_stub=1,
            sram_macro_info=SRAMMacroInfo(),
            read_delay=1,  # Cycle delay in read (SRAM vs Register File)
            rw_same_cycle=False,  # Does the memory allow r+w in same cycle?
            agg_height=4,
            max_agg_schedule=32,
            input_max_port_sched=32,
            output_max_port_sched=32,
            align_input=1,
            max_line_length=128,
            max_tb_height=1,
            tb_range_max=128,
            tb_sched_max=64,
            max_tb_stride=15,
            num_tb=1,
            tb_iterator_support=2,
            multiwrite=1,
            max_prefetch=64,
            config_data_width=16,
            config_addr_width=8,
            num_tiles=2,
            remove_tb=False,
            fifo_mode=False,
            add_clk_enable=False,
            add_flush=False):
        super().__init__("LakeChain", debug=True)

        fw_int = int(mem_width / data_width)
        data_words_per_set = 2**config_addr_width
        sets = int((fw_int * mem_depth) / data_words_per_set)

        sets_per_macro = max(1, int(mem_depth / data_words_per_set))
        total_sets = max(1, banks * sets_per_macro)

        self._clk = self.clock("clk")
        self._rst_n = self.reset("rst_n")

        self._data_in = self.input("data_in",
                                   data_width,
                                   size=interconnect_input_ports,
                                   packed=True,
                                   explicit_array=True)
        self._addr_in = self.input("addr_in",
                                   data_width,
                                   size=interconnect_input_ports,
                                   packed=True,
                                   explicit_array=True)

        self._wen = self.input("wen", interconnect_input_ports)
        self._ren = self.input("ren", interconnect_output_ports)

        self._config_data_in = self.input("config_data_in", config_data_width)

        self._config_addr_in = self.input("config_addr_in", config_addr_width)

        self._config_data_out = self.output("config_data_out",
                                            config_data_width,
                                            size=(num_tiles, total_sets),
                                            explicit_array=True,
                                            packed=True)

        self._config_read = self.input("config_read", 1)
        self._config_write = self.input("config_write", 1)
        self._config_en = self.input("config_en", total_sets)

        self._data_out = self.output("data_out",
                                     data_width,
                                     size=(num_tiles,
                                           interconnect_output_ports),
                                     packed=True,
                                     explicit_array=True)

        self._data_out_inter = self.var("data_out_inter",
                                        data_width,
                                        size=(num_tiles,
                                              interconnect_output_ports),
                                        packed=True,
                                        explicit_array=True)

        self._valid_out = self.output("valid_out",
                                      interconnect_output_ports,
                                      size=num_tiles,
                                      packed=True,
                                      explicit_array=True)

        self._valid_out_inter = self.var("valid_out_inter",
                                         interconnect_output_ports,
                                         size=num_tiles,
                                         packed=True,
                                         explicit_array=True)

        self._enable_chain_output = self.input("enable_chain_output", 1)

        self._chain_data_out = self.output("chain_data_out",
                                           data_width,
                                           size=interconnect_output_ports,
                                           packed=True,
                                           explicit_array=True)

        self._chain_valid_out = self.output("chain_valid_out",
                                            interconnect_output_ports)

        self._tile_output_en = self.var("tile_output_en",
                                        1,
                                        size=(num_tiles,
                                              interconnect_output_ports),
                                        packed=True,
                                        explicit_array=True)

        self.is_valid_ = self.var("is_valid",
                                  1,
                                  size=interconnect_output_ports,
                                  packed=True,
                                  explicit_array=True)

        self.valids = self.var("valids",
                               clog2(num_tiles),
                               size=interconnect_output_ports,
                               packed=True,
                               explicit_array=True)

        for i in range(num_tiles):
            tile = LakeTop(data_width=data_width,
                           mem_width=mem_width,
                           mem_depth=mem_depth,
                           banks=banks,
                           input_iterator_support=input_iterator_support,
                           output_iterator_support=output_iterator_support,
                           interconnect_input_ports=interconnect_input_ports,
                           interconnect_output_ports=interconnect_output_ports,
                           mem_input_ports=mem_input_ports,
                           mem_output_ports=mem_output_ports,
                           use_sram_stub=use_sram_stub,
                           sram_macro_info=sram_macro_info,
                           read_delay=read_delay,
                           rw_same_cycle=rw_same_cycle,
                           agg_height=agg_height,
                           max_agg_schedule=max_agg_schedule,
                           input_max_port_sched=input_max_port_sched,
                           output_max_port_sched=output_max_port_sched,
                           align_input=align_input,
                           max_line_length=max_line_length,
                           max_tb_height=max_tb_height,
                           tb_range_max=tb_range_max,
                           tb_sched_max=tb_sched_max,
                           max_tb_stride=max_tb_stride,
                           num_tb=num_tb,
                           tb_iterator_support=tb_iterator_support,
                           multiwrite=multiwrite,
                           max_prefetch=max_prefetch,
                           config_data_width=config_data_width,
                           config_addr_width=config_addr_width,
                           num_tiles=num_tiles,
                           remove_tb=remove_tb,
                           fifo_mode=fifo_mode,
                           add_clk_enable=add_clk_enable,
                           add_flush=add_flush)

            self.add_child(
                f"tile_{i}",
                tile,
                clk=self._clk,
                rst_n=self._rst_n,
                enable_chain_output=self._enable_chain_output,
                # tile index
                chain_idx_input=i,
                chain_idx_output=0,
                tile_output_en=self._tile_output_en[i],
                # broadcast input data to all tiles
                data_in=self._data_in,
                addr_in=self._addr_in,
                wen=self._wen,
                ren=self._ren,
                config_data_in=self._config_data_in,
                config_addr_in=self._config_addr_in,
                config_data_out=self._config_data_out[i],
                config_read=self._config_read,
                config_write=self._config_write,
                config_en=self._config_en,
                # used if output chaining not enabled
                data_out=self._data_out_inter[i],
                valid_out=self._valid_out_inter[i],
                # unused currently?
                tile_en=1,
                # UB mode
                mode=0)

        self.add_code(self.set_data_out)
        self.add_code(self.set_valid_out)
        self.add_code(self.set_chain_outputs)

        # config regs
        lift_config_reg(self.internal_generator)
コード例 #7
0
    def __init__(
            self,
            data_width=16,  # CGRA Params
            mem_width=64,
            mem_depth=512,
            banks=1,
            input_iterator_support=6,  # Addr Controllers
            output_iterator_support=6,
            input_config_width=16,
            output_config_width=16,
            interconnect_input_ports=2,  # Connection to int
            interconnect_output_ports=2,
            mem_input_ports=1,
            mem_output_ports=1,
            use_sram_stub=True,
            sram_macro_info=SRAMMacroInfo("TS1N16FFCLLSBLVTC512X32M4S",
                                          wtsel_value=0,
                                          rtsel_value=1),
            read_delay=1,  # Cycle delay in read (SRAM vs Register File)
            rw_same_cycle=False,  # Does the memory allow r+w in same cycle?
            agg_height=4,
            tb_sched_max=16,
            config_data_width=32,
            config_addr_width=8,
            num_tiles=1,
            fifo_mode=True,
            add_clk_enable=True,
            add_flush=True,
            override_name=None,
            gen_addr=True):

        lake_name = "LakeTop"

        super().__init__(config_data_width=config_data_width,
                         config_addr_width=config_addr_width,
                         data_width=data_width,
                         name="MemCore")

        # Capture everything to the tile object
        # self.data_width = data_width
        self.mem_width = mem_width
        self.mem_depth = mem_depth
        self.banks = banks
        self.fw_int = int(self.mem_width / self.data_width)
        self.input_iterator_support = input_iterator_support
        self.output_iterator_support = output_iterator_support
        self.input_config_width = input_config_width
        self.output_config_width = output_config_width
        self.interconnect_input_ports = interconnect_input_ports
        self.interconnect_output_ports = interconnect_output_ports
        self.mem_input_ports = mem_input_ports
        self.mem_output_ports = mem_output_ports
        self.use_sram_stub = use_sram_stub
        self.sram_macro_info = sram_macro_info
        self.read_delay = read_delay
        self.rw_same_cycle = rw_same_cycle
        self.agg_height = agg_height
        self.config_data_width = config_data_width
        self.config_addr_width = config_addr_width
        self.num_tiles = num_tiles
        self.fifo_mode = fifo_mode
        self.add_clk_enable = add_clk_enable
        self.add_flush = add_flush
        self.gen_addr = gen_addr
        # self.app_ctrl_depth_width = app_ctrl_depth_width
        # self.stcl_valid_iter = stcl_valid_iter
        # Typedefs for ease
        TData = magma.Bits[self.data_width]
        TBit = magma.Bits[1]

        cache_key = (self.data_width, self.mem_width, self.mem_depth,
                     self.banks, self.input_iterator_support,
                     self.output_iterator_support,
                     self.interconnect_input_ports,
                     self.interconnect_output_ports, self.use_sram_stub,
                     self.sram_macro_info, self.read_delay, self.rw_same_cycle,
                     self.agg_height, self.config_data_width,
                     self.config_addr_width, self.num_tiles, self.fifo_mode,
                     self.add_clk_enable, self.add_flush, self.gen_addr)

        # Check for circuit caching
        if cache_key not in LakeCoreBase._circuit_cache:

            # Instantiate core object here - will only use the object representation to
            # query for information. The circuit representation will be cached and retrieved
            # in the following steps.
            self.dut = LakeTop(
                data_width=self.data_width,
                mem_width=self.mem_width,
                mem_depth=self.mem_depth,
                banks=self.banks,
                input_iterator_support=self.input_iterator_support,
                output_iterator_support=self.output_iterator_support,
                input_config_width=self.input_config_width,
                output_config_width=self.output_config_width,
                interconnect_input_ports=self.interconnect_input_ports,
                interconnect_output_ports=self.interconnect_output_ports,
                use_sram_stub=self.use_sram_stub,
                sram_macro_info=self.sram_macro_info,
                read_delay=self.read_delay,
                rw_same_cycle=self.rw_same_cycle,
                agg_height=self.agg_height,
                config_data_width=self.config_data_width,
                config_addr_width=self.config_addr_width,
                num_tiles=self.num_tiles,
                fifo_mode=self.fifo_mode,
                add_clk_enable=self.add_clk_enable,
                add_flush=self.add_flush,
                name=lake_name,
                gen_addr=self.gen_addr)

            change_sram_port_pass = change_sram_port_names(
                use_sram_stub, sram_macro_info)
            circ = kts.util.to_magma(
                self.dut,
                flatten_array=True,
                check_multiple_driver=False,
                optimize_if=False,
                check_flip_flop_always_ff=False,
                additional_passes={"change_sram_port": change_sram_port_pass})
            LakeCoreBase._circuit_cache[cache_key] = (circ, self.dut)
        else:
            circ, self.dut = LakeCoreBase._circuit_cache[cache_key]

        # Save as underlying circuit object
        self.underlying = FromMagma(circ)

        self.wrap_lake_core()

        conf_names = list(self.registers.keys())
        conf_names.sort()
        with open("mem_cfg.txt", "w+") as cfg_dump:
            for idx, reg in enumerate(conf_names):
                write_line = f"(\"{reg}\", 0),  # {self.registers[reg].width}\n"
                cfg_dump.write(write_line)
        with open("mem_synth.txt", "w+") as cfg_dump:
            for idx, reg in enumerate(conf_names):
                write_line = f"{reg}\n"
                cfg_dump.write(write_line)
コード例 #8
0
class MemCore(LakeCoreBase):
    def __init__(
            self,
            data_width=16,  # CGRA Params
            mem_width=64,
            mem_depth=512,
            banks=1,
            input_iterator_support=6,  # Addr Controllers
            output_iterator_support=6,
            input_config_width=16,
            output_config_width=16,
            interconnect_input_ports=2,  # Connection to int
            interconnect_output_ports=2,
            mem_input_ports=1,
            mem_output_ports=1,
            use_sram_stub=True,
            sram_macro_info=SRAMMacroInfo("TS1N16FFCLLSBLVTC512X32M4S",
                                          wtsel_value=0,
                                          rtsel_value=1),
            read_delay=1,  # Cycle delay in read (SRAM vs Register File)
            rw_same_cycle=False,  # Does the memory allow r+w in same cycle?
            agg_height=4,
            tb_sched_max=16,
            config_data_width=32,
            config_addr_width=8,
            num_tiles=1,
            fifo_mode=True,
            add_clk_enable=True,
            add_flush=True,
            override_name=None,
            gen_addr=True):

        lake_name = "LakeTop"

        super().__init__(config_data_width=config_data_width,
                         config_addr_width=config_addr_width,
                         data_width=data_width,
                         name="MemCore")

        # Capture everything to the tile object
        # self.data_width = data_width
        self.mem_width = mem_width
        self.mem_depth = mem_depth
        self.banks = banks
        self.fw_int = int(self.mem_width / self.data_width)
        self.input_iterator_support = input_iterator_support
        self.output_iterator_support = output_iterator_support
        self.input_config_width = input_config_width
        self.output_config_width = output_config_width
        self.interconnect_input_ports = interconnect_input_ports
        self.interconnect_output_ports = interconnect_output_ports
        self.mem_input_ports = mem_input_ports
        self.mem_output_ports = mem_output_ports
        self.use_sram_stub = use_sram_stub
        self.sram_macro_info = sram_macro_info
        self.read_delay = read_delay
        self.rw_same_cycle = rw_same_cycle
        self.agg_height = agg_height
        self.config_data_width = config_data_width
        self.config_addr_width = config_addr_width
        self.num_tiles = num_tiles
        self.fifo_mode = fifo_mode
        self.add_clk_enable = add_clk_enable
        self.add_flush = add_flush
        self.gen_addr = gen_addr
        # self.app_ctrl_depth_width = app_ctrl_depth_width
        # self.stcl_valid_iter = stcl_valid_iter
        # Typedefs for ease
        TData = magma.Bits[self.data_width]
        TBit = magma.Bits[1]

        cache_key = (self.data_width, self.mem_width, self.mem_depth,
                     self.banks, self.input_iterator_support,
                     self.output_iterator_support,
                     self.interconnect_input_ports,
                     self.interconnect_output_ports, self.use_sram_stub,
                     self.sram_macro_info, self.read_delay, self.rw_same_cycle,
                     self.agg_height, self.config_data_width,
                     self.config_addr_width, self.num_tiles, self.fifo_mode,
                     self.add_clk_enable, self.add_flush, self.gen_addr)

        # Check for circuit caching
        if cache_key not in LakeCoreBase._circuit_cache:

            # Instantiate core object here - will only use the object representation to
            # query for information. The circuit representation will be cached and retrieved
            # in the following steps.
            self.dut = LakeTop(
                data_width=self.data_width,
                mem_width=self.mem_width,
                mem_depth=self.mem_depth,
                banks=self.banks,
                input_iterator_support=self.input_iterator_support,
                output_iterator_support=self.output_iterator_support,
                input_config_width=self.input_config_width,
                output_config_width=self.output_config_width,
                interconnect_input_ports=self.interconnect_input_ports,
                interconnect_output_ports=self.interconnect_output_ports,
                use_sram_stub=self.use_sram_stub,
                sram_macro_info=self.sram_macro_info,
                read_delay=self.read_delay,
                rw_same_cycle=self.rw_same_cycle,
                agg_height=self.agg_height,
                config_data_width=self.config_data_width,
                config_addr_width=self.config_addr_width,
                num_tiles=self.num_tiles,
                fifo_mode=self.fifo_mode,
                add_clk_enable=self.add_clk_enable,
                add_flush=self.add_flush,
                name=lake_name,
                gen_addr=self.gen_addr)

            change_sram_port_pass = change_sram_port_names(
                use_sram_stub, sram_macro_info)
            circ = kts.util.to_magma(
                self.dut,
                flatten_array=True,
                check_multiple_driver=False,
                optimize_if=False,
                check_flip_flop_always_ff=False,
                additional_passes={"change_sram_port": change_sram_port_pass})
            LakeCoreBase._circuit_cache[cache_key] = (circ, self.dut)
        else:
            circ, self.dut = LakeCoreBase._circuit_cache[cache_key]

        # Save as underlying circuit object
        self.underlying = FromMagma(circ)

        self.wrap_lake_core()

        conf_names = list(self.registers.keys())
        conf_names.sort()
        with open("mem_cfg.txt", "w+") as cfg_dump:
            for idx, reg in enumerate(conf_names):
                write_line = f"(\"{reg}\", 0),  # {self.registers[reg].width}\n"
                cfg_dump.write(write_line)
        with open("mem_synth.txt", "w+") as cfg_dump:
            for idx, reg in enumerate(conf_names):
                write_line = f"{reg}\n"
                cfg_dump.write(write_line)

    def get_config_bitstream(self, instr):
        configs = []
        config_runtime = []

        mode_map = {
            "lake": MemoryMode.UNIFIED_BUFFER,
            "rom": MemoryMode.ROM,
            "sram": MemoryMode.SRAM,
            "fifo": MemoryMode.FIFO,
        }

        # Extract the runtime + preload config
        if "config" in instr:
            top_config = instr['config']
            if "init" in top_config:
                instr["init"] = top_config["init"]

        # Add in preloaded memory
        if "init" in instr:
            # this is SRAM content
            content = instr['init']
            for addr, data in enumerate(content):
                if (not isinstance(data, int)) and len(data) == 2:
                    addr, data = data
                addr = addr >> 2
                feat_addr = addr // 256 + 1
                addr = (addr % 256)
                configs.append((addr, feat_addr, data))

        # Extract mode to the enum
        if "is_rom" in instr and instr["is_rom"]:
            mode = mode_map["rom"]
        else:
            mode = mode_map[instr['mode']]

        if mode == MemoryMode.UNIFIED_BUFFER:
            config_runtime = self.dut.get_static_bitstream_json(top_config)
        elif mode == MemoryMode.ROM:
            # Rom mode is simply SRAM mode with the writes disabled
            config_runtime = [("tile_en", 1), ("mode", 2),
                              ("wen_in_0_reg_sel", 1), ("wen_in_1_reg_sel", 1)]
        elif mode == MemoryMode.SRAM:
            # SRAM mode gives 1 write port, 1 read port currently
            config_runtime = [("tile_en", 1), ("mode", 2),
                              ("wen_in_1_reg_sel", 1)]
        elif mode == MemoryMode.FIFO:
            # FIFO mode gives 1 write port, 1 read port currently
            assert 'depth' in top_config, "FIFO configuration needs a 'depth' - please include one in the config"
            fifo_depth = int(top_config['depth'])

            config_runtime = [("tile_en", 1), ("mode", 1),
                              ("wen_in_1_reg_sel", 1),
                              ("strg_fifo_fifo_depth", fifo_depth)]

        # Add the runtime configuration to the final config
        for name, v in config_runtime:
            configs = [self.get_config_data(name, v)] + configs

        #print(configs)
        return configs

    def get_static_bitstream(self, config_path, in_file_name, out_file_name):

        # Don't do the rest anymore...
        return self.dut.get_static_bitstream(config_path=config_path,
                                             in_file_name=in_file_name,
                                             out_file_name=out_file_name)

    def pnr_info(self):
        return PnRTag("m", self.DEFAULT_PRIORITY - 1, self.DEFAULT_PRIORITY)
コード例 #9
0
class MemCore(LakeCoreBase):
    def __init__(
            self,
            data_width=16,  # CGRA Params
            mem_width=64,
            mem_depth=512,
            banks=1,
            input_iterator_support=6,  # Addr Controllers
            output_iterator_support=6,
            input_config_width=16,
            output_config_width=16,
            interconnect_input_ports=2,  # Connection to int
            interconnect_output_ports=2,
            mem_input_ports=1,
            mem_output_ports=1,
            use_sram_stub=True,
            sram_macro_info=SRAMMacroInfo("TS1N16FFCLLSBLVTC512X32M4S",
                                          wtsel_value=0,
                                          rtsel_value=1),
            read_delay=1,  # Cycle delay in read (SRAM vs Register File)
            rw_same_cycle=False,  # Does the memory allow r+w in same cycle?
            agg_height=4,
            tb_sched_max=16,
            config_data_width=32,
            config_addr_width=8,
            num_tiles=1,
            fifo_mode=True,
            add_clk_enable=True,
            add_flush=True,
            override_name=None,
            gen_addr=True):

        lake_name = "LakeTop"

        super().__init__(config_data_width=config_data_width,
                         config_addr_width=config_addr_width,
                         data_width=data_width,
                         name="MemCore")

        # Capture everything to the tile object
        # self.data_width = data_width
        self.mem_width = mem_width
        self.mem_depth = mem_depth
        self.banks = banks
        self.fw_int = int(self.mem_width / self.data_width)
        self.input_iterator_support = input_iterator_support
        self.output_iterator_support = output_iterator_support
        self.input_config_width = input_config_width
        self.output_config_width = output_config_width
        self.interconnect_input_ports = interconnect_input_ports
        self.interconnect_output_ports = interconnect_output_ports
        self.mem_input_ports = mem_input_ports
        self.mem_output_ports = mem_output_ports
        self.use_sram_stub = use_sram_stub
        self.sram_macro_info = sram_macro_info
        self.read_delay = read_delay
        self.rw_same_cycle = rw_same_cycle
        self.agg_height = agg_height
        self.config_data_width = config_data_width
        self.config_addr_width = config_addr_width
        self.num_tiles = num_tiles
        self.fifo_mode = fifo_mode
        self.add_clk_enable = add_clk_enable
        self.add_flush = add_flush
        self.gen_addr = gen_addr
        # self.app_ctrl_depth_width = app_ctrl_depth_width
        # self.stcl_valid_iter = stcl_valid_iter
        # Typedefs for ease
        TData = magma.Bits[self.data_width]
        TBit = magma.Bits[1]

        cache_key = (self.data_width, self.mem_width, self.mem_depth,
                     self.banks, self.input_iterator_support,
                     self.output_iterator_support,
                     self.interconnect_input_ports,
                     self.interconnect_output_ports, self.use_sram_stub,
                     self.sram_macro_info, self.read_delay, self.rw_same_cycle,
                     self.agg_height, self.config_data_width,
                     self.config_addr_width, self.num_tiles, self.fifo_mode,
                     self.add_clk_enable, self.add_flush, self.gen_addr)

        # Check for circuit caching
        if cache_key not in LakeCoreBase._circuit_cache:

            # Instantiate core object here - will only use the object representation to
            # query for information. The circuit representation will be cached and retrieved
            # in the following steps.
            self.dut = LakeTop(
                data_width=self.data_width,
                mem_width=self.mem_width,
                mem_depth=self.mem_depth,
                banks=self.banks,
                input_iterator_support=self.input_iterator_support,
                output_iterator_support=self.output_iterator_support,
                input_config_width=self.input_config_width,
                output_config_width=self.output_config_width,
                interconnect_input_ports=self.interconnect_input_ports,
                interconnect_output_ports=self.interconnect_output_ports,
                use_sram_stub=self.use_sram_stub,
                sram_macro_info=self.sram_macro_info,
                read_delay=self.read_delay,
                rw_same_cycle=self.rw_same_cycle,
                agg_height=self.agg_height,
                config_data_width=self.config_data_width,
                config_addr_width=self.config_addr_width,
                num_tiles=self.num_tiles,
                fifo_mode=self.fifo_mode,
                add_clk_enable=self.add_clk_enable,
                add_flush=self.add_flush,
                name=lake_name,
                gen_addr=self.gen_addr)

            change_sram_port_pass = change_sram_port_names(
                use_sram_stub, sram_macro_info)
            circ = kts.util.to_magma(
                self.dut,
                flatten_array=True,
                check_multiple_driver=False,
                optimize_if=False,
                check_flip_flop_always_ff=False,
                additional_passes={"change_sram_port": change_sram_port_pass})
            LakeCoreBase._circuit_cache[cache_key] = (circ, self.dut)
        else:
            circ, self.dut = LakeCoreBase._circuit_cache[cache_key]

        # Save as underlying circuit object
        self.underlying = FromMagma(circ)

        self.wrap_lake_core()

        conf_names = list(self.registers.keys())
        conf_names.sort()
        with open("mem_cfg.txt", "w+") as cfg_dump:
            for idx, reg in enumerate(conf_names):
                write_line = f"(\"{reg}\", 0),  # {self.registers[reg].width}\n"
                cfg_dump.write(write_line)
        with open("mem_synth.txt", "w+") as cfg_dump:
            for idx, reg in enumerate(conf_names):
                write_line = f"{reg}\n"
                cfg_dump.write(write_line)

    def get_config_bitstream(self, instr):
        configs = []
        if "init" in instr['config'][1]:
            config_mem = [("tile_en", 1), ("mode", 2), ("wen_in_0_reg_sel", 1),
                          ("wen_in_1_reg_sel", 1)]
            for name, v in config_mem:
                configs = [self.get_config_data(name, v)] + configs
            # this is SRAM content
            content = instr['config'][1]['init']
            for addr, data in enumerate(content):
                if (not isinstance(data, int)) and len(data) == 2:
                    addr, data = data
                feat_addr = addr // 256 + 1
                addr = (addr % 256) >> 2
                configs.append((addr, feat_addr, data))
            print(configs)
            return configs

        # unified buffer buffer stuff
        if "is_ub" in instr and instr["is_ub"]:
            depth = instr["range_0"]
            instr["depth"] = depth
            print("configure ub to have depth", depth)
        if "depth" in instr:
            # need to download the csv and get configuration files
            app_name = instr["app_name"]
            # hardcode the config bitstream depends on the apps
            config_mem = []
            print("app is", app_name)
            use_json = True
            if use_json:
                top_controller_node = instr['config'][1]
                config_mem = self.dut.get_static_bitstream_json(
                    top_controller_node)
            elif app_name == "conv_3_3":
                # Create a tempdir and download the files...
                with tempfile.TemporaryDirectory() as tempdir:
                    # Download files here and leverage lake bitstream code....
                    print(f'Downloading app files for {app_name}')
                    url_prefix = "https://raw.githubusercontent.com/dillonhuff/clockwork/" +\
                                 "fix_config/lake_controllers/conv_3_3_aha/buf_inst_input" +\
                                 "_10_to_buf_inst_output_3_ubuf/"
                    file_suffix = [
                        "input_agg2sram.csv", "input_in2agg_0.csv",
                        "output_2_sram2tb.csv", "output_2_tb2out_0.csv",
                        "output_2_tb2out_1.csv", "stencil_valid.csv"
                    ]
                    for fs in file_suffix:
                        full_url = url_prefix + fs
                        print(f"Downloading from {full_url}")
                        urllib.request.urlretrieve(full_url,
                                                   tempdir + "/" + fs)
                    config_path = tempdir
                    config_mem = self.get_static_bitstream(
                        config_path=config_path,
                        in_file_name="input",
                        out_file_name="output")

            for name, v in config_mem:
                configs += [self.get_config_data(name, v)]
            # gate config signals
            conf_names = ["wen_in_1_reg_sel"]
            for conf_name in conf_names:
                configs += [self.get_config_data(conf_name, 1)]
        else:
            # for now config it as sram
            config_mem = [("tile_en", 1), ("mode", 2), ("wen_in_0_reg_sel", 1),
                          ("wen_in_1_reg_sel", 1)]
            for name, v in config_mem:
                configs = [self.get_config_data(name, v)] + configs
        print(configs)
        return configs

    def get_static_bitstream(self, config_path, in_file_name, out_file_name):

        # Don't do the rest anymore...
        return self.dut.get_static_bitstream(config_path=config_path,
                                             in_file_name=in_file_name,
                                             out_file_name=out_file_name)

    def pnr_info(self):
        return PnRTag("m", self.DEFAULT_PRIORITY - 1, self.DEFAULT_PRIORITY)