class _Mux(Circuit): name = 'Mux_{}t_{}n'.format(cleanName(str(t)), n) addr_width = getRAMAddrWidth(n) IO = [ 'data', In(Array[n, t]), 'sel', In(Bits[addr_width]), 'out', Out(t) ] @classmethod def definition(cls): if n > 1: type_size_in_bits = GetCoreIRBackend().get_type(t).size mux = CommonlibMuxN(n, type_size_in_bits) type_to_bits = DefineNativeMapParallel(n, DefineDehydrate(t))() wire(cls.data, type_to_bits.I) wire(type_to_bits.out, mux.I.data) bits_to_type = Hydrate(t) wire(mux.out, bits_to_type.I) wire(bits_to_type.out, cls.out) wire(cls.sel, mux.I.sel) else: wire(cls.data[0], cls.out) sel_term = DefineTermAnyType(Bits[cls.addr_width])() wire(cls.sel, sel_term.I)
class _RAM(Circuit): name = 'RAM_{}t_{}n'.format(cleanName(str(t)), n) addr_width = getRAMAddrWidth(n) IO = [ 'RADDR', In(Bits[addr_width]), 'RDATA', Out(t), 'WADDR', In(Bits[addr_width]), 'WDATA', In(t), 'WE', In(Bit) ] + ClockInterface() @classmethod def definition(cls): type_size_in_bits = GetCoreIRBackend().get_type(t).size ram = DefineRAM(n, type_size_in_bits, read_latency=read_latency)() type_to_bits = Dehydrate(t) wire(cls.WDATA, type_to_bits.I) wire(type_to_bits.out, ram.WDATA) bits_to_type = Hydrate(t) wire(ram.RDATA, bits_to_type.I) wire(bits_to_type.out, cls.RDATA) wire(cls.RADDR, ram.RADDR) wire(ram.WADDR, cls.WADDR) wire(cls.WE, ram.WE)
class _ROM(Circuit): name = 'LUT_{}t_{}n'.format(cleanName(str(t)), n) addr_width = getRAMAddrWidth(n) IO = [ 'addr', In(Bits[addr_width]), 'data', Out(t), ] + ClockInterface() @classmethod def definition(cls): type_size_in_bits = GetCoreIRBackend().get_type(t).size bit_luts = [] for i in range(type_size_in_bits): bit_luts += [ DefineLUT(seq2int([el[i] for el in init]), getRAMAddrWidth(n))() ] bits_to_type = Hydrate(t) for i in range(type_size_in_bits): wire(bit_luts[i].O, bits_to_type.I[i]) wire(bits_to_type.out, cls.data) for i in range(type_size_in_bits): for j in range(cls.addr_width): wire(cls.addr[j], getattr(bit_luts[i], "I" + str(j)))
def definition(cls): type_size_in_bits = GetCoreIRBackend().get_type(t).size bit_luts = [] for i in range(type_size_in_bits): bit_luts += [ DefineLUT(seq2int([el[i] for el in init]), getRAMAddrWidth(n))() ] bits_to_type = Hydrate(t) for i in range(type_size_in_bits): wire(bit_luts[i].O, bits_to_type.I[i]) wire(bits_to_type.out, cls.data) for i in range(type_size_in_bits): for j in range(cls.addr_width): wire(cls.addr[j], getattr(bit_luts[i], "I" + str(j)))
class _RAM_ST(Circuit): name = 'RAM_ST_{}_hasReset{}'.format(cleanName(str(t)), str(has_reset)) addr_width = getRAMAddrWidth(n) IO = ['RADDR', In(Bits[addr_width]), 'RDATA', Out(t.magma_repr()), 'WADDR', In(Bits[addr_width]), 'WDATA', In(t.magma_repr()), 'WE', In(Bit), 'RE', In(Bit) ] + ClockInterface(has_ce=False, has_reset=has_reset) @classmethod def definition(cls): # each valid clock, going to get a magma_repr in # read or write each one of those to a location rams = [DefineRAMAnyType(t.magma_repr(), t.valid_clocks(), read_latency=read_latency)() for _ in range(n)] read_time_position_counter = DefineNestedCounters(t, has_cur_valid=True, has_ce=True, has_reset=has_reset)() read_valid_term = TermAnyType(Bit) read_last_term = TermAnyType(Bit) write_time_position_counter = DefineNestedCounters(t, has_cur_valid=True, has_ce=True, has_reset=has_reset)() write_valid_term = TermAnyType(Bit) write_last_term = TermAnyType(Bit) read_selector = DefineMuxAnyType(t.magma_repr(), n)() for i in range(n): wire(cls.WDATA, rams[i].WDATA) wire(write_time_position_counter.cur_valid, rams[i].WADDR) wire(read_selector.data[i], rams[i].RDATA) wire(read_time_position_counter.cur_valid, rams[i].RADDR) write_cur_ram = Decode(i, cls.WADDR.N)(cls.WADDR) wire(write_cur_ram & write_time_position_counter.valid, rams[i].WE) wire(cls.RADDR, read_selector.sel) wire(cls.RDATA, read_selector.out) wire(cls.WE, write_time_position_counter.CE) wire(cls.RE, read_time_position_counter.CE) wire(read_time_position_counter.valid, read_valid_term.I) wire(read_time_position_counter.last, read_last_term.I) wire(write_time_position_counter.valid, write_valid_term.I) wire(write_time_position_counter.last, write_last_term.I) if has_reset: wire(cls.RESET, write_time_position_counter.RESET) wire(cls.RESET, read_time_position_counter.RESET)
class _ROM(Circuit): name = 'ROM_{}t_{}n'.format(cleanName(str(t)), n) addr_width = getRAMAddrWidth(n) IO = ['RADDR', In(Bits[addr_width]), 'RDATA', Out(t), 'REN', In(Bit) ] + ClockInterface() @classmethod def definition(cls): type_size_in_bits = GetCoreIRBackend().get_type(t).size rom = DefineROM(n, type_size_in_bits)(coreir_configargs={"init": [list(el) for el in init]}) bits_to_type = Hydrate(t) wire(rom.rdata, bits_to_type.I) wire(bits_to_type.out, cls.RDATA) wire(cls.RADDR, rom.raddr) wire(cls.REN, rom.ren)
def definition(cls): # first section creates the RAMs and LUTs that set values in them and the sorting network shared_and_diff_subtypes = get_shared_and_diff_subtypes( t_in, t_out) t_in_diff = shared_and_diff_subtypes.diff_input t_out_diff = shared_and_diff_subtypes.diff_output graph = build_permutation_graph(ST_TSeq(2, 0, t_in_diff), ST_TSeq(2, 0, t_out_diff)) banks_write_addr_per_input_lane = get_banks_addr_per_lane( graph.input_nodes) input_lane_write_addr_per_bank = get_lane_addr_per_banks( graph.input_nodes) output_lane_read_addr_per_bank = get_lane_addr_per_banks( graph.output_nodes) # each ram only needs to be large enough to handle the number of addresses assigned to it # all rams receive the same number of writes # but some of those writes don't happen as the data is invalid, so don't need storage for them max_ram_addrs = [ max([bank_clock_data.addr for bank_clock_data in bank_data]) for bank_data in output_lane_read_addr_per_bank ] # rams also handle parallelism from outer_shared type as this affects all banks the same outer_shared_sseqs = remove_tseqs( shared_and_diff_subtypes.shared_outer) if outer_shared_sseqs == ST_Tombstone(): ram_element_type = shared_and_diff_subtypes.shared_inner else: ram_element_type = replace_tombstone( outer_shared_sseqs, shared_and_diff_subtypes.shared_inner) # can use wider rams rather than duplicate for outer_shared_sseqs because will # transpose dimenions of input wires below to wire up as if outer, shared dimensions # were on the inside rams = [ DefineRAM_ST(ram_element_type, ram_max_addr + 1)() for ram_max_addr in max_ram_addrs ] rams_addr_widths = [ram.WADDR.N for ram in rams] # for bank, the addresses to write to each clock write_addr_for_bank_luts = [] for bank_idx in range(len(rams)): ram_addr_width = rams_addr_widths[bank_idx] num_addrs = len(input_lane_write_addr_per_bank[bank_idx]) #assert num_addrs == t_in_diff.time() write_addrs = [ builtins.tuple( int2seq(write_data_per_bank_per_clock.addr, ram_addr_width)) for write_data_per_bank_per_clock in input_lane_write_addr_per_bank[bank_idx] ] write_addr_for_bank_luts.append( DefineLUTAnyType(Array[ram_addr_width, Bit], num_addrs, builtins.tuple(write_addrs))()) # for bank, whether to actually write this clock write_valid_for_bank_luts = [] for bank_idx in range(len(rams)): num_valids = len(input_lane_write_addr_per_bank[bank_idx]) #assert num_valids == t_in_diff.time() valids = [ builtins.tuple([write_data_per_bank_per_clock.valid]) for write_data_per_bank_per_clock in input_lane_write_addr_per_bank[bank_idx] ] write_valid_for_bank_luts.append( DefineLUTAnyType(Bit, num_valids, builtins.tuple(valids))()) # for each input lane, the bank to write to each clock write_bank_for_input_lane_luts = [] bank_idx_width = getRAMAddrWidth(len(rams)) for lane_idx in range(len(banks_write_addr_per_input_lane)): num_bank_idxs = len(banks_write_addr_per_input_lane[lane_idx]) #assert num_bank_idxs == t_in_diff.time() bank_idxs = [ builtins.tuple( int2seq(write_data_per_lane_per_clock.bank, bank_idx_width)) for write_data_per_lane_per_clock in banks_write_addr_per_input_lane[lane_idx] ] write_bank_for_input_lane_luts.append( DefineLUTAnyType(Array[bank_idx_width, Bit], num_bank_idxs, builtins.tuple(bank_idxs))()) # for each bank, the address to read from each clock read_addr_for_bank_luts = [] for bank_idx in range(len(rams)): ram_addr_width = rams_addr_widths[bank_idx] num_addrs = len(output_lane_read_addr_per_bank[bank_idx]) #assert num_addrs == t_in_diff.time() read_addrs = [ builtins.tuple( int2seq(read_data_per_bank_per_clock.addr, ram_addr_width)) for read_data_per_bank_per_clock in output_lane_read_addr_per_bank[bank_idx] ] read_addr_for_bank_luts.append( DefineLUTAnyType(Array[ram_addr_width, Bit], num_addrs, builtins.tuple(read_addrs))()) # for each bank, the lane to send each read to output_lane_for_bank_luts = [] # number of lanes equals number of banks # some the lanes are just always invalid, added so input lane width equals output lane width lane_idx_width = getRAMAddrWidth(len(rams)) for bank_idx in range(len(rams)): num_lane_idxs = len(output_lane_read_addr_per_bank[bank_idx]) #assert num_lane_idxs == t_in_diff.time() lane_idxs = [ builtins.tuple( int2seq(read_data_per_bank_per_clock.s, lane_idx_width)) for read_data_per_bank_per_clock in output_lane_read_addr_per_bank[bank_idx] ] output_lane_for_bank_luts.append( DefineLUTAnyType(Array[lane_idx_width, Bit], num_lane_idxs, builtins.tuple(lane_idxs))()) # second part creates the counters that index into the LUTs # elem_per counts time per element of the reshape elem_per_reshape_counter = AESizedCounterModM( ram_element_type.time(), has_ce=True) end_cur_elem = Decode(ram_element_type.time() - 1, elem_per_reshape_counter.O.N)( elem_per_reshape_counter.O) # reshape counts which element in the reshape num_clocks = len(output_lane_read_addr_per_bank[0]) reshape_write_counter = AESizedCounterModM(num_clocks, has_ce=True, has_reset=has_reset) reshape_read_counter = AESizedCounterModM(num_clocks, has_ce=True, has_reset=has_reset) output_delay = ( get_output_latencies(graph)[0]) * ram_element_type.time() # this is present so testing knows the delay cls.output_delay = output_delay reshape_read_delay_counter = DefineInitialDelayCounter( output_delay, has_ce=True, has_reset=has_reset)() # outer counter the repeats the reshape #wire(reshape_write_counter.O, cls.reshape_write_counter) enabled = DefineCoreirConst(1, 1)().O[0] if has_valid: enabled = cls.valid_up & enabled wire(reshape_read_delay_counter.valid, cls.valid_down) if has_ce: enabled = bit(cls.CE) & enabled wire(enabled, elem_per_reshape_counter.CE) wire(enabled, reshape_read_delay_counter.CE) wire(enabled & end_cur_elem, reshape_write_counter.CE) wire(enabled & end_cur_elem & reshape_read_delay_counter.valid, reshape_read_counter.CE) if has_reset: wire(cls.RESET, elem_per_reshape_counter.RESET) wire(cls.RESET, reshape_read_delay_counter.RESET) wire(cls.RESET, reshape_write_counter.RESET) wire(cls.RESET, reshape_read_counter.RESET) # wire read and write counters to all LUTs for lut in write_bank_for_input_lane_luts: wire(reshape_write_counter.O, lut.addr) for lut in write_addr_for_bank_luts: wire(reshape_write_counter.O, lut.addr) for lut in write_valid_for_bank_luts: wire(reshape_write_counter.O, lut.addr) for lut in read_addr_for_bank_luts: wire(reshape_read_counter.O, lut.addr) for lut in output_lane_for_bank_luts: wire(reshape_read_counter.O, lut.addr) # third and final instance creation part creates the sorting networks that map lanes to banks input_sorting_network_t = Tuple( bank=Array[write_bank_for_input_lane_luts[0].data.N, Bit], val=ram_element_type.magma_repr()) input_sorting_network = DefineBitonicSort(input_sorting_network_t, len(rams), lambda x: x.bank)() output_sorting_network_t = Tuple( lane=Array[output_lane_for_bank_luts[0].data.N, Bit], val=ram_element_type.magma_repr()) output_sorting_network = DefineBitonicSort( output_sorting_network_t, len(rams), lambda x: x.lane)() # wire luts, sorting networks, inputs, and rams # flatten all the sseq_layers to get flat magma type of inputs and outputs # tseqs don't affect magma types num_sseq_layers_inputs = num_nested_layers( remove_tseqs(shared_and_diff_subtypes.diff_input)) num_sseq_layers_to_remove_inputs = max(0, num_sseq_layers_inputs - 1) num_sseq_layers_outputs = num_nested_layers( remove_tseqs(shared_and_diff_subtypes.diff_output)) num_sseq_layers_to_remove_outputs = max( 0, num_sseq_layers_outputs - 1) if remove_tseqs( shared_and_diff_subtypes.shared_outer) != ST_Tombstone(): #num_sseq_layers_inputs += num_nested_layers(remove_tseqs(shared_and_diff_subtypes.shared_outer)) #num_sseq_layers_outputs += num_nested_layers(remove_tseqs(shared_and_diff_subtypes.shared_outer)) input_ports = flatten_ports( transpose_outer_dimensions( shared_and_diff_subtypes.shared_outer, shared_and_diff_subtypes.diff_input, cls.I), num_sseq_layers_to_remove_inputs) output_ports = flatten_ports( transpose_outer_dimensions( shared_and_diff_subtypes.shared_outer, shared_and_diff_subtypes.diff_output, cls.O), num_sseq_layers_to_remove_outputs) else: input_ports = flatten_ports(cls.I, num_sseq_layers_to_remove_inputs) output_ports = flatten_ports( cls.O, num_sseq_layers_to_remove_outputs) # this is only used if the shared outer layers contains any sseqs sseq_layers_to_flatten = max( num_nested_layers( remove_tseqs(shared_and_diff_subtypes.shared_outer)) - 1, 0) for idx in range(len(rams)): # wire input and bank to input sorting network wire(write_bank_for_input_lane_luts[idx].data, input_sorting_network.I[idx].bank) #if idx == 0: # wire(cls.first_valid, write_valid_for_bank_luts[idx].data) if idx < t_in_diff.port_width(): # since the input_ports are lists, need to wire them individually to the sorting ports if remove_tseqs(shared_and_diff_subtypes.shared_outer ) != ST_Tombstone(): cur_input_port = flatten_ports(input_ports[idx], sseq_layers_to_flatten) cur_sort_port = flatten_ports( input_sorting_network.I[idx].val, sseq_layers_to_flatten) for i in range(len(cur_input_port)): wire(cur_input_port[i], cur_sort_port[i]) else: if num_sseq_layers_inputs == 0: # input_ports will be an array of bits for 1 element # if no sseq in t_in wire(input_ports, input_sorting_network.I[idx].val) else: wire(input_ports[idx], input_sorting_network.I[idx].val) #wire(cls.ram_wr, input_sorting_network.O[idx].val) #wire(cls.ram_rd, rams[idx].RDATA) else: zero_const = DefineCoreirConst( ram_element_type.magma_repr().size(), 0)().O cur_sn_input = input_sorting_network.I[idx].val while len(cur_sn_input) != len(zero_const): cur_sn_input = cur_sn_input[0] wire(zero_const, cur_sn_input) # wire input sorting network, write addr, and write valid luts to banks wire(input_sorting_network.O[idx].val, rams[idx].WDATA) wire(write_addr_for_bank_luts[idx].data, rams[idx].WADDR) #wire(write_addr_for_bank_luts[idx].data[0], cls.addr_wr[idx]) if has_ce: wire(write_valid_for_bank_luts[idx].data & bit(cls.CE), rams[idx].WE) else: wire(write_valid_for_bank_luts[idx].data, rams[idx].WE) # wire output sorting network, read addr, read bank, and read enable wire(rams[idx].RDATA, output_sorting_network.I[idx].val) wire(output_lane_for_bank_luts[idx].data, output_sorting_network.I[idx].lane) wire(read_addr_for_bank_luts[idx].data, rams[idx].RADDR) #wire(read_addr_for_bank_luts[idx].data[0], cls.addr_rd[idx]) # ok to read invalid things, so in read value LUT if has_ce: wire(bit(cls.CE), rams[idx].RE) else: wire(DefineCoreirConst(1, 1)().O[0], rams[idx].RE) if has_reset: wire(cls.RESET, rams[idx].RESET) # wire output sorting network value to output or term if idx < t_out_diff.port_width(): # since the output_ports are lists, need to wire them individually to the sorting ports if remove_tseqs(shared_and_diff_subtypes.shared_outer ) != ST_Tombstone(): cur_output_port = flatten_ports( output_ports[idx], sseq_layers_to_flatten) cur_sort_port = flatten_ports( output_sorting_network.O[idx].val, sseq_layers_to_flatten) for i in range(len(cur_output_port)): wire(cur_output_port[i], cur_sort_port[i]) else: if num_sseq_layers_outputs == 0: # output_ports will be an array of bits for 1 element # if no sseq in t_out wire(output_sorting_network.O[idx].val, output_ports) else: wire(output_sorting_network.O[idx].val, output_ports[idx]) else: wire(output_sorting_network.O[idx].val, TermAnyType(type(output_sorting_network.O[idx].val))) # wire sorting networks bank/lane to term as not used on outputs, just used for sorting wire(input_sorting_network.O[idx].bank, TermAnyType(type(input_sorting_network.O[idx].bank))) wire(output_sorting_network.O[idx].lane, TermAnyType(type(output_sorting_network.O[idx].lane)))
def definition(TSBankGenerator): flat_idx_width = getRAMAddrWidth(no * ni) # next element each time_per_element clock if time_per_element > 1: index_in_cur_element = SizedCounterModM(time_per_element, has_ce=has_ce, has_reset=has_reset) next_element = Decode(time_per_element - 1, index_in_cur_element.O.N)( index_in_cur_element.O) else: next_element = DefineCoreirConst(1, 1)() # each element of the SSeq is a separate vector lane first_lane_flat_idx = SizedCounterModM((no + io) * ni, incr=ni, has_ce=True, has_reset=has_reset)() time_counter = SizedCounterModM(no + io, has_ce=True, has_reset=has_reset) wire(next_element.O, first_lane_flat_idx.CE) wire(next_element.O, time_counter.CE) if has_ce: wire(TSBankGenerator.CE, index_in_cur_element.CE) if has_reset: wire(TSBankGenerator.RESET, index_in_cur_element.RESET) wire(TSBankGenerator.RESET, first_lane_flat_idx.RESET) wire(TSBankGenerator.RESET, time_counter.RESET) lane_flat_idxs = [first_lane_flat_idx.O] # compute the current flat_idx for each lane for i in range(1, ni): cur_lane_flat_idx_adder = DefineAdd(flat_idx_width)() wire(cur_lane_flat_idx_adder.I0, first_lane_flat_idx.O) wire(cur_lane_flat_idx_adder.I1, DefineCoreirConst(flat_idx_width, i * no)().O) lane_flat_idxs += [cur_lane_flat_idx_adder.O] lane_flat_div_lcms = [] # conmpute flat_idx / lcm_dim for each lane for i in range(ni): cur_lane_lcm_div = DefineUDiv(flat_idx_width)() wire(cur_lane_lcm_div.I0, lane_flat_idxs[0].O) wire(cur_lane_lcm_div.I1, DefineCoreirConst(lcm(no, ni), flat_idx_width)().O) lane_flat_div_lcms += [cur_lane_flat_idx_adder.O] # compute ((flat_idx % sseq_dim) + (flat_idx / lcm_dim)) % sseq_dim for each lane # note that s_ts == flat_idx % sseq_dim # only need to mod sseq_dim at end as that is same as also doing it flat_idx before addition for i in range(ni): pre_mod_add = DefineAdd(flat_idx_width)() wire(pre_mod_add.I0, lane_flat_idxs[i]) wire(pre_mod_add.I1, lane_flat_div_lcms[i]) bank_mod = DefineUMod(flat_idx_width)() wire(bank_mod.I0, pre_mod_add.O) wire(bank_mod.I0, DefineCoreirConst(flat_idx_width, ni)().O) wire(TSBankGenerator.bank[i], bank_mod.O[0:TSBankGenerator.bank_width]) # compute t for each lane addr for i in range(0, ni): wire(TSBankGenerator.addr[i], time_counter.O[0:TSBankGenerator.addr_width])
def definition(STBankGenerator): flat_idx_width = getRAMAddrWidth(no * ni) # next element each time_per_element clock if time_per_element > 1: index_in_cur_element = SizedCounterModM(time_per_element, has_ce=has_ce, has_reset=has_reset) next_element = Decode(time_per_element - 1, index_in_cur_element.O.N)( index_in_cur_element.O) else: next_element = DefineCoreirConst(1, 1)() # each element of the SSeq is a separate vector lane first_lane_flat_idx = DefineCounterModM(ni + ii, flat_idx_width, cout=False, has_ce=True, has_reset=has_reset)() wire(next_element.O[0], first_lane_flat_idx.CE) if has_ce: wire(STBankGenerator.CE, index_in_cur_element.CE) if has_reset: wire(STBankGenerator.RESET, index_in_cur_element.RESET) wire(STBankGenerator.RESET, first_lane_flat_idx.RESET) lane_flat_idxs = [first_lane_flat_idx.O] # compute the current flat_idx for each lane for i in range(1, no): cur_lane_flat_idx_adder = DefineAdd(flat_idx_width)() wire(cur_lane_flat_idx_adder.I0, first_lane_flat_idx.O) wire(cur_lane_flat_idx_adder.I1, DefineCoreirConst(flat_idx_width, i * ni)().O) lane_flat_idxs += [cur_lane_flat_idx_adder.O] lane_flat_div_lcms = [] lcm_dim = DefineCoreirConst(flat_idx_width, lcm(no, ni))() # conmpute flat_idx / lcm_dim for each lane for i in range(no): cur_lane_lcm_div = DefineUDiv(flat_idx_width)() wire(cur_lane_lcm_div.I0, lane_flat_idxs[i]) wire(cur_lane_lcm_div.I1, lcm_dim.O) lane_flat_div_lcms += [cur_lane_lcm_div.O] # compute ((flat_idx % sseq_dim) + (flat_idx / lcm_dim)) % sseq_dim for each lane # only need to mod sseq_dim at end as that is same as also doing it flat_idx before addition for i in range(no): pre_mod_add = DefineAdd(flat_idx_width)() wire(pre_mod_add.I0, lane_flat_idxs[i]) wire(pre_mod_add.I1, lane_flat_div_lcms[i]) bank_mod = DefineUMod(flat_idx_width)() wire(bank_mod.I0, pre_mod_add.O) wire(bank_mod.I1, DefineCoreirConst(flat_idx_width, no)().O) wire(STBankGenerator.bank[i], bank_mod.O[0:STBankGenerator.bank_width]) if len(bank_mod.O) > STBankGenerator.bank_width: bits_to_term = len(bank_mod.O) - STBankGenerator.bank_width term = TermAnyType(Array[bits_to_term, Bit]) wire(bank_mod.O[STBankGenerator.bank_width:], term.I) # compute flat_idx / sseq_dim for each lane addr for i in range(no): flat_idx_sseq_dim_div = DefineUDiv(flat_idx_width)() wire(flat_idx_sseq_dim_div.I0, lane_flat_idxs[0]) wire(flat_idx_sseq_dim_div.I1, DefineCoreirConst(flat_idx_width, no)().O) wire(STBankGenerator.addr[i], flat_idx_sseq_dim_div.O[0:STBankGenerator.addr_width]) if len(flat_idx_sseq_dim_div.O) > STBankGenerator.addr_width: bits_to_term = len(bank_mod.O) - STBankGenerator.addr_width term = TermAnyType(Array[bits_to_term, Bit]) wire(flat_idx_sseq_dim_div.O[STBankGenerator.addr_width:], term.I)
class _NestedCounters(Circuit): name = 'NestedCounters_{}_hasCE{}_hasReset{}'.format( cleanName(str(t)), str(has_ce), str(has_reset)) IO = ['valid', Out(Bit)] + ClockInterface(has_ce=has_ce, has_reset=has_reset) if has_last: IO += ['last', Out(Bit)] if has_cur_valid: IO += [ 'cur_valid', Out(Array[getRAMAddrWidth(t.valid_clocks()), Bit]) ] @classmethod def definition(cls): if type(t) == ST_TSeq: outer_counter = AESizedCounterModM(t.n + t.i, has_ce=True, has_reset=has_reset) inner_counters = DefineNestedCounters( t.t, has_last=True, has_cur_valid=False, has_ce=has_ce, has_reset=has_reset, valid_when_ce_off=valid_when_ce_off)() if has_last: is_last = Decode(t.n + t.i - 1, outer_counter.O.N)(outer_counter.O) if has_cur_valid: cur_valid_counter = AESizedCounterModM(t.valid_clocks(), has_ce=True, has_reset=has_reset) wire(cur_valid_counter.O, cls.cur_valid) # if t.n is a power of 2 and always valid, then outer_counter.O.N not enough bits # for valid_length to contain t.n and for is_valid to get the right input # always valid in this case, so just emit 1 if math.pow(2, outer_counter.O.N) - 1 < t.n: is_valid = DefineCoreirConst(1, 1)().O[0] if not has_last: # never using the outer_counter is not has_last last_term = TermAnyType(type(outer_counter.O)) wire(outer_counter.O, last_term.I) else: valid_length = DefineCoreirConst(outer_counter.O.N, t.n)() is_valid_cmp = DefineCoreirUlt(outer_counter.O.N)() wire(is_valid_cmp.I0, outer_counter.O) wire(is_valid_cmp.I1, valid_length.O) is_valid = is_valid_cmp.O wire(inner_counters.valid & is_valid, cls.valid) if has_last: wire(is_last & inner_counters.last, cls.last) if has_reset: wire(cls.RESET, outer_counter.RESET) wire(cls.RESET, inner_counters.RESET) if has_cur_valid: wire(cls.RESET, cur_valid_counter.RESET) if has_ce: wire(bit(cls.CE) & inner_counters.last, outer_counter.CE) wire(cls.CE, inner_counters.CE) if has_cur_valid: wire( bit(cls.CE) & inner_counters.valid & is_valid, cur_valid_counter.CE) else: wire(inner_counters.last, outer_counter.CE) if has_cur_valid: wire(inner_counters.valid & is_valid, cur_valid_counter.CE) elif is_nested(t): inner_counters = DefineNestedCounters( t.t, has_last, has_cur_valid, has_ce, has_reset, valid_when_ce_off=valid_when_ce_off)() wire(inner_counters.valid, cls.valid) if has_last: wire(inner_counters.last, cls.last) if has_reset: wire(cls.RESET, inner_counters.RESET) if has_ce: wire(cls.CE, inner_counters.CE) if has_cur_valid: wire(inner_counters.cur_valid, cls.cur_valid) else: # only 1 element, so always last and valid element valid_and_last = DefineCoreirConst(1, 1)() if has_last: wire(valid_and_last.O[0], cls.last) if has_cur_valid: cur_valid = DefineCoreirConst(1, 0)() wire(cur_valid.O, cls.cur_valid) if has_ce: if valid_when_ce_off: wire(cls.valid, valid_and_last.O[0]) ce_term = TermAnyType(Bit) wire(cls.CE, ce_term.I) else: wire(cls.valid, cls.CE) else: wire(valid_and_last.O[0], cls.valid) if has_reset: reset_term = TermAnyType(Bit) wire(reset_term.I, cls.RESET)