Esempio n. 1
0
    class _Mux(Circuit):
        name = 'Mux_{}t_{}n'.format(cleanName(str(t)), n)
        addr_width = getRAMAddrWidth(n)
        IO = [
            'data',
            In(Array[n, t]), 'sel',
            In(Bits[addr_width]), 'out',
            Out(t)
        ]

        @classmethod
        def definition(cls):
            if n > 1:
                type_size_in_bits = GetCoreIRBackend().get_type(t).size
                mux = CommonlibMuxN(n, type_size_in_bits)

                type_to_bits = DefineNativeMapParallel(n, DefineDehydrate(t))()
                wire(cls.data, type_to_bits.I)
                wire(type_to_bits.out, mux.I.data)

                bits_to_type = Hydrate(t)
                wire(mux.out, bits_to_type.I)
                wire(bits_to_type.out, cls.out)

                wire(cls.sel, mux.I.sel)
            else:
                wire(cls.data[0], cls.out)
                sel_term = DefineTermAnyType(Bits[cls.addr_width])()
                wire(cls.sel, sel_term.I)
Esempio n. 2
0
    class _RAM(Circuit):
        name = 'RAM_{}t_{}n'.format(cleanName(str(t)), n)
        addr_width = getRAMAddrWidth(n)
        IO = [
            'RADDR',
            In(Bits[addr_width]), 'RDATA',
            Out(t), 'WADDR',
            In(Bits[addr_width]), 'WDATA',
            In(t), 'WE',
            In(Bit)
        ] + ClockInterface()

        @classmethod
        def definition(cls):
            type_size_in_bits = GetCoreIRBackend().get_type(t).size
            ram = DefineRAM(n, type_size_in_bits, read_latency=read_latency)()

            type_to_bits = Dehydrate(t)
            wire(cls.WDATA, type_to_bits.I)
            wire(type_to_bits.out, ram.WDATA)

            bits_to_type = Hydrate(t)
            wire(ram.RDATA, bits_to_type.I)
            wire(bits_to_type.out, cls.RDATA)

            wire(cls.RADDR, ram.RADDR)
            wire(ram.WADDR, cls.WADDR)

            wire(cls.WE, ram.WE)
Esempio n. 3
0
    class _ROM(Circuit):
        name = 'LUT_{}t_{}n'.format(cleanName(str(t)), n)
        addr_width = getRAMAddrWidth(n)
        IO = [
            'addr',
            In(Bits[addr_width]),
            'data',
            Out(t),
        ] + ClockInterface()

        @classmethod
        def definition(cls):
            type_size_in_bits = GetCoreIRBackend().get_type(t).size
            bit_luts = []
            for i in range(type_size_in_bits):
                bit_luts += [
                    DefineLUT(seq2int([el[i] for el in init]),
                              getRAMAddrWidth(n))()
                ]

            bits_to_type = Hydrate(t)
            for i in range(type_size_in_bits):
                wire(bit_luts[i].O, bits_to_type.I[i])
            wire(bits_to_type.out, cls.data)

            for i in range(type_size_in_bits):
                for j in range(cls.addr_width):
                    wire(cls.addr[j], getattr(bit_luts[i], "I" + str(j)))
Esempio n. 4
0
        def definition(cls):
            type_size_in_bits = GetCoreIRBackend().get_type(t).size
            bit_luts = []
            for i in range(type_size_in_bits):
                bit_luts += [
                    DefineLUT(seq2int([el[i] for el in init]),
                              getRAMAddrWidth(n))()
                ]

            bits_to_type = Hydrate(t)
            for i in range(type_size_in_bits):
                wire(bit_luts[i].O, bits_to_type.I[i])
            wire(bits_to_type.out, cls.data)

            for i in range(type_size_in_bits):
                for j in range(cls.addr_width):
                    wire(cls.addr[j], getattr(bit_luts[i], "I" + str(j)))
Esempio n. 5
0
    class _RAM_ST(Circuit):
        name = 'RAM_ST_{}_hasReset{}'.format(cleanName(str(t)), str(has_reset))
        addr_width = getRAMAddrWidth(n)
        IO = ['RADDR', In(Bits[addr_width]),
              'RDATA', Out(t.magma_repr()),
              'WADDR', In(Bits[addr_width]),
              'WDATA', In(t.magma_repr()),
              'WE', In(Bit),
              'RE', In(Bit)
              ] + ClockInterface(has_ce=False, has_reset=has_reset)
        @classmethod
        def definition(cls):
            # each valid clock, going to get a magma_repr in
            # read or write each one of those to a location
            rams = [DefineRAMAnyType(t.magma_repr(), t.valid_clocks(), read_latency=read_latency)() for _ in range(n)]
            read_time_position_counter = DefineNestedCounters(t, has_cur_valid=True, has_ce=True, has_reset=has_reset)()
            read_valid_term = TermAnyType(Bit)
            read_last_term = TermAnyType(Bit)
            write_time_position_counter = DefineNestedCounters(t, has_cur_valid=True, has_ce=True, has_reset=has_reset)()
            write_valid_term = TermAnyType(Bit)
            write_last_term = TermAnyType(Bit)
            read_selector = DefineMuxAnyType(t.magma_repr(), n)()

            for i in range(n):
                wire(cls.WDATA, rams[i].WDATA)
                wire(write_time_position_counter.cur_valid, rams[i].WADDR)
                wire(read_selector.data[i], rams[i].RDATA)
                wire(read_time_position_counter.cur_valid, rams[i].RADDR)
                write_cur_ram = Decode(i, cls.WADDR.N)(cls.WADDR)
                wire(write_cur_ram & write_time_position_counter.valid, rams[i].WE)

            wire(cls.RADDR, read_selector.sel)
            wire(cls.RDATA, read_selector.out)

            wire(cls.WE, write_time_position_counter.CE)
            wire(cls.RE, read_time_position_counter.CE)

            wire(read_time_position_counter.valid, read_valid_term.I)
            wire(read_time_position_counter.last, read_last_term.I)
            wire(write_time_position_counter.valid, write_valid_term.I)
            wire(write_time_position_counter.last, write_last_term.I)

            if has_reset:
                wire(cls.RESET, write_time_position_counter.RESET)
                wire(cls.RESET, read_time_position_counter.RESET)
Esempio n. 6
0
    class _ROM(Circuit):
        name = 'ROM_{}t_{}n'.format(cleanName(str(t)), n)
        addr_width = getRAMAddrWidth(n)
        IO = ['RADDR', In(Bits[addr_width]),
              'RDATA', Out(t),
              'REN', In(Bit)
              ] + ClockInterface()
        @classmethod
        def definition(cls):
            type_size_in_bits = GetCoreIRBackend().get_type(t).size
            rom = DefineROM(n, type_size_in_bits)(coreir_configargs={"init": [list(el) for el in init]})

            bits_to_type = Hydrate(t)
            wire(rom.rdata, bits_to_type.I)
            wire(bits_to_type.out, cls.RDATA)

            wire(cls.RADDR, rom.raddr)

            wire(cls.REN, rom.ren)
Esempio n. 7
0
        def definition(cls):
            # first section creates the RAMs and LUTs that set values in them and the sorting network
            shared_and_diff_subtypes = get_shared_and_diff_subtypes(
                t_in, t_out)
            t_in_diff = shared_and_diff_subtypes.diff_input
            t_out_diff = shared_and_diff_subtypes.diff_output
            graph = build_permutation_graph(ST_TSeq(2, 0, t_in_diff),
                                            ST_TSeq(2, 0, t_out_diff))
            banks_write_addr_per_input_lane = get_banks_addr_per_lane(
                graph.input_nodes)
            input_lane_write_addr_per_bank = get_lane_addr_per_banks(
                graph.input_nodes)
            output_lane_read_addr_per_bank = get_lane_addr_per_banks(
                graph.output_nodes)

            # each ram only needs to be large enough to handle the number of addresses assigned to it
            # all rams receive the same number of writes
            # but some of those writes don't happen as the data is invalid, so don't need storage for them
            max_ram_addrs = [
                max([bank_clock_data.addr for bank_clock_data in bank_data])
                for bank_data in output_lane_read_addr_per_bank
            ]
            # rams also handle parallelism from outer_shared type as this affects all banks the same
            outer_shared_sseqs = remove_tseqs(
                shared_and_diff_subtypes.shared_outer)
            if outer_shared_sseqs == ST_Tombstone():
                ram_element_type = shared_and_diff_subtypes.shared_inner
            else:
                ram_element_type = replace_tombstone(
                    outer_shared_sseqs, shared_and_diff_subtypes.shared_inner)
            # can use wider rams rather than duplicate for outer_shared_sseqs because will
            # transpose dimenions of input wires below to wire up as if outer, shared dimensions
            # were on the inside
            rams = [
                DefineRAM_ST(ram_element_type, ram_max_addr + 1)()
                for ram_max_addr in max_ram_addrs
            ]
            rams_addr_widths = [ram.WADDR.N for ram in rams]

            # for bank, the addresses to write to each clock
            write_addr_for_bank_luts = []
            for bank_idx in range(len(rams)):
                ram_addr_width = rams_addr_widths[bank_idx]
                num_addrs = len(input_lane_write_addr_per_bank[bank_idx])
                #assert num_addrs == t_in_diff.time()
                write_addrs = [
                    builtins.tuple(
                        int2seq(write_data_per_bank_per_clock.addr,
                                ram_addr_width))
                    for write_data_per_bank_per_clock in
                    input_lane_write_addr_per_bank[bank_idx]
                ]
                write_addr_for_bank_luts.append(
                    DefineLUTAnyType(Array[ram_addr_width, Bit], num_addrs,
                                     builtins.tuple(write_addrs))())

            # for bank, whether to actually write this clock
            write_valid_for_bank_luts = []
            for bank_idx in range(len(rams)):
                num_valids = len(input_lane_write_addr_per_bank[bank_idx])
                #assert num_valids == t_in_diff.time()
                valids = [
                    builtins.tuple([write_data_per_bank_per_clock.valid])
                    for write_data_per_bank_per_clock in
                    input_lane_write_addr_per_bank[bank_idx]
                ]
                write_valid_for_bank_luts.append(
                    DefineLUTAnyType(Bit, num_valids,
                                     builtins.tuple(valids))())

            # for each input lane, the bank to write to each clock
            write_bank_for_input_lane_luts = []
            bank_idx_width = getRAMAddrWidth(len(rams))
            for lane_idx in range(len(banks_write_addr_per_input_lane)):
                num_bank_idxs = len(banks_write_addr_per_input_lane[lane_idx])
                #assert num_bank_idxs == t_in_diff.time()
                bank_idxs = [
                    builtins.tuple(
                        int2seq(write_data_per_lane_per_clock.bank,
                                bank_idx_width))
                    for write_data_per_lane_per_clock in
                    banks_write_addr_per_input_lane[lane_idx]
                ]
                write_bank_for_input_lane_luts.append(
                    DefineLUTAnyType(Array[bank_idx_width, Bit], num_bank_idxs,
                                     builtins.tuple(bank_idxs))())

            # for each bank, the address to read from each clock
            read_addr_for_bank_luts = []
            for bank_idx in range(len(rams)):
                ram_addr_width = rams_addr_widths[bank_idx]
                num_addrs = len(output_lane_read_addr_per_bank[bank_idx])
                #assert num_addrs == t_in_diff.time()
                read_addrs = [
                    builtins.tuple(
                        int2seq(read_data_per_bank_per_clock.addr,
                                ram_addr_width))
                    for read_data_per_bank_per_clock in
                    output_lane_read_addr_per_bank[bank_idx]
                ]
                read_addr_for_bank_luts.append(
                    DefineLUTAnyType(Array[ram_addr_width, Bit], num_addrs,
                                     builtins.tuple(read_addrs))())

            # for each bank, the lane to send each read to
            output_lane_for_bank_luts = []
            # number of lanes equals number of banks
            # some the lanes are just always invalid, added so input lane width equals output lane width
            lane_idx_width = getRAMAddrWidth(len(rams))
            for bank_idx in range(len(rams)):
                num_lane_idxs = len(output_lane_read_addr_per_bank[bank_idx])
                #assert num_lane_idxs == t_in_diff.time()
                lane_idxs = [
                    builtins.tuple(
                        int2seq(read_data_per_bank_per_clock.s,
                                lane_idx_width))
                    for read_data_per_bank_per_clock in
                    output_lane_read_addr_per_bank[bank_idx]
                ]
                output_lane_for_bank_luts.append(
                    DefineLUTAnyType(Array[lane_idx_width, Bit], num_lane_idxs,
                                     builtins.tuple(lane_idxs))())

            # second part creates the counters that index into the LUTs
            # elem_per counts time per element of the reshape
            elem_per_reshape_counter = AESizedCounterModM(
                ram_element_type.time(), has_ce=True)
            end_cur_elem = Decode(ram_element_type.time() - 1,
                                  elem_per_reshape_counter.O.N)(
                                      elem_per_reshape_counter.O)
            # reshape counts which element in the reshape
            num_clocks = len(output_lane_read_addr_per_bank[0])
            reshape_write_counter = AESizedCounterModM(num_clocks,
                                                       has_ce=True,
                                                       has_reset=has_reset)
            reshape_read_counter = AESizedCounterModM(num_clocks,
                                                      has_ce=True,
                                                      has_reset=has_reset)

            output_delay = (
                get_output_latencies(graph)[0]) * ram_element_type.time()
            # this is present so testing knows the delay
            cls.output_delay = output_delay
            reshape_read_delay_counter = DefineInitialDelayCounter(
                output_delay, has_ce=True, has_reset=has_reset)()
            # outer counter the repeats the reshape
            #wire(reshape_write_counter.O, cls.reshape_write_counter)

            enabled = DefineCoreirConst(1, 1)().O[0]
            if has_valid:
                enabled = cls.valid_up & enabled
                wire(reshape_read_delay_counter.valid, cls.valid_down)
            if has_ce:
                enabled = bit(cls.CE) & enabled
            wire(enabled, elem_per_reshape_counter.CE)
            wire(enabled, reshape_read_delay_counter.CE)
            wire(enabled & end_cur_elem, reshape_write_counter.CE)
            wire(enabled & end_cur_elem & reshape_read_delay_counter.valid,
                 reshape_read_counter.CE)

            if has_reset:
                wire(cls.RESET, elem_per_reshape_counter.RESET)
                wire(cls.RESET, reshape_read_delay_counter.RESET)
                wire(cls.RESET, reshape_write_counter.RESET)
                wire(cls.RESET, reshape_read_counter.RESET)

            # wire read and write counters to all LUTs
            for lut in write_bank_for_input_lane_luts:
                wire(reshape_write_counter.O, lut.addr)

            for lut in write_addr_for_bank_luts:
                wire(reshape_write_counter.O, lut.addr)

            for lut in write_valid_for_bank_luts:
                wire(reshape_write_counter.O, lut.addr)

            for lut in read_addr_for_bank_luts:
                wire(reshape_read_counter.O, lut.addr)

            for lut in output_lane_for_bank_luts:
                wire(reshape_read_counter.O, lut.addr)

            # third and final instance creation part creates the sorting networks that map lanes to banks
            input_sorting_network_t = Tuple(
                bank=Array[write_bank_for_input_lane_luts[0].data.N, Bit],
                val=ram_element_type.magma_repr())
            input_sorting_network = DefineBitonicSort(input_sorting_network_t,
                                                      len(rams),
                                                      lambda x: x.bank)()

            output_sorting_network_t = Tuple(
                lane=Array[output_lane_for_bank_luts[0].data.N, Bit],
                val=ram_element_type.magma_repr())
            output_sorting_network = DefineBitonicSort(
                output_sorting_network_t, len(rams), lambda x: x.lane)()

            # wire luts, sorting networks, inputs, and rams
            # flatten all the sseq_layers to get flat magma type of inputs and outputs
            # tseqs don't affect magma types
            num_sseq_layers_inputs = num_nested_layers(
                remove_tseqs(shared_and_diff_subtypes.diff_input))
            num_sseq_layers_to_remove_inputs = max(0,
                                                   num_sseq_layers_inputs - 1)
            num_sseq_layers_outputs = num_nested_layers(
                remove_tseqs(shared_and_diff_subtypes.diff_output))
            num_sseq_layers_to_remove_outputs = max(
                0, num_sseq_layers_outputs - 1)
            if remove_tseqs(
                    shared_and_diff_subtypes.shared_outer) != ST_Tombstone():
                #num_sseq_layers_inputs += num_nested_layers(remove_tseqs(shared_and_diff_subtypes.shared_outer))
                #num_sseq_layers_outputs += num_nested_layers(remove_tseqs(shared_and_diff_subtypes.shared_outer))
                input_ports = flatten_ports(
                    transpose_outer_dimensions(
                        shared_and_diff_subtypes.shared_outer,
                        shared_and_diff_subtypes.diff_input, cls.I),
                    num_sseq_layers_to_remove_inputs)
                output_ports = flatten_ports(
                    transpose_outer_dimensions(
                        shared_and_diff_subtypes.shared_outer,
                        shared_and_diff_subtypes.diff_output, cls.O),
                    num_sseq_layers_to_remove_outputs)
            else:
                input_ports = flatten_ports(cls.I,
                                            num_sseq_layers_to_remove_inputs)
                output_ports = flatten_ports(
                    cls.O, num_sseq_layers_to_remove_outputs)
            # this is only used if the shared outer layers contains any sseqs
            sseq_layers_to_flatten = max(
                num_nested_layers(
                    remove_tseqs(shared_and_diff_subtypes.shared_outer)) - 1,
                0)
            for idx in range(len(rams)):
                # wire input and bank to input sorting network
                wire(write_bank_for_input_lane_luts[idx].data,
                     input_sorting_network.I[idx].bank)
                #if idx == 0:
                #    wire(cls.first_valid, write_valid_for_bank_luts[idx].data)
                if idx < t_in_diff.port_width():
                    # since the input_ports are lists, need to wire them individually to the sorting ports
                    if remove_tseqs(shared_and_diff_subtypes.shared_outer
                                    ) != ST_Tombstone():
                        cur_input_port = flatten_ports(input_ports[idx],
                                                       sseq_layers_to_flatten)
                        cur_sort_port = flatten_ports(
                            input_sorting_network.I[idx].val,
                            sseq_layers_to_flatten)
                        for i in range(len(cur_input_port)):
                            wire(cur_input_port[i], cur_sort_port[i])
                    else:
                        if num_sseq_layers_inputs == 0:
                            # input_ports will be an array of bits for 1 element
                            # if no sseq in t_in
                            wire(input_ports, input_sorting_network.I[idx].val)
                        else:
                            wire(input_ports[idx],
                                 input_sorting_network.I[idx].val)
                    #wire(cls.ram_wr, input_sorting_network.O[idx].val)
                    #wire(cls.ram_rd, rams[idx].RDATA)
                else:
                    zero_const = DefineCoreirConst(
                        ram_element_type.magma_repr().size(), 0)().O
                    cur_sn_input = input_sorting_network.I[idx].val
                    while len(cur_sn_input) != len(zero_const):
                        cur_sn_input = cur_sn_input[0]
                    wire(zero_const, cur_sn_input)

                # wire input sorting network, write addr, and write valid luts to banks
                wire(input_sorting_network.O[idx].val, rams[idx].WDATA)
                wire(write_addr_for_bank_luts[idx].data, rams[idx].WADDR)
                #wire(write_addr_for_bank_luts[idx].data[0], cls.addr_wr[idx])
                if has_ce:
                    wire(write_valid_for_bank_luts[idx].data & bit(cls.CE),
                         rams[idx].WE)
                else:
                    wire(write_valid_for_bank_luts[idx].data, rams[idx].WE)

                # wire output sorting network, read addr, read bank, and read enable
                wire(rams[idx].RDATA, output_sorting_network.I[idx].val)
                wire(output_lane_for_bank_luts[idx].data,
                     output_sorting_network.I[idx].lane)
                wire(read_addr_for_bank_luts[idx].data, rams[idx].RADDR)
                #wire(read_addr_for_bank_luts[idx].data[0], cls.addr_rd[idx])
                # ok to read invalid things, so in read value LUT
                if has_ce:
                    wire(bit(cls.CE), rams[idx].RE)
                else:
                    wire(DefineCoreirConst(1, 1)().O[0], rams[idx].RE)
                if has_reset:
                    wire(cls.RESET, rams[idx].RESET)

                # wire output sorting network value to output or term
                if idx < t_out_diff.port_width():
                    # since the output_ports are lists, need to wire them individually to the sorting ports
                    if remove_tseqs(shared_and_diff_subtypes.shared_outer
                                    ) != ST_Tombstone():
                        cur_output_port = flatten_ports(
                            output_ports[idx], sseq_layers_to_flatten)
                        cur_sort_port = flatten_ports(
                            output_sorting_network.O[idx].val,
                            sseq_layers_to_flatten)
                        for i in range(len(cur_output_port)):
                            wire(cur_output_port[i], cur_sort_port[i])
                    else:
                        if num_sseq_layers_outputs == 0:
                            # output_ports will be an array of bits for 1 element
                            # if no sseq in t_out
                            wire(output_sorting_network.O[idx].val,
                                 output_ports)
                        else:
                            wire(output_sorting_network.O[idx].val,
                                 output_ports[idx])
                else:
                    wire(output_sorting_network.O[idx].val,
                         TermAnyType(type(output_sorting_network.O[idx].val)))

                # wire sorting networks bank/lane to term as not used on outputs, just used for sorting
                wire(input_sorting_network.O[idx].bank,
                     TermAnyType(type(input_sorting_network.O[idx].bank)))
                wire(output_sorting_network.O[idx].lane,
                     TermAnyType(type(output_sorting_network.O[idx].lane)))
        def definition(TSBankGenerator):
            flat_idx_width = getRAMAddrWidth(no * ni)
            # next element each time_per_element clock
            if time_per_element > 1:
                index_in_cur_element = SizedCounterModM(time_per_element,
                                                        has_ce=has_ce,
                                                        has_reset=has_reset)
                next_element = Decode(time_per_element - 1,
                                      index_in_cur_element.O.N)(
                                          index_in_cur_element.O)
            else:
                next_element = DefineCoreirConst(1, 1)()
            # each element of the SSeq is a separate vector lane
            first_lane_flat_idx = SizedCounterModM((no + io) * ni,
                                                   incr=ni,
                                                   has_ce=True,
                                                   has_reset=has_reset)()
            time_counter = SizedCounterModM(no + io,
                                            has_ce=True,
                                            has_reset=has_reset)
            wire(next_element.O, first_lane_flat_idx.CE)
            wire(next_element.O, time_counter.CE)
            if has_ce:
                wire(TSBankGenerator.CE, index_in_cur_element.CE)
            if has_reset:
                wire(TSBankGenerator.RESET, index_in_cur_element.RESET)
                wire(TSBankGenerator.RESET, first_lane_flat_idx.RESET)
                wire(TSBankGenerator.RESET, time_counter.RESET)

            lane_flat_idxs = [first_lane_flat_idx.O]

            # compute the current flat_idx for each lane
            for i in range(1, ni):
                cur_lane_flat_idx_adder = DefineAdd(flat_idx_width)()
                wire(cur_lane_flat_idx_adder.I0, first_lane_flat_idx.O)
                wire(cur_lane_flat_idx_adder.I1,
                     DefineCoreirConst(flat_idx_width, i * no)().O)

                lane_flat_idxs += [cur_lane_flat_idx_adder.O]

            lane_flat_div_lcms = []
            # conmpute flat_idx / lcm_dim for each lane
            for i in range(ni):
                cur_lane_lcm_div = DefineUDiv(flat_idx_width)()
                wire(cur_lane_lcm_div.I0, lane_flat_idxs[0].O)
                wire(cur_lane_lcm_div.I1,
                     DefineCoreirConst(lcm(no, ni), flat_idx_width)().O)

                lane_flat_div_lcms += [cur_lane_flat_idx_adder.O]

            # compute ((flat_idx % sseq_dim) + (flat_idx / lcm_dim)) % sseq_dim for each lane
            # note that s_ts == flat_idx % sseq_dim
            # only need to mod sseq_dim at end as that is same as also doing it flat_idx before addition
            for i in range(ni):
                pre_mod_add = DefineAdd(flat_idx_width)()
                wire(pre_mod_add.I0, lane_flat_idxs[i])
                wire(pre_mod_add.I1, lane_flat_div_lcms[i])

                bank_mod = DefineUMod(flat_idx_width)()
                wire(bank_mod.I0, pre_mod_add.O)
                wire(bank_mod.I0, DefineCoreirConst(flat_idx_width, ni)().O)

                wire(TSBankGenerator.bank[i],
                     bank_mod.O[0:TSBankGenerator.bank_width])

            # compute t for each lane addr
            for i in range(0, ni):
                wire(TSBankGenerator.addr[i],
                     time_counter.O[0:TSBankGenerator.addr_width])
        def definition(STBankGenerator):
            flat_idx_width = getRAMAddrWidth(no * ni)
            # next element each time_per_element clock
            if time_per_element > 1:
                index_in_cur_element = SizedCounterModM(time_per_element,
                                                        has_ce=has_ce,
                                                        has_reset=has_reset)
                next_element = Decode(time_per_element - 1,
                                      index_in_cur_element.O.N)(
                                          index_in_cur_element.O)
            else:
                next_element = DefineCoreirConst(1, 1)()
            # each element of the SSeq is a separate vector lane
            first_lane_flat_idx = DefineCounterModM(ni + ii,
                                                    flat_idx_width,
                                                    cout=False,
                                                    has_ce=True,
                                                    has_reset=has_reset)()
            wire(next_element.O[0], first_lane_flat_idx.CE)
            if has_ce:
                wire(STBankGenerator.CE, index_in_cur_element.CE)
            if has_reset:
                wire(STBankGenerator.RESET, index_in_cur_element.RESET)
                wire(STBankGenerator.RESET, first_lane_flat_idx.RESET)

            lane_flat_idxs = [first_lane_flat_idx.O]

            # compute the current flat_idx for each lane
            for i in range(1, no):
                cur_lane_flat_idx_adder = DefineAdd(flat_idx_width)()
                wire(cur_lane_flat_idx_adder.I0, first_lane_flat_idx.O)
                wire(cur_lane_flat_idx_adder.I1,
                     DefineCoreirConst(flat_idx_width, i * ni)().O)

                lane_flat_idxs += [cur_lane_flat_idx_adder.O]

            lane_flat_div_lcms = []
            lcm_dim = DefineCoreirConst(flat_idx_width, lcm(no, ni))()
            # conmpute flat_idx / lcm_dim for each lane
            for i in range(no):
                cur_lane_lcm_div = DefineUDiv(flat_idx_width)()
                wire(cur_lane_lcm_div.I0, lane_flat_idxs[i])
                wire(cur_lane_lcm_div.I1, lcm_dim.O)

                lane_flat_div_lcms += [cur_lane_lcm_div.O]

            # compute ((flat_idx % sseq_dim) + (flat_idx / lcm_dim)) % sseq_dim for each lane
            # only need to mod sseq_dim at end as that is same as also doing it flat_idx before addition
            for i in range(no):
                pre_mod_add = DefineAdd(flat_idx_width)()
                wire(pre_mod_add.I0, lane_flat_idxs[i])
                wire(pre_mod_add.I1, lane_flat_div_lcms[i])

                bank_mod = DefineUMod(flat_idx_width)()
                wire(bank_mod.I0, pre_mod_add.O)
                wire(bank_mod.I1, DefineCoreirConst(flat_idx_width, no)().O)

                wire(STBankGenerator.bank[i],
                     bank_mod.O[0:STBankGenerator.bank_width])
                if len(bank_mod.O) > STBankGenerator.bank_width:
                    bits_to_term = len(bank_mod.O) - STBankGenerator.bank_width
                    term = TermAnyType(Array[bits_to_term, Bit])
                    wire(bank_mod.O[STBankGenerator.bank_width:], term.I)

            # compute flat_idx / sseq_dim for each lane addr
            for i in range(no):
                flat_idx_sseq_dim_div = DefineUDiv(flat_idx_width)()
                wire(flat_idx_sseq_dim_div.I0, lane_flat_idxs[0])
                wire(flat_idx_sseq_dim_div.I1,
                     DefineCoreirConst(flat_idx_width, no)().O)

                wire(STBankGenerator.addr[i],
                     flat_idx_sseq_dim_div.O[0:STBankGenerator.addr_width])
                if len(flat_idx_sseq_dim_div.O) > STBankGenerator.addr_width:
                    bits_to_term = len(bank_mod.O) - STBankGenerator.addr_width
                    term = TermAnyType(Array[bits_to_term, Bit])
                    wire(flat_idx_sseq_dim_div.O[STBankGenerator.addr_width:],
                         term.I)
Esempio n. 10
0
    class _NestedCounters(Circuit):
        name = 'NestedCounters_{}_hasCE{}_hasReset{}'.format(
            cleanName(str(t)), str(has_ce), str(has_reset))
        IO = ['valid', Out(Bit)] + ClockInterface(has_ce=has_ce,
                                                  has_reset=has_reset)
        if has_last:
            IO += ['last', Out(Bit)]
        if has_cur_valid:
            IO += [
                'cur_valid',
                Out(Array[getRAMAddrWidth(t.valid_clocks()), Bit])
            ]

        @classmethod
        def definition(cls):
            if type(t) == ST_TSeq:
                outer_counter = AESizedCounterModM(t.n + t.i,
                                                   has_ce=True,
                                                   has_reset=has_reset)
                inner_counters = DefineNestedCounters(
                    t.t,
                    has_last=True,
                    has_cur_valid=False,
                    has_ce=has_ce,
                    has_reset=has_reset,
                    valid_when_ce_off=valid_when_ce_off)()
                if has_last:
                    is_last = Decode(t.n + t.i - 1,
                                     outer_counter.O.N)(outer_counter.O)
                if has_cur_valid:
                    cur_valid_counter = AESizedCounterModM(t.valid_clocks(),
                                                           has_ce=True,
                                                           has_reset=has_reset)
                    wire(cur_valid_counter.O, cls.cur_valid)

                # if t.n is a power of 2 and always valid, then outer_counter.O.N not enough bits
                # for valid_length to contain t.n and for is_valid to get the right input
                # always valid in this case, so just emit 1
                if math.pow(2, outer_counter.O.N) - 1 < t.n:
                    is_valid = DefineCoreirConst(1, 1)().O[0]
                    if not has_last:
                        # never using the outer_counter is not has_last
                        last_term = TermAnyType(type(outer_counter.O))
                        wire(outer_counter.O, last_term.I)
                else:
                    valid_length = DefineCoreirConst(outer_counter.O.N, t.n)()
                    is_valid_cmp = DefineCoreirUlt(outer_counter.O.N)()
                    wire(is_valid_cmp.I0, outer_counter.O)
                    wire(is_valid_cmp.I1, valid_length.O)
                    is_valid = is_valid_cmp.O

                wire(inner_counters.valid & is_valid, cls.valid)
                if has_last:
                    wire(is_last & inner_counters.last, cls.last)
                if has_reset:
                    wire(cls.RESET, outer_counter.RESET)
                    wire(cls.RESET, inner_counters.RESET)
                    if has_cur_valid:
                        wire(cls.RESET, cur_valid_counter.RESET)
                if has_ce:
                    wire(bit(cls.CE) & inner_counters.last, outer_counter.CE)
                    wire(cls.CE, inner_counters.CE)
                    if has_cur_valid:
                        wire(
                            bit(cls.CE) & inner_counters.valid & is_valid,
                            cur_valid_counter.CE)
                else:
                    wire(inner_counters.last, outer_counter.CE)
                    if has_cur_valid:
                        wire(inner_counters.valid & is_valid,
                             cur_valid_counter.CE)
            elif is_nested(t):
                inner_counters = DefineNestedCounters(
                    t.t,
                    has_last,
                    has_cur_valid,
                    has_ce,
                    has_reset,
                    valid_when_ce_off=valid_when_ce_off)()

                wire(inner_counters.valid, cls.valid)
                if has_last:
                    wire(inner_counters.last, cls.last)
                if has_reset:
                    wire(cls.RESET, inner_counters.RESET)
                if has_ce:
                    wire(cls.CE, inner_counters.CE)
                if has_cur_valid:
                    wire(inner_counters.cur_valid, cls.cur_valid)
            else:
                # only 1 element, so always last and valid element
                valid_and_last = DefineCoreirConst(1, 1)()
                if has_last:
                    wire(valid_and_last.O[0], cls.last)
                if has_cur_valid:
                    cur_valid = DefineCoreirConst(1, 0)()
                    wire(cur_valid.O, cls.cur_valid)
                if has_ce:
                    if valid_when_ce_off:
                        wire(cls.valid, valid_and_last.O[0])
                        ce_term = TermAnyType(Bit)
                        wire(cls.CE, ce_term.I)
                    else:
                        wire(cls.valid, cls.CE)
                else:
                    wire(valid_and_last.O[0], cls.valid)
                if has_reset:
                    reset_term = TermAnyType(Bit)
                    wire(reset_term.I, cls.RESET)