예제 #1
0
def test_tseq_2_0_sseq_50_stuple_3_to_tseq_2_0_sseq_50_sseq_3():
    input_type = ST_TSeq(2, 0, ST_SSeq(50, ST_SSeq_Tuple(3, ST_Int())))
    output_type = ST_TSeq(2, 0, ST_SSeq(50, ST_SSeq(3, ST_Int())))
    graph = build_permutation_graph(input_type, output_type)
    testcircuit = DefineReshape_ST(input_type, output_type)
    tester = fault.Tester(testcircuit, testcircuit.CLK)
    check_reshape(graph, 2, testcircuit.output_delay, tester, 1, 1)
예제 #2
0
def test_tseq_2_0_tseq_1_2_sseq_1_to_tseq_1_5_sseq_2():
    input_type = ST_TSeq(2, 0, ST_TSeq(1, 2, ST_SSeq(1, ST_Int())))
    output_type = ST_TSeq(1, 5, ST_SSeq(2, ST_Int()))
    graph = build_permutation_graph(input_type, output_type)
    testcircuit = DefineReshape_ST(input_type, output_type)
    tester = fault.Tester(testcircuit, testcircuit.CLK)
    check_reshape(graph, 2, testcircuit.output_delay, tester, 0, 0)
예제 #3
0
def test_shared_outer_sseq_2_tseq_3_1_diff_2_3_shared_inner_sseq_2_tseq_3_3_flip_reshape(
):
    """
    Tests reshape with both sseq and tseq on inside and outside
    """
    no = 3
    io = 0
    ni = 2
    nii = 2
    niii = 3
    iiii = 3
    input_type = ST_SSeq(
        2,
        ST_TSeq(
            3, 1,
            ST_TSeq(no, io,
                    ST_SSeq(ni, ST_SSeq(nii, ST_TSeq(niii, iiii, ST_Int()))))))
    output_type = ST_SSeq(
        2,
        ST_TSeq(
            3, 1,
            ST_SSeq(
                ni, ST_TSeq(no, io, ST_SSeq(nii, ST_TSeq(niii, iiii,
                                                         ST_Int()))))))
    graph = build_permutation_graph(input_type, output_type)
    testcircuit = DefineReshape_ST(input_type, output_type)
    tester = fault.Tester(testcircuit, testcircuit.CLK)
    check_reshape(graph, 2, testcircuit.output_delay, tester, 2, 2)
예제 #4
0
def test_sseq_2_tseq_2_to_sseq_4_tseq_1_reshape():
    """
    Tests reshape with different input and output ports
    """
    input_type = ST_SSeq(2, ST_TSeq(2, 0, ST_Int()))
    output_type = ST_SSeq(4, ST_TSeq(1, 1, ST_Int()))
    graph = build_permutation_graph(input_type, output_type)
    testcircuit = DefineReshape_ST(input_type, output_type)
    tester = fault.Tester(testcircuit, testcircuit.CLK)
    check_reshape(graph, 2, testcircuit.output_delay, tester, 0, 0)
예제 #5
0
def test_diff_sseq_on_diff_component_of_type():
    # i would need to improve test bench for this to actually be tested
    # just here to verify no errors on wiring. Need to manually check that for now.
    input_type = ST_TSeq(
        4, 0, ST_TSeq(1, 1, ST_TSeq(4, 0, ST_SSeq(1, ST_SSeq(1, ST_Int())))))
    output_type = ST_TSeq(
        4, 4, ST_SSeq(1, ST_TSeq(4, 0, ST_SSeq(1, ST_SSeq(1, ST_Int())))))
    graph = build_permutation_graph(input_type, output_type)
    testcircuit = DefineReshape_ST(input_type, output_type)
    tester = fault.Tester(testcircuit, testcircuit.CLK)
    tester.step(2)
    compile_and_run(tester)
예제 #6
0
def test_2_3_flip_reshape():
    """
    Tests the most basic flip
    """
    t_len = 3
    s_len = 2
    input_type = ST_SSeq(s_len, ST_TSeq(t_len, 0, ST_Int()))
    output_type = ST_TSeq(t_len, 0, ST_SSeq(s_len, ST_Int()))
    graph = build_permutation_graph(input_type, output_type)
    testcircuit = DefineReshape_ST(input_type, output_type)
    tester = fault.Tester(testcircuit, testcircuit.CLK)
    check_reshape(graph, 2, testcircuit.output_delay, tester, 0, 0)
예제 #7
0
def test_shared_sseq_2_2_3_flip_reshape():
    """
    Tests flip with a shared sseq on the outside
    """
    no = 2
    ni = 3
    ii = 0
    nii = 2
    input_type = ST_SSeq(no, ST_TSeq(ni, ii, ST_SSeq(nii, ST_Int())))
    output_type = ST_SSeq(no, ST_SSeq(nii, ST_TSeq(ni, ii, ST_Int())))
    graph = build_permutation_graph(input_type, output_type)
    testcircuit = DefineReshape_ST(input_type, output_type)
    tester = fault.Tester(testcircuit, testcircuit.CLK)
    check_reshape(graph, 2, testcircuit.output_delay, tester, 1, 1)
예제 #8
0
def test_tseq_3_6_tseq_1_2_to_tseq_3_24():
    input_type = ST_TSeq(3, 6, ST_TSeq(1, 2, ST_Int()))
    output_type = ST_TSeq(3, 24, ST_Int())
    graph = build_permutation_graph(input_type, output_type)
    testcircuit = DefineReshape_ST(input_type, output_type)
    tester = fault.Tester(testcircuit, testcircuit.CLK)
    check_reshape(graph,
                  2,
                  testcircuit.output_delay,
                  tester,
                  0,
                  0,
                  input_port_iterable=False,
                  output_port_iterable=False)
예제 #9
0
def test_2_2_3_flip_reshape():
    """
    Tests the flip where moving two sseqs at the same time
    """
    t_len = 3
    s_len_0 = 2
    s_len_1 = 2
    input_type = ST_SSeq(s_len_0, ST_SSeq(s_len_1, ST_TSeq(t_len, 0,
                                                           ST_Int())))
    output_type = ST_TSeq(t_len, 0, ST_SSeq(s_len_0,
                                            ST_SSeq(s_len_1, ST_Int())))
    graph = build_permutation_graph(input_type, output_type)
    testcircuit = DefineReshape_ST(input_type, output_type)
    tester = fault.Tester(testcircuit, testcircuit.CLK)
    check_reshape(graph, 2, testcircuit.output_delay, tester, 1, 1)
예제 #10
0
        def definition(cls):
            # first section creates the RAMs and LUTs that set values in them and the sorting network
            shared_and_diff_subtypes = get_shared_and_diff_subtypes(
                t_in, t_out)
            t_in_diff = shared_and_diff_subtypes.diff_input
            t_out_diff = shared_and_diff_subtypes.diff_output
            graph = build_permutation_graph(ST_TSeq(2, 0, t_in_diff),
                                            ST_TSeq(2, 0, t_out_diff))
            banks_write_addr_per_input_lane = get_banks_addr_per_lane(
                graph.input_nodes)
            input_lane_write_addr_per_bank = get_lane_addr_per_banks(
                graph.input_nodes)
            output_lane_read_addr_per_bank = get_lane_addr_per_banks(
                graph.output_nodes)

            # each ram only needs to be large enough to handle the number of addresses assigned to it
            # all rams receive the same number of writes
            # but some of those writes don't happen as the data is invalid, so don't need storage for them
            max_ram_addrs = [
                max([bank_clock_data.addr for bank_clock_data in bank_data])
                for bank_data in output_lane_read_addr_per_bank
            ]
            # rams also handle parallelism from outer_shared type as this affects all banks the same
            outer_shared_sseqs = remove_tseqs(
                shared_and_diff_subtypes.shared_outer)
            if outer_shared_sseqs == ST_Tombstone():
                ram_element_type = shared_and_diff_subtypes.shared_inner
            else:
                ram_element_type = replace_tombstone(
                    outer_shared_sseqs, shared_and_diff_subtypes.shared_inner)
            # can use wider rams rather than duplicate for outer_shared_sseqs because will
            # transpose dimenions of input wires below to wire up as if outer, shared dimensions
            # were on the inside
            rams = [
                DefineRAM_ST(ram_element_type, ram_max_addr + 1)()
                for ram_max_addr in max_ram_addrs
            ]
            rams_addr_widths = [ram.WADDR.N for ram in rams]

            # for bank, the addresses to write to each clock
            write_addr_for_bank_luts = []
            for bank_idx in range(len(rams)):
                ram_addr_width = rams_addr_widths[bank_idx]
                num_addrs = len(input_lane_write_addr_per_bank[bank_idx])
                #assert num_addrs == t_in_diff.time()
                write_addrs = [
                    builtins.tuple(
                        int2seq(write_data_per_bank_per_clock.addr,
                                ram_addr_width))
                    for write_data_per_bank_per_clock in
                    input_lane_write_addr_per_bank[bank_idx]
                ]
                write_addr_for_bank_luts.append(
                    DefineLUTAnyType(Array[ram_addr_width, Bit], num_addrs,
                                     builtins.tuple(write_addrs))())

            # for bank, whether to actually write this clock
            write_valid_for_bank_luts = []
            for bank_idx in range(len(rams)):
                num_valids = len(input_lane_write_addr_per_bank[bank_idx])
                #assert num_valids == t_in_diff.time()
                valids = [
                    builtins.tuple([write_data_per_bank_per_clock.valid])
                    for write_data_per_bank_per_clock in
                    input_lane_write_addr_per_bank[bank_idx]
                ]
                write_valid_for_bank_luts.append(
                    DefineLUTAnyType(Bit, num_valids,
                                     builtins.tuple(valids))())

            # for each input lane, the bank to write to each clock
            write_bank_for_input_lane_luts = []
            bank_idx_width = getRAMAddrWidth(len(rams))
            for lane_idx in range(len(banks_write_addr_per_input_lane)):
                num_bank_idxs = len(banks_write_addr_per_input_lane[lane_idx])
                #assert num_bank_idxs == t_in_diff.time()
                bank_idxs = [
                    builtins.tuple(
                        int2seq(write_data_per_lane_per_clock.bank,
                                bank_idx_width))
                    for write_data_per_lane_per_clock in
                    banks_write_addr_per_input_lane[lane_idx]
                ]
                write_bank_for_input_lane_luts.append(
                    DefineLUTAnyType(Array[bank_idx_width, Bit], num_bank_idxs,
                                     builtins.tuple(bank_idxs))())

            # for each bank, the address to read from each clock
            read_addr_for_bank_luts = []
            for bank_idx in range(len(rams)):
                ram_addr_width = rams_addr_widths[bank_idx]
                num_addrs = len(output_lane_read_addr_per_bank[bank_idx])
                #assert num_addrs == t_in_diff.time()
                read_addrs = [
                    builtins.tuple(
                        int2seq(read_data_per_bank_per_clock.addr,
                                ram_addr_width))
                    for read_data_per_bank_per_clock in
                    output_lane_read_addr_per_bank[bank_idx]
                ]
                read_addr_for_bank_luts.append(
                    DefineLUTAnyType(Array[ram_addr_width, Bit], num_addrs,
                                     builtins.tuple(read_addrs))())

            # for each bank, the lane to send each read to
            output_lane_for_bank_luts = []
            # number of lanes equals number of banks
            # some the lanes are just always invalid, added so input lane width equals output lane width
            lane_idx_width = getRAMAddrWidth(len(rams))
            for bank_idx in range(len(rams)):
                num_lane_idxs = len(output_lane_read_addr_per_bank[bank_idx])
                #assert num_lane_idxs == t_in_diff.time()
                lane_idxs = [
                    builtins.tuple(
                        int2seq(read_data_per_bank_per_clock.s,
                                lane_idx_width))
                    for read_data_per_bank_per_clock in
                    output_lane_read_addr_per_bank[bank_idx]
                ]
                output_lane_for_bank_luts.append(
                    DefineLUTAnyType(Array[lane_idx_width, Bit], num_lane_idxs,
                                     builtins.tuple(lane_idxs))())

            # second part creates the counters that index into the LUTs
            # elem_per counts time per element of the reshape
            elem_per_reshape_counter = AESizedCounterModM(
                ram_element_type.time(), has_ce=True)
            end_cur_elem = Decode(ram_element_type.time() - 1,
                                  elem_per_reshape_counter.O.N)(
                                      elem_per_reshape_counter.O)
            # reshape counts which element in the reshape
            num_clocks = len(output_lane_read_addr_per_bank[0])
            reshape_write_counter = AESizedCounterModM(num_clocks,
                                                       has_ce=True,
                                                       has_reset=has_reset)
            reshape_read_counter = AESizedCounterModM(num_clocks,
                                                      has_ce=True,
                                                      has_reset=has_reset)

            output_delay = (
                get_output_latencies(graph)[0]) * ram_element_type.time()
            # this is present so testing knows the delay
            cls.output_delay = output_delay
            reshape_read_delay_counter = DefineInitialDelayCounter(
                output_delay, has_ce=True, has_reset=has_reset)()
            # outer counter the repeats the reshape
            #wire(reshape_write_counter.O, cls.reshape_write_counter)

            enabled = DefineCoreirConst(1, 1)().O[0]
            if has_valid:
                enabled = cls.valid_up & enabled
                wire(reshape_read_delay_counter.valid, cls.valid_down)
            if has_ce:
                enabled = bit(cls.CE) & enabled
            wire(enabled, elem_per_reshape_counter.CE)
            wire(enabled, reshape_read_delay_counter.CE)
            wire(enabled & end_cur_elem, reshape_write_counter.CE)
            wire(enabled & end_cur_elem & reshape_read_delay_counter.valid,
                 reshape_read_counter.CE)

            if has_reset:
                wire(cls.RESET, elem_per_reshape_counter.RESET)
                wire(cls.RESET, reshape_read_delay_counter.RESET)
                wire(cls.RESET, reshape_write_counter.RESET)
                wire(cls.RESET, reshape_read_counter.RESET)

            # wire read and write counters to all LUTs
            for lut in write_bank_for_input_lane_luts:
                wire(reshape_write_counter.O, lut.addr)

            for lut in write_addr_for_bank_luts:
                wire(reshape_write_counter.O, lut.addr)

            for lut in write_valid_for_bank_luts:
                wire(reshape_write_counter.O, lut.addr)

            for lut in read_addr_for_bank_luts:
                wire(reshape_read_counter.O, lut.addr)

            for lut in output_lane_for_bank_luts:
                wire(reshape_read_counter.O, lut.addr)

            # third and final instance creation part creates the sorting networks that map lanes to banks
            input_sorting_network_t = Tuple(
                bank=Array[write_bank_for_input_lane_luts[0].data.N, Bit],
                val=ram_element_type.magma_repr())
            input_sorting_network = DefineBitonicSort(input_sorting_network_t,
                                                      len(rams),
                                                      lambda x: x.bank)()

            output_sorting_network_t = Tuple(
                lane=Array[output_lane_for_bank_luts[0].data.N, Bit],
                val=ram_element_type.magma_repr())
            output_sorting_network = DefineBitonicSort(
                output_sorting_network_t, len(rams), lambda x: x.lane)()

            # wire luts, sorting networks, inputs, and rams
            # flatten all the sseq_layers to get flat magma type of inputs and outputs
            # tseqs don't affect magma types
            num_sseq_layers_inputs = num_nested_layers(
                remove_tseqs(shared_and_diff_subtypes.diff_input))
            num_sseq_layers_to_remove_inputs = max(0,
                                                   num_sseq_layers_inputs - 1)
            num_sseq_layers_outputs = num_nested_layers(
                remove_tseqs(shared_and_diff_subtypes.diff_output))
            num_sseq_layers_to_remove_outputs = max(
                0, num_sseq_layers_outputs - 1)
            if remove_tseqs(
                    shared_and_diff_subtypes.shared_outer) != ST_Tombstone():
                #num_sseq_layers_inputs += num_nested_layers(remove_tseqs(shared_and_diff_subtypes.shared_outer))
                #num_sseq_layers_outputs += num_nested_layers(remove_tseqs(shared_and_diff_subtypes.shared_outer))
                input_ports = flatten_ports(
                    transpose_outer_dimensions(
                        shared_and_diff_subtypes.shared_outer,
                        shared_and_diff_subtypes.diff_input, cls.I),
                    num_sseq_layers_to_remove_inputs)
                output_ports = flatten_ports(
                    transpose_outer_dimensions(
                        shared_and_diff_subtypes.shared_outer,
                        shared_and_diff_subtypes.diff_output, cls.O),
                    num_sseq_layers_to_remove_outputs)
            else:
                input_ports = flatten_ports(cls.I,
                                            num_sseq_layers_to_remove_inputs)
                output_ports = flatten_ports(
                    cls.O, num_sseq_layers_to_remove_outputs)
            # this is only used if the shared outer layers contains any sseqs
            sseq_layers_to_flatten = max(
                num_nested_layers(
                    remove_tseqs(shared_and_diff_subtypes.shared_outer)) - 1,
                0)
            for idx in range(len(rams)):
                # wire input and bank to input sorting network
                wire(write_bank_for_input_lane_luts[idx].data,
                     input_sorting_network.I[idx].bank)
                #if idx == 0:
                #    wire(cls.first_valid, write_valid_for_bank_luts[idx].data)
                if idx < t_in_diff.port_width():
                    # since the input_ports are lists, need to wire them individually to the sorting ports
                    if remove_tseqs(shared_and_diff_subtypes.shared_outer
                                    ) != ST_Tombstone():
                        cur_input_port = flatten_ports(input_ports[idx],
                                                       sseq_layers_to_flatten)
                        cur_sort_port = flatten_ports(
                            input_sorting_network.I[idx].val,
                            sseq_layers_to_flatten)
                        for i in range(len(cur_input_port)):
                            wire(cur_input_port[i], cur_sort_port[i])
                    else:
                        if num_sseq_layers_inputs == 0:
                            # input_ports will be an array of bits for 1 element
                            # if no sseq in t_in
                            wire(input_ports, input_sorting_network.I[idx].val)
                        else:
                            wire(input_ports[idx],
                                 input_sorting_network.I[idx].val)
                    #wire(cls.ram_wr, input_sorting_network.O[idx].val)
                    #wire(cls.ram_rd, rams[idx].RDATA)
                else:
                    zero_const = DefineCoreirConst(
                        ram_element_type.magma_repr().size(), 0)().O
                    cur_sn_input = input_sorting_network.I[idx].val
                    while len(cur_sn_input) != len(zero_const):
                        cur_sn_input = cur_sn_input[0]
                    wire(zero_const, cur_sn_input)

                # wire input sorting network, write addr, and write valid luts to banks
                wire(input_sorting_network.O[idx].val, rams[idx].WDATA)
                wire(write_addr_for_bank_luts[idx].data, rams[idx].WADDR)
                #wire(write_addr_for_bank_luts[idx].data[0], cls.addr_wr[idx])
                if has_ce:
                    wire(write_valid_for_bank_luts[idx].data & bit(cls.CE),
                         rams[idx].WE)
                else:
                    wire(write_valid_for_bank_luts[idx].data, rams[idx].WE)

                # wire output sorting network, read addr, read bank, and read enable
                wire(rams[idx].RDATA, output_sorting_network.I[idx].val)
                wire(output_lane_for_bank_luts[idx].data,
                     output_sorting_network.I[idx].lane)
                wire(read_addr_for_bank_luts[idx].data, rams[idx].RADDR)
                #wire(read_addr_for_bank_luts[idx].data[0], cls.addr_rd[idx])
                # ok to read invalid things, so in read value LUT
                if has_ce:
                    wire(bit(cls.CE), rams[idx].RE)
                else:
                    wire(DefineCoreirConst(1, 1)().O[0], rams[idx].RE)
                if has_reset:
                    wire(cls.RESET, rams[idx].RESET)

                # wire output sorting network value to output or term
                if idx < t_out_diff.port_width():
                    # since the output_ports are lists, need to wire them individually to the sorting ports
                    if remove_tseqs(shared_and_diff_subtypes.shared_outer
                                    ) != ST_Tombstone():
                        cur_output_port = flatten_ports(
                            output_ports[idx], sseq_layers_to_flatten)
                        cur_sort_port = flatten_ports(
                            output_sorting_network.O[idx].val,
                            sseq_layers_to_flatten)
                        for i in range(len(cur_output_port)):
                            wire(cur_output_port[i], cur_sort_port[i])
                    else:
                        if num_sseq_layers_outputs == 0:
                            # output_ports will be an array of bits for 1 element
                            # if no sseq in t_out
                            wire(output_sorting_network.O[idx].val,
                                 output_ports)
                        else:
                            wire(output_sorting_network.O[idx].val,
                                 output_ports[idx])
                else:
                    wire(output_sorting_network.O[idx].val,
                         TermAnyType(type(output_sorting_network.O[idx].val)))

                # wire sorting networks bank/lane to term as not used on outputs, just used for sorting
                wire(input_sorting_network.O[idx].bank,
                     TermAnyType(type(input_sorting_network.O[idx].bank)))
                wire(output_sorting_network.O[idx].lane,
                     TermAnyType(type(output_sorting_network.O[idx].lane)))