def test_tseq_2_0_sseq_50_stuple_3_to_tseq_2_0_sseq_50_sseq_3(): input_type = ST_TSeq(2, 0, ST_SSeq(50, ST_SSeq_Tuple(3, ST_Int()))) output_type = ST_TSeq(2, 0, ST_SSeq(50, ST_SSeq(3, ST_Int()))) graph = build_permutation_graph(input_type, output_type) testcircuit = DefineReshape_ST(input_type, output_type) tester = fault.Tester(testcircuit, testcircuit.CLK) check_reshape(graph, 2, testcircuit.output_delay, tester, 1, 1)
def test_tseq_2_0_tseq_1_2_sseq_1_to_tseq_1_5_sseq_2(): input_type = ST_TSeq(2, 0, ST_TSeq(1, 2, ST_SSeq(1, ST_Int()))) output_type = ST_TSeq(1, 5, ST_SSeq(2, ST_Int())) graph = build_permutation_graph(input_type, output_type) testcircuit = DefineReshape_ST(input_type, output_type) tester = fault.Tester(testcircuit, testcircuit.CLK) check_reshape(graph, 2, testcircuit.output_delay, tester, 0, 0)
def test_shared_outer_sseq_2_tseq_3_1_diff_2_3_shared_inner_sseq_2_tseq_3_3_flip_reshape( ): """ Tests reshape with both sseq and tseq on inside and outside """ no = 3 io = 0 ni = 2 nii = 2 niii = 3 iiii = 3 input_type = ST_SSeq( 2, ST_TSeq( 3, 1, ST_TSeq(no, io, ST_SSeq(ni, ST_SSeq(nii, ST_TSeq(niii, iiii, ST_Int())))))) output_type = ST_SSeq( 2, ST_TSeq( 3, 1, ST_SSeq( ni, ST_TSeq(no, io, ST_SSeq(nii, ST_TSeq(niii, iiii, ST_Int())))))) graph = build_permutation_graph(input_type, output_type) testcircuit = DefineReshape_ST(input_type, output_type) tester = fault.Tester(testcircuit, testcircuit.CLK) check_reshape(graph, 2, testcircuit.output_delay, tester, 2, 2)
def test_sseq_2_tseq_2_to_sseq_4_tseq_1_reshape(): """ Tests reshape with different input and output ports """ input_type = ST_SSeq(2, ST_TSeq(2, 0, ST_Int())) output_type = ST_SSeq(4, ST_TSeq(1, 1, ST_Int())) graph = build_permutation_graph(input_type, output_type) testcircuit = DefineReshape_ST(input_type, output_type) tester = fault.Tester(testcircuit, testcircuit.CLK) check_reshape(graph, 2, testcircuit.output_delay, tester, 0, 0)
def test_diff_sseq_on_diff_component_of_type(): # i would need to improve test bench for this to actually be tested # just here to verify no errors on wiring. Need to manually check that for now. input_type = ST_TSeq( 4, 0, ST_TSeq(1, 1, ST_TSeq(4, 0, ST_SSeq(1, ST_SSeq(1, ST_Int()))))) output_type = ST_TSeq( 4, 4, ST_SSeq(1, ST_TSeq(4, 0, ST_SSeq(1, ST_SSeq(1, ST_Int()))))) graph = build_permutation_graph(input_type, output_type) testcircuit = DefineReshape_ST(input_type, output_type) tester = fault.Tester(testcircuit, testcircuit.CLK) tester.step(2) compile_and_run(tester)
def test_2_3_flip_reshape(): """ Tests the most basic flip """ t_len = 3 s_len = 2 input_type = ST_SSeq(s_len, ST_TSeq(t_len, 0, ST_Int())) output_type = ST_TSeq(t_len, 0, ST_SSeq(s_len, ST_Int())) graph = build_permutation_graph(input_type, output_type) testcircuit = DefineReshape_ST(input_type, output_type) tester = fault.Tester(testcircuit, testcircuit.CLK) check_reshape(graph, 2, testcircuit.output_delay, tester, 0, 0)
def test_shared_sseq_2_2_3_flip_reshape(): """ Tests flip with a shared sseq on the outside """ no = 2 ni = 3 ii = 0 nii = 2 input_type = ST_SSeq(no, ST_TSeq(ni, ii, ST_SSeq(nii, ST_Int()))) output_type = ST_SSeq(no, ST_SSeq(nii, ST_TSeq(ni, ii, ST_Int()))) graph = build_permutation_graph(input_type, output_type) testcircuit = DefineReshape_ST(input_type, output_type) tester = fault.Tester(testcircuit, testcircuit.CLK) check_reshape(graph, 2, testcircuit.output_delay, tester, 1, 1)
def test_tseq_3_6_tseq_1_2_to_tseq_3_24(): input_type = ST_TSeq(3, 6, ST_TSeq(1, 2, ST_Int())) output_type = ST_TSeq(3, 24, ST_Int()) graph = build_permutation_graph(input_type, output_type) testcircuit = DefineReshape_ST(input_type, output_type) tester = fault.Tester(testcircuit, testcircuit.CLK) check_reshape(graph, 2, testcircuit.output_delay, tester, 0, 0, input_port_iterable=False, output_port_iterable=False)
def test_2_2_3_flip_reshape(): """ Tests the flip where moving two sseqs at the same time """ t_len = 3 s_len_0 = 2 s_len_1 = 2 input_type = ST_SSeq(s_len_0, ST_SSeq(s_len_1, ST_TSeq(t_len, 0, ST_Int()))) output_type = ST_TSeq(t_len, 0, ST_SSeq(s_len_0, ST_SSeq(s_len_1, ST_Int()))) graph = build_permutation_graph(input_type, output_type) testcircuit = DefineReshape_ST(input_type, output_type) tester = fault.Tester(testcircuit, testcircuit.CLK) check_reshape(graph, 2, testcircuit.output_delay, tester, 1, 1)
def definition(cls): # first section creates the RAMs and LUTs that set values in them and the sorting network shared_and_diff_subtypes = get_shared_and_diff_subtypes( t_in, t_out) t_in_diff = shared_and_diff_subtypes.diff_input t_out_diff = shared_and_diff_subtypes.diff_output graph = build_permutation_graph(ST_TSeq(2, 0, t_in_diff), ST_TSeq(2, 0, t_out_diff)) banks_write_addr_per_input_lane = get_banks_addr_per_lane( graph.input_nodes) input_lane_write_addr_per_bank = get_lane_addr_per_banks( graph.input_nodes) output_lane_read_addr_per_bank = get_lane_addr_per_banks( graph.output_nodes) # each ram only needs to be large enough to handle the number of addresses assigned to it # all rams receive the same number of writes # but some of those writes don't happen as the data is invalid, so don't need storage for them max_ram_addrs = [ max([bank_clock_data.addr for bank_clock_data in bank_data]) for bank_data in output_lane_read_addr_per_bank ] # rams also handle parallelism from outer_shared type as this affects all banks the same outer_shared_sseqs = remove_tseqs( shared_and_diff_subtypes.shared_outer) if outer_shared_sseqs == ST_Tombstone(): ram_element_type = shared_and_diff_subtypes.shared_inner else: ram_element_type = replace_tombstone( outer_shared_sseqs, shared_and_diff_subtypes.shared_inner) # can use wider rams rather than duplicate for outer_shared_sseqs because will # transpose dimenions of input wires below to wire up as if outer, shared dimensions # were on the inside rams = [ DefineRAM_ST(ram_element_type, ram_max_addr + 1)() for ram_max_addr in max_ram_addrs ] rams_addr_widths = [ram.WADDR.N for ram in rams] # for bank, the addresses to write to each clock write_addr_for_bank_luts = [] for bank_idx in range(len(rams)): ram_addr_width = rams_addr_widths[bank_idx] num_addrs = len(input_lane_write_addr_per_bank[bank_idx]) #assert num_addrs == t_in_diff.time() write_addrs = [ builtins.tuple( int2seq(write_data_per_bank_per_clock.addr, ram_addr_width)) for write_data_per_bank_per_clock in input_lane_write_addr_per_bank[bank_idx] ] write_addr_for_bank_luts.append( DefineLUTAnyType(Array[ram_addr_width, Bit], num_addrs, builtins.tuple(write_addrs))()) # for bank, whether to actually write this clock write_valid_for_bank_luts = [] for bank_idx in range(len(rams)): num_valids = len(input_lane_write_addr_per_bank[bank_idx]) #assert num_valids == t_in_diff.time() valids = [ builtins.tuple([write_data_per_bank_per_clock.valid]) for write_data_per_bank_per_clock in input_lane_write_addr_per_bank[bank_idx] ] write_valid_for_bank_luts.append( DefineLUTAnyType(Bit, num_valids, builtins.tuple(valids))()) # for each input lane, the bank to write to each clock write_bank_for_input_lane_luts = [] bank_idx_width = getRAMAddrWidth(len(rams)) for lane_idx in range(len(banks_write_addr_per_input_lane)): num_bank_idxs = len(banks_write_addr_per_input_lane[lane_idx]) #assert num_bank_idxs == t_in_diff.time() bank_idxs = [ builtins.tuple( int2seq(write_data_per_lane_per_clock.bank, bank_idx_width)) for write_data_per_lane_per_clock in banks_write_addr_per_input_lane[lane_idx] ] write_bank_for_input_lane_luts.append( DefineLUTAnyType(Array[bank_idx_width, Bit], num_bank_idxs, builtins.tuple(bank_idxs))()) # for each bank, the address to read from each clock read_addr_for_bank_luts = [] for bank_idx in range(len(rams)): ram_addr_width = rams_addr_widths[bank_idx] num_addrs = len(output_lane_read_addr_per_bank[bank_idx]) #assert num_addrs == t_in_diff.time() read_addrs = [ builtins.tuple( int2seq(read_data_per_bank_per_clock.addr, ram_addr_width)) for read_data_per_bank_per_clock in output_lane_read_addr_per_bank[bank_idx] ] read_addr_for_bank_luts.append( DefineLUTAnyType(Array[ram_addr_width, Bit], num_addrs, builtins.tuple(read_addrs))()) # for each bank, the lane to send each read to output_lane_for_bank_luts = [] # number of lanes equals number of banks # some the lanes are just always invalid, added so input lane width equals output lane width lane_idx_width = getRAMAddrWidth(len(rams)) for bank_idx in range(len(rams)): num_lane_idxs = len(output_lane_read_addr_per_bank[bank_idx]) #assert num_lane_idxs == t_in_diff.time() lane_idxs = [ builtins.tuple( int2seq(read_data_per_bank_per_clock.s, lane_idx_width)) for read_data_per_bank_per_clock in output_lane_read_addr_per_bank[bank_idx] ] output_lane_for_bank_luts.append( DefineLUTAnyType(Array[lane_idx_width, Bit], num_lane_idxs, builtins.tuple(lane_idxs))()) # second part creates the counters that index into the LUTs # elem_per counts time per element of the reshape elem_per_reshape_counter = AESizedCounterModM( ram_element_type.time(), has_ce=True) end_cur_elem = Decode(ram_element_type.time() - 1, elem_per_reshape_counter.O.N)( elem_per_reshape_counter.O) # reshape counts which element in the reshape num_clocks = len(output_lane_read_addr_per_bank[0]) reshape_write_counter = AESizedCounterModM(num_clocks, has_ce=True, has_reset=has_reset) reshape_read_counter = AESizedCounterModM(num_clocks, has_ce=True, has_reset=has_reset) output_delay = ( get_output_latencies(graph)[0]) * ram_element_type.time() # this is present so testing knows the delay cls.output_delay = output_delay reshape_read_delay_counter = DefineInitialDelayCounter( output_delay, has_ce=True, has_reset=has_reset)() # outer counter the repeats the reshape #wire(reshape_write_counter.O, cls.reshape_write_counter) enabled = DefineCoreirConst(1, 1)().O[0] if has_valid: enabled = cls.valid_up & enabled wire(reshape_read_delay_counter.valid, cls.valid_down) if has_ce: enabled = bit(cls.CE) & enabled wire(enabled, elem_per_reshape_counter.CE) wire(enabled, reshape_read_delay_counter.CE) wire(enabled & end_cur_elem, reshape_write_counter.CE) wire(enabled & end_cur_elem & reshape_read_delay_counter.valid, reshape_read_counter.CE) if has_reset: wire(cls.RESET, elem_per_reshape_counter.RESET) wire(cls.RESET, reshape_read_delay_counter.RESET) wire(cls.RESET, reshape_write_counter.RESET) wire(cls.RESET, reshape_read_counter.RESET) # wire read and write counters to all LUTs for lut in write_bank_for_input_lane_luts: wire(reshape_write_counter.O, lut.addr) for lut in write_addr_for_bank_luts: wire(reshape_write_counter.O, lut.addr) for lut in write_valid_for_bank_luts: wire(reshape_write_counter.O, lut.addr) for lut in read_addr_for_bank_luts: wire(reshape_read_counter.O, lut.addr) for lut in output_lane_for_bank_luts: wire(reshape_read_counter.O, lut.addr) # third and final instance creation part creates the sorting networks that map lanes to banks input_sorting_network_t = Tuple( bank=Array[write_bank_for_input_lane_luts[0].data.N, Bit], val=ram_element_type.magma_repr()) input_sorting_network = DefineBitonicSort(input_sorting_network_t, len(rams), lambda x: x.bank)() output_sorting_network_t = Tuple( lane=Array[output_lane_for_bank_luts[0].data.N, Bit], val=ram_element_type.magma_repr()) output_sorting_network = DefineBitonicSort( output_sorting_network_t, len(rams), lambda x: x.lane)() # wire luts, sorting networks, inputs, and rams # flatten all the sseq_layers to get flat magma type of inputs and outputs # tseqs don't affect magma types num_sseq_layers_inputs = num_nested_layers( remove_tseqs(shared_and_diff_subtypes.diff_input)) num_sseq_layers_to_remove_inputs = max(0, num_sseq_layers_inputs - 1) num_sseq_layers_outputs = num_nested_layers( remove_tseqs(shared_and_diff_subtypes.diff_output)) num_sseq_layers_to_remove_outputs = max( 0, num_sseq_layers_outputs - 1) if remove_tseqs( shared_and_diff_subtypes.shared_outer) != ST_Tombstone(): #num_sseq_layers_inputs += num_nested_layers(remove_tseqs(shared_and_diff_subtypes.shared_outer)) #num_sseq_layers_outputs += num_nested_layers(remove_tseqs(shared_and_diff_subtypes.shared_outer)) input_ports = flatten_ports( transpose_outer_dimensions( shared_and_diff_subtypes.shared_outer, shared_and_diff_subtypes.diff_input, cls.I), num_sseq_layers_to_remove_inputs) output_ports = flatten_ports( transpose_outer_dimensions( shared_and_diff_subtypes.shared_outer, shared_and_diff_subtypes.diff_output, cls.O), num_sseq_layers_to_remove_outputs) else: input_ports = flatten_ports(cls.I, num_sseq_layers_to_remove_inputs) output_ports = flatten_ports( cls.O, num_sseq_layers_to_remove_outputs) # this is only used if the shared outer layers contains any sseqs sseq_layers_to_flatten = max( num_nested_layers( remove_tseqs(shared_and_diff_subtypes.shared_outer)) - 1, 0) for idx in range(len(rams)): # wire input and bank to input sorting network wire(write_bank_for_input_lane_luts[idx].data, input_sorting_network.I[idx].bank) #if idx == 0: # wire(cls.first_valid, write_valid_for_bank_luts[idx].data) if idx < t_in_diff.port_width(): # since the input_ports are lists, need to wire them individually to the sorting ports if remove_tseqs(shared_and_diff_subtypes.shared_outer ) != ST_Tombstone(): cur_input_port = flatten_ports(input_ports[idx], sseq_layers_to_flatten) cur_sort_port = flatten_ports( input_sorting_network.I[idx].val, sseq_layers_to_flatten) for i in range(len(cur_input_port)): wire(cur_input_port[i], cur_sort_port[i]) else: if num_sseq_layers_inputs == 0: # input_ports will be an array of bits for 1 element # if no sseq in t_in wire(input_ports, input_sorting_network.I[idx].val) else: wire(input_ports[idx], input_sorting_network.I[idx].val) #wire(cls.ram_wr, input_sorting_network.O[idx].val) #wire(cls.ram_rd, rams[idx].RDATA) else: zero_const = DefineCoreirConst( ram_element_type.magma_repr().size(), 0)().O cur_sn_input = input_sorting_network.I[idx].val while len(cur_sn_input) != len(zero_const): cur_sn_input = cur_sn_input[0] wire(zero_const, cur_sn_input) # wire input sorting network, write addr, and write valid luts to banks wire(input_sorting_network.O[idx].val, rams[idx].WDATA) wire(write_addr_for_bank_luts[idx].data, rams[idx].WADDR) #wire(write_addr_for_bank_luts[idx].data[0], cls.addr_wr[idx]) if has_ce: wire(write_valid_for_bank_luts[idx].data & bit(cls.CE), rams[idx].WE) else: wire(write_valid_for_bank_luts[idx].data, rams[idx].WE) # wire output sorting network, read addr, read bank, and read enable wire(rams[idx].RDATA, output_sorting_network.I[idx].val) wire(output_lane_for_bank_luts[idx].data, output_sorting_network.I[idx].lane) wire(read_addr_for_bank_luts[idx].data, rams[idx].RADDR) #wire(read_addr_for_bank_luts[idx].data[0], cls.addr_rd[idx]) # ok to read invalid things, so in read value LUT if has_ce: wire(bit(cls.CE), rams[idx].RE) else: wire(DefineCoreirConst(1, 1)().O[0], rams[idx].RE) if has_reset: wire(cls.RESET, rams[idx].RESET) # wire output sorting network value to output or term if idx < t_out_diff.port_width(): # since the output_ports are lists, need to wire them individually to the sorting ports if remove_tseqs(shared_and_diff_subtypes.shared_outer ) != ST_Tombstone(): cur_output_port = flatten_ports( output_ports[idx], sseq_layers_to_flatten) cur_sort_port = flatten_ports( output_sorting_network.O[idx].val, sseq_layers_to_flatten) for i in range(len(cur_output_port)): wire(cur_output_port[i], cur_sort_port[i]) else: if num_sseq_layers_outputs == 0: # output_ports will be an array of bits for 1 element # if no sseq in t_out wire(output_sorting_network.O[idx].val, output_ports) else: wire(output_sorting_network.O[idx].val, output_ports[idx]) else: wire(output_sorting_network.O[idx].val, TermAnyType(type(output_sorting_network.O[idx].val))) # wire sorting networks bank/lane to term as not used on outputs, just used for sorting wire(input_sorting_network.O[idx].bank, TermAnyType(type(input_sorting_network.O[idx].bank))) wire(output_sorting_network.O[idx].lane, TermAnyType(type(output_sorting_network.O[idx].lane)))