示例#1
0
        def definition(cls):
            red_reg = DefineRegisterAnyType(cls.st_out_t.magma_repr())()
            if n > 1:
                op_renamed = tupleToTwoInputsForReduce(
                    op, num_nested_space_layers(cls.st_in_t[0]))
                reduce = DefineReduceSequential(n, op_renamed, has_ce=True)()
                enable_counter = DefineNestedCounters(cls.st_in_t[0],
                                                      has_last=False,
                                                      has_ce=True)()
                wire(enable_counter.valid, reduce.CE)
                wire(cls.valid_up, enable_counter.CE)
                wire(cls.I, reduce.I)

                wire(reduce.out, red_reg.I)
            else:
                wire(cls.I, red_reg.I)
            wire(cls.O, red_reg.O)

            # valid output after first full valid input collected
            valid_delay = InitialDelayCounter(
                time_last_valid(cls.st_in_t[0]) + 1)
            wire(cls.valid_up, valid_delay.CE)
            wire(cls.valid_down, valid_delay.valid)

            if n > 1:
                # ignore inner reduce ready and valid
                wire(reduce.valid, TermAnyType(Bit).I)
                wire(reduce.ready, TermAnyType(Bit).I)
示例#2
0
        def definition(cls):
            # each valid clock, going to get a magma_repr in
            # read or write each one of those to a location
            rams = [DefineRAMAnyType(t.magma_repr(), t.valid_clocks(), read_latency=read_latency)() for _ in range(n)]
            read_time_position_counter = DefineNestedCounters(t, has_cur_valid=True, has_ce=True, has_reset=has_reset)()
            read_valid_term = TermAnyType(Bit)
            read_last_term = TermAnyType(Bit)
            write_time_position_counter = DefineNestedCounters(t, has_cur_valid=True, has_ce=True, has_reset=has_reset)()
            write_valid_term = TermAnyType(Bit)
            write_last_term = TermAnyType(Bit)
            read_selector = DefineMuxAnyType(t.magma_repr(), n)()

            for i in range(n):
                wire(cls.WDATA, rams[i].WDATA)
                wire(write_time_position_counter.cur_valid, rams[i].WADDR)
                wire(read_selector.data[i], rams[i].RDATA)
                wire(read_time_position_counter.cur_valid, rams[i].RADDR)
                write_cur_ram = Decode(i, cls.WADDR.N)(cls.WADDR)
                wire(write_cur_ram & write_time_position_counter.valid, rams[i].WE)

            wire(cls.RADDR, read_selector.sel)
            wire(cls.RDATA, read_selector.out)

            wire(cls.WE, write_time_position_counter.CE)
            wire(cls.RE, read_time_position_counter.CE)

            wire(read_time_position_counter.valid, read_valid_term.I)
            wire(read_time_position_counter.last, read_last_term.I)
            wire(write_time_position_counter.valid, write_valid_term.I)
            wire(write_time_position_counter.last, write_last_term.I)

            if has_reset:
                wire(cls.RESET, write_time_position_counter.RESET)
                wire(cls.RESET, read_time_position_counter.RESET)
示例#3
0
 def definition(cls):
     output = array(0, 1)
     wire(output, cls.O)
     if has_ce:
         wire(cls.CE, TermAnyType(Bit))
     if has_reset:
         wire(cls.RESET, TermAnyType(Bit))
     if cin:
         wire(cls.CIN, TermAnyType(Bit))
     if cout:
         wire(cls.COUT, output[0])
示例#4
0
        def definition(cls):
            enabled = DefineCoreirConst(1, 1)().O[0]
            if has_valid:
                enabled = cls.valid_up & enabled
                wire(cls.valid_up, cls.valid_down)
            if has_ce:
                enabled = bit(cls.CE) & enabled

            value_store = DefineRAM_ST(elem_t,
                                       shift_amount,
                                       has_reset=has_reset)()

            # write and read from same location
            # will write on first iteration through element, write and read on later iterations
            # output for first iteration is undefined, so ok to read anything
            next_ram_addr = DefineNestedCounters(elem_t,
                                                 has_ce=True,
                                                 has_reset=has_reset)()
            # its fine that this doesn't account for the invalid clocks of outer TSeq
            # after the invalid clocks, the next iteration will start from
            # an index that is possibly not 0. That doesn't matter
            # as will just loop around
            ram_addr = AESizedCounterModM(shift_amount,
                                          has_ce=True,
                                          has_reset=has_reset)
            # this handles invalid clocks of inner TSeq
            inner_valid_t = ST_Int()
            for i in range(len(nis))[::-1]:
                inner_valid_t = ST_TSeq(nis[i], iis[i], inner_valid_t)
            inner_valid = DefineNestedCounters(inner_valid_t,
                                               has_last=False,
                                               has_ce=True,
                                               has_reset=has_reset,
                                               valid_when_ce_off=True)()

            wire(ram_addr.O, value_store.WADDR)
            wire(ram_addr.O, value_store.RADDR)

            wire(enabled & inner_valid.valid, value_store.WE)
            wire(enabled & next_ram_addr.last, inner_valid.CE)
            #wire(inner_valid.valid, cls.inner_valid)
            wire(enabled & inner_valid.valid, value_store.RE)
            wire(enabled & next_ram_addr.last & inner_valid.valid, ram_addr.CE)
            wire(enabled, next_ram_addr.CE)

            next_ram_addr_term = TermAnyType(Bit)
            wire(next_ram_addr.valid, next_ram_addr_term.I)

            wire(cls.I, value_store.WDATA)
            wire(value_store.RDATA, cls.O)
            if has_reset:
                wire(value_store.RESET, cls.RESET)
                wire(ram_addr.RESET, cls.RESET)
                wire(next_ram_addr.RESET, cls.RESET)
                wire(inner_valid.RESET, cls.RESET)
示例#5
0
 def definition(cls):
     if n > 1:
         inputs_term = TermAnyType(ST_SSeq(n - 1, elem_t).magma_repr())
     num_wired_to_output = 0
     for i in range(len(cls.I)):
         if i == idx:
             wire(cls.I[i], cls.O[0])
             num_wired_to_output += 1
         else:
             wire(cls.I[i], inputs_term.I[i - num_wired_to_output])
     if has_valid:
         wire(cls.valid_up, cls.valid_down)
示例#6
0
 def definition(downsampleParallel):
     # dehydrate all but the first, that one is passed through
     inputs_term = TermAnyType(Array[n - 1, T])
     num_wired_to_output = 0
     for i in range(len(downsampleParallel.I)):
         if i == idx:
             wire(downsampleParallel.I[i], downsampleParallel.O)
             num_wired_to_output += 1
         else:
             wire(downsampleParallel.I[i],
                  inputs_term.I[i - num_wired_to_output])
     if has_ready_valid:
         wire(downsampleParallel.ready_up,
              downsampleParallel.ready_down)
         wire(downsampleParallel.valid_up,
              downsampleParallel.valid_down)
def test_term():
    width = 11
    T = Array[width, BitIn]

    args = ['I', In(T), 'O', Out(T)]
    testcircuit = DefineCircuit('Test_Term', *args)
    wire(testcircuit.I, testcircuit.O)
    term = TermAnyType(T)
    t_const = DefineCoreirConst(width, 0)()
    wire(t_const.O, term.I)
    EndCircuit()

    tester = fault.Tester(testcircuit)
    tester.circuit.I = 2
    tester.eval()
    tester.circuit.O.expect(2)
    compile_and_run(tester)
示例#8
0
        def definition(cls):
            enabled = DefineCoreirConst(1, 1)().O[0]
            if has_valid:
                enabled = cls.valid_up & enabled
                wire(cls.valid_up, cls.valid_down)
            if has_ce:
                enabled = bit(cls.CE) & enabled

            value_store = DefineRAM_ST(elem_t,
                                       shift_amount,
                                       has_reset=has_reset)()

            # write and read from same location
            # will write on first iteration through element, write and read on later iterations
            # output for first iteration is undefined, so ok to read anything
            next_ram_addr = DefineNestedCounters(elem_t,
                                                 has_ce=True,
                                                 has_reset=has_reset)()
            # its fine that this doesn't account for the invalid clocks.
            # after the invalid clocks, the next iteration will start from
            # an index that is possibly not 0. That doesn't matter
            # as will just loop around
            ram_addr = AESizedCounterModM(shift_amount,
                                          has_ce=True,
                                          has_reset=has_reset)

            wire(ram_addr.O, value_store.WADDR)
            wire(ram_addr.O, value_store.RADDR)

            wire(enabled, value_store.WE)
            wire(enabled, value_store.RE)
            wire(enabled & next_ram_addr.last, ram_addr.CE)
            wire(enabled, next_ram_addr.CE)

            next_ram_addr_term = TermAnyType(Bit)
            wire(next_ram_addr.valid, next_ram_addr_term.I)

            wire(cls.I, value_store.WDATA)
            wire(value_store.RDATA, cls.O)
            if has_reset:
                wire(value_store.RESET, cls.RESET)
                wire(ram_addr.RESET, cls.RESET)
                wire(next_ram_addr.RESET, cls.RESET)
示例#9
0
        def definition(cls):
            enabled = DefineCoreirConst(1, 1)().O[0]
            if has_valid:
                enabled = cls.valid_up & enabled
                wire(cls.valid_up, cls.valid_down)
            if has_ce:
                enabled = bit(cls.CE) & enabled

            value_store = DefineRAM_ST(elem_t, 1, has_reset=has_reset)()

            # write to value_store for first element, read for next
            element_time_counter = DefineNestedCounters(elem_t, has_ce=True, has_reset=has_reset)()
            element_idx_counter = AESizedCounterModM(n + i, has_ce=True, has_reset=has_reset)
            is_first_element = Decode(0, element_idx_counter.O.N)(element_idx_counter.O)

            zero_addr = DefineCoreirConst(1, 0)().O
            wire(zero_addr, value_store.WADDR)
            wire(zero_addr, value_store.RADDR)

            wire(enabled & is_first_element, value_store.WE)
            wire(enabled, value_store.RE)
            wire(enabled, element_time_counter.CE)
            wire(enabled & element_time_counter.last, element_idx_counter.CE)

            element_time_counter_term = TermAnyType(Bit)
            wire(element_time_counter.valid, element_time_counter_term.I)

            wire(cls.I, value_store.WDATA)

            output_selector = DefineMuxAnyType(elem_t.magma_repr(), 2)()

            # on first element, send the input directly out. otherwise, use the register
            wire(is_first_element, output_selector.sel[0])
            wire(value_store.RDATA, output_selector.data[0])
            wire(cls.I, output_selector.data[1])
            wire(output_selector.out, cls.O)

            if has_reset:
                wire(value_store.RESET, cls.RESET)
                wire(element_time_counter.RESET, cls.RESET)
                wire(element_idx_counter.RESET, cls.RESET)
示例#10
0
        def definition(BitonicSort):
            # generate the max value (all 1's) and feed it to all inputs to
            # power 2 bitonic sorting network not used by inputs
            t_size = T.size()
            n_raised_to_nearest_pow2 = pow(2, ceil(log2(n)))
            if n_raised_to_nearest_pow2 > n:
                max_const_flat = DefineCoreirConst(t_size,
                                                   pow(2, t_size) - 1)()
                max_const = Hydrate(T)
                wire(max_const_flat.O, max_const.I)

            pow2_sort = DefineBitonicSortPow2(T, n_raised_to_nearest_pow2,
                                              cmp_component)()
            for i in range(n_raised_to_nearest_pow2):
                if i < n:
                    wire(BitonicSort.I[i], pow2_sort.I[i])
                    wire(BitonicSort.O[i], pow2_sort.O[i])
                else:
                    wire(max_const.out, pow2_sort.I[i])
                    term = TermAnyType(T)
                    wire(term.I, pow2_sort.O[i])
示例#11
0
        def definition(cls):
            one_const = DefineCoreirConst(1, 1)().O[0]
            if delay == 0:
                enabled = one_const
            else:
                delay_counter = InitialDelayCounter(delay)
                wire(delay_counter.CE, one_const)
                enabled = delay_counter.valid
            if has_ce:
                enabled = bit(cls.CE) & enabled

            luts = DefineLUTAnyType(t.magma_repr(), t.time(), ts_arrays_to_bits(ts_values))()
            lut_position_counter = AESizedCounterModM(t.time(), has_ce=True, has_reset=has_reset)

            wire(lut_position_counter.O, luts.addr)
            wire(cls.O, luts.data)
            wire(enabled, lut_position_counter.CE)

            if has_reset:
                wire(cls.RESET, lut_position_counter.RESET)
            if has_valid:
                valid_up_term = TermAnyType(Bit)
                wire(cls.valid_up, valid_up_term.I)
                wire(enabled, cls.valid_down)
        def definition(STBankGenerator):
            flat_idx_width = getRAMAddrWidth(no * ni)
            # next element each time_per_element clock
            if time_per_element > 1:
                index_in_cur_element = SizedCounterModM(time_per_element,
                                                        has_ce=has_ce,
                                                        has_reset=has_reset)
                next_element = Decode(time_per_element - 1,
                                      index_in_cur_element.O.N)(
                                          index_in_cur_element.O)
            else:
                next_element = DefineCoreirConst(1, 1)()
            # each element of the SSeq is a separate vector lane
            first_lane_flat_idx = DefineCounterModM(ni + ii,
                                                    flat_idx_width,
                                                    cout=False,
                                                    has_ce=True,
                                                    has_reset=has_reset)()
            wire(next_element.O[0], first_lane_flat_idx.CE)
            if has_ce:
                wire(STBankGenerator.CE, index_in_cur_element.CE)
            if has_reset:
                wire(STBankGenerator.RESET, index_in_cur_element.RESET)
                wire(STBankGenerator.RESET, first_lane_flat_idx.RESET)

            lane_flat_idxs = [first_lane_flat_idx.O]

            # compute the current flat_idx for each lane
            for i in range(1, no):
                cur_lane_flat_idx_adder = DefineAdd(flat_idx_width)()
                wire(cur_lane_flat_idx_adder.I0, first_lane_flat_idx.O)
                wire(cur_lane_flat_idx_adder.I1,
                     DefineCoreirConst(flat_idx_width, i * ni)().O)

                lane_flat_idxs += [cur_lane_flat_idx_adder.O]

            lane_flat_div_lcms = []
            lcm_dim = DefineCoreirConst(flat_idx_width, lcm(no, ni))()
            # conmpute flat_idx / lcm_dim for each lane
            for i in range(no):
                cur_lane_lcm_div = DefineUDiv(flat_idx_width)()
                wire(cur_lane_lcm_div.I0, lane_flat_idxs[i])
                wire(cur_lane_lcm_div.I1, lcm_dim.O)

                lane_flat_div_lcms += [cur_lane_lcm_div.O]

            # compute ((flat_idx % sseq_dim) + (flat_idx / lcm_dim)) % sseq_dim for each lane
            # only need to mod sseq_dim at end as that is same as also doing it flat_idx before addition
            for i in range(no):
                pre_mod_add = DefineAdd(flat_idx_width)()
                wire(pre_mod_add.I0, lane_flat_idxs[i])
                wire(pre_mod_add.I1, lane_flat_div_lcms[i])

                bank_mod = DefineUMod(flat_idx_width)()
                wire(bank_mod.I0, pre_mod_add.O)
                wire(bank_mod.I1, DefineCoreirConst(flat_idx_width, no)().O)

                wire(STBankGenerator.bank[i],
                     bank_mod.O[0:STBankGenerator.bank_width])
                if len(bank_mod.O) > STBankGenerator.bank_width:
                    bits_to_term = len(bank_mod.O) - STBankGenerator.bank_width
                    term = TermAnyType(Array[bits_to_term, Bit])
                    wire(bank_mod.O[STBankGenerator.bank_width:], term.I)

            # compute flat_idx / sseq_dim for each lane addr
            for i in range(no):
                flat_idx_sseq_dim_div = DefineUDiv(flat_idx_width)()
                wire(flat_idx_sseq_dim_div.I0, lane_flat_idxs[0])
                wire(flat_idx_sseq_dim_div.I1,
                     DefineCoreirConst(flat_idx_width, no)().O)

                wire(STBankGenerator.addr[i],
                     flat_idx_sseq_dim_div.O[0:STBankGenerator.addr_width])
                if len(flat_idx_sseq_dim_div.O) > STBankGenerator.addr_width:
                    bits_to_term = len(bank_mod.O) - STBankGenerator.addr_width
                    term = TermAnyType(Array[bits_to_term, Bit])
                    wire(flat_idx_sseq_dim_div.O[STBankGenerator.addr_width:],
                         term.I)
示例#13
0
        def definition(cls):

            shift_register = MapParallel(
                pixel_per_clock,
                SIPOAnyType(image_size // pixel_per_clock,
                            pixel_type,
                            0,
                            has_ce=True))

            # reverse the pixels per clock. Since greater index_in_shift_register
            # mean earlier inputted pixels, also want greater current_shift_register
            # to mean earlier inputted pixels. This accomplishes that by making
            # pixels earlier each clock go to higher number shift register
            if first_row:
                wire(cls.I[::-1], shift_register.I)
            else:
                # don't need to reverse if not first row as prior rows have already done reversing
                wire(cls.I, shift_register.I)

            for i in range(pixel_per_clock):
                wire(cls.CE, shift_register.CE[i])

            # these two variables provide a 2D coordinate system for the SIPOs.
            # the inner dimension is current_shift_register
            # the outer dimension is index_in_shift_register
            # greater values in current_shift_register are inputs from older clocks
            # greater values in index_in_shift_register are inputs from lower index
            # values in the inputs in a single clock (due to above cls.I, type_to_bits reversing)
            # the index_in_shift_register is reversed so that bigger number always
            # means lower indexed value in the input image. For example, if asking
            # for location 0 with a 2 px per clock, 3 window width, then the
            # 2D location is index_in_shift_register = 1, current_shift_register = 1
            # and walking back in 2D space as increasing 1D location.
            current_shift_register = 0
            index_in_shift_register = 0

            # since current_shift_register and index_in_shift_register form a
            # 2D shape where current_shift_registers is inner dimension and
            # index_in_shift_register is outer, get_shift_register_location_in_1D_coordinates
            # and set_shift_register_location_using_1D_coordinates  convert between
            # 2D coordinates in the SIPOs and 1D coordinates in the 1D image

            # To do the reversing of 1D coordinates, need to find the oldest pixel that should be output,
            # ignoring origin as origin doesn't impact this computation.
            # This is done by finding the number of relevant pixels for outputting and adjusting it
            # so that it aligns with the number of pixels per clock cycle.
            # That coordinates position is treated as a 0 in the reverse coordinates
            # and requested coordinates (going in the opposite direction) are reversed
            # and adjusted to fit the new coordinate system by subtracting their values
            # from the 0's value in the original, forward coordinate system.

            # need to be able to handle situations with swizzling. Swizzling is
            # where a pixel inputted this clock is not used until next clock.
            # This is handled by wiring up in reverse order. If a pixel is inputted
            # in a clock but not used, it will have a high 1D location as it will be
            # one of the first registers in the first index_in_shift_register.
            # The swizzled pixel's large 1D location ensures it isn't wired directly
            # to an output

            # get needed pixels (ignoring origin as that can be garbage)
            # to determine number of clock cycles needed to satisfy input
            if cls.windows_per_active_clock == 1:
                needed_pixels = window_width
            else:
                needed_pixels = window_width + stride * (
                    cls.windows_per_active_clock - 1)

            # get the maximum 1D coordinate when aligning needed pixels to the number
            # of pixels per clock
            if needed_pixels % pixel_per_clock == 0:
                oldest_needed_pixel_forward_1D_coordinates = needed_pixels
            else:
                oldest_needed_pixel_forward_1D_coordinates = ceil(needed_pixels / pixel_per_clock) * \
                                                             pixel_per_clock

            # adjust by 1 for 0 indexing
            oldest_needed_pixel_forward_1D_coordinates -= 1

            def get_shift_register_location_in_1D_coordinates() -> int:
                return oldest_needed_pixel_forward_1D_coordinates - \
                       (index_in_shift_register * pixel_per_clock +
                        current_shift_register)

            def set_shift_register_location_using_1D_coordinates(
                    location: int) -> int:
                nonlocal current_shift_register, index_in_shift_register
                location_reversed_indexing = oldest_needed_pixel_forward_1D_coordinates - location
                index_in_shift_register = location_reversed_indexing // pixel_per_clock
                current_shift_register = location_reversed_indexing % pixel_per_clock

            used_coordinates = set()

            for current_window_index in range(cls.windows_per_active_clock):
                # stride is handled by wiring if there are multiple windows emitted per clock,
                # aka if stride is less than number of pixels per clock.
                # In this case, multiple windows are emitted but they must be overlapped
                # less than normal
                strideMultiplier = stride if stride < pixel_per_clock else 1
                set_shift_register_location_using_1D_coordinates(
                    strideMultiplier * current_window_index +
                    # handle origin across multiple clocks by changing valid, but within a single clock
                    # need to adjust where the windows start
                    # need neg conversion twice due to issues taking mod of negative number
                    ((origin * -1) % pixel_per_clock * -1))
                for index_in_window in range(window_width):
                    wire(
                        shift_register.O[current_shift_register]
                        [index_in_shift_register],
                        cls.O[current_window_index][index_in_window])

                    used_coordinates.add(
                        (index_in_shift_register, current_shift_register))

                    set_shift_register_location_using_1D_coordinates(
                        get_shift_register_location_in_1D_coordinates() + 1)

            # if not last row, have output ports for ends of all shift_registers so next
            # 1D can accept them
            if not last_row:
                index_in_shift_register = image_size // pixel_per_clock - 1
                for current_shift_register in range(pixel_per_clock):
                    wire(
                        shift_register.O[current_shift_register]
                        [index_in_shift_register],
                        cls.next_row[current_shift_register])
                    used_coordinates.add(
                        (index_in_shift_register, current_shift_register))

            # wire up all non-used coordinates to terms
            for sr in range(pixel_per_clock):
                for sr_index in range(image_size // pixel_per_clock):
                    if (sr_index, sr) in used_coordinates:
                        continue
                    term = TermAnyType(pixel_type)
                    wire(shift_register.O[sr][sr_index], term.I)

            # valid when the maximum coordinate used (minus origin, as origin can in
            # invalid space when emitting) gets data
            # add 1 here as coordinates are 0 indexed, and the denominator of this
            # fraction is the last register accessed
            # would add 1 outside fraction as it takes 1 clock for data
            # to get through registers but won't as 0 indexed
            valid_counter_max_value = ceil(
                (oldest_needed_pixel_forward_1D_coordinates + 1 + origin) /
                pixel_per_clock)

            # add 1 as sizedcounter counts to 1 less than the provided max
            valid_counter = SizedCounterModM(valid_counter_max_value + 1,
                                             has_ce=True)

            valid_counter_max_instance = DefineCoreirConst(
                len(valid_counter.O), valid_counter_max_value)()

            wire(
                enable(
                    bit(cls.CE)
                    & (valid_counter.O < valid_counter_max_instance.O)),
                valid_counter.CE)

            # if stride is greater than pixels_per_clock, then need a stride counter as
            # not active every clock. Invalid clocks create striding in this case
            if stride > pixel_per_clock:

                stride_counter = SizedCounterModM(stride // pixel_per_clock,
                                                  has_ce=True)
                stride_counter_0 = DefineCoreirConst(len(stride_counter.O),
                                                     0)()

                wire(
                    enable((stride_counter.O == stride_counter_0.O) &
                           (valid_counter.O == valid_counter_max_instance.O)),
                    cls.valid)

                # only increment stride if trying to emit data this clock cycle
                wire(valid_counter.O == valid_counter_max_instance.O,
                     stride_counter.CE)

            else:
                wire((valid_counter.O == valid_counter_max_instance.O),
                     cls.valid)
示例#14
0
 def definition(cls):
     fst_term = TermAnyType(t.t0.magma_repr())
     wire(cls.I[0], fst_term.I)
     wire(cls.I[1], cls.O)
     if has_valid:
         wire(cls.valid_up, cls.valid_down)
示例#15
0
 def definition(cls):
     wire(cls.I[0], cls.O)
     snd_term = TermAnyType(t.t1.magma_repr())
     wire(cls.I[1], snd_term.I)
     if has_valid:
         wire(cls.valid_up, cls.valid_down)
示例#16
0
        def definition(cls):
            # first section creates the RAMs and LUTs that set values in them and the sorting network
            shared_and_diff_subtypes = get_shared_and_diff_subtypes(
                t_in, t_out)
            t_in_diff = shared_and_diff_subtypes.diff_input
            t_out_diff = shared_and_diff_subtypes.diff_output
            graph = build_permutation_graph(ST_TSeq(2, 0, t_in_diff),
                                            ST_TSeq(2, 0, t_out_diff))
            banks_write_addr_per_input_lane = get_banks_addr_per_lane(
                graph.input_nodes)
            input_lane_write_addr_per_bank = get_lane_addr_per_banks(
                graph.input_nodes)
            output_lane_read_addr_per_bank = get_lane_addr_per_banks(
                graph.output_nodes)

            # each ram only needs to be large enough to handle the number of addresses assigned to it
            # all rams receive the same number of writes
            # but some of those writes don't happen as the data is invalid, so don't need storage for them
            max_ram_addrs = [
                max([bank_clock_data.addr for bank_clock_data in bank_data])
                for bank_data in output_lane_read_addr_per_bank
            ]
            # rams also handle parallelism from outer_shared type as this affects all banks the same
            outer_shared_sseqs = remove_tseqs(
                shared_and_diff_subtypes.shared_outer)
            if outer_shared_sseqs == ST_Tombstone():
                ram_element_type = shared_and_diff_subtypes.shared_inner
            else:
                ram_element_type = replace_tombstone(
                    outer_shared_sseqs, shared_and_diff_subtypes.shared_inner)
            # can use wider rams rather than duplicate for outer_shared_sseqs because will
            # transpose dimenions of input wires below to wire up as if outer, shared dimensions
            # were on the inside
            rams = [
                DefineRAM_ST(ram_element_type, ram_max_addr + 1)()
                for ram_max_addr in max_ram_addrs
            ]
            rams_addr_widths = [ram.WADDR.N for ram in rams]

            # for bank, the addresses to write to each clock
            write_addr_for_bank_luts = []
            for bank_idx in range(len(rams)):
                ram_addr_width = rams_addr_widths[bank_idx]
                num_addrs = len(input_lane_write_addr_per_bank[bank_idx])
                #assert num_addrs == t_in_diff.time()
                write_addrs = [
                    builtins.tuple(
                        int2seq(write_data_per_bank_per_clock.addr,
                                ram_addr_width))
                    for write_data_per_bank_per_clock in
                    input_lane_write_addr_per_bank[bank_idx]
                ]
                write_addr_for_bank_luts.append(
                    DefineLUTAnyType(Array[ram_addr_width, Bit], num_addrs,
                                     builtins.tuple(write_addrs))())

            # for bank, whether to actually write this clock
            write_valid_for_bank_luts = []
            for bank_idx in range(len(rams)):
                num_valids = len(input_lane_write_addr_per_bank[bank_idx])
                #assert num_valids == t_in_diff.time()
                valids = [
                    builtins.tuple([write_data_per_bank_per_clock.valid])
                    for write_data_per_bank_per_clock in
                    input_lane_write_addr_per_bank[bank_idx]
                ]
                write_valid_for_bank_luts.append(
                    DefineLUTAnyType(Bit, num_valids,
                                     builtins.tuple(valids))())

            # for each input lane, the bank to write to each clock
            write_bank_for_input_lane_luts = []
            bank_idx_width = getRAMAddrWidth(len(rams))
            for lane_idx in range(len(banks_write_addr_per_input_lane)):
                num_bank_idxs = len(banks_write_addr_per_input_lane[lane_idx])
                #assert num_bank_idxs == t_in_diff.time()
                bank_idxs = [
                    builtins.tuple(
                        int2seq(write_data_per_lane_per_clock.bank,
                                bank_idx_width))
                    for write_data_per_lane_per_clock in
                    banks_write_addr_per_input_lane[lane_idx]
                ]
                write_bank_for_input_lane_luts.append(
                    DefineLUTAnyType(Array[bank_idx_width, Bit], num_bank_idxs,
                                     builtins.tuple(bank_idxs))())

            # for each bank, the address to read from each clock
            read_addr_for_bank_luts = []
            for bank_idx in range(len(rams)):
                ram_addr_width = rams_addr_widths[bank_idx]
                num_addrs = len(output_lane_read_addr_per_bank[bank_idx])
                #assert num_addrs == t_in_diff.time()
                read_addrs = [
                    builtins.tuple(
                        int2seq(read_data_per_bank_per_clock.addr,
                                ram_addr_width))
                    for read_data_per_bank_per_clock in
                    output_lane_read_addr_per_bank[bank_idx]
                ]
                read_addr_for_bank_luts.append(
                    DefineLUTAnyType(Array[ram_addr_width, Bit], num_addrs,
                                     builtins.tuple(read_addrs))())

            # for each bank, the lane to send each read to
            output_lane_for_bank_luts = []
            # number of lanes equals number of banks
            # some the lanes are just always invalid, added so input lane width equals output lane width
            lane_idx_width = getRAMAddrWidth(len(rams))
            for bank_idx in range(len(rams)):
                num_lane_idxs = len(output_lane_read_addr_per_bank[bank_idx])
                #assert num_lane_idxs == t_in_diff.time()
                lane_idxs = [
                    builtins.tuple(
                        int2seq(read_data_per_bank_per_clock.s,
                                lane_idx_width))
                    for read_data_per_bank_per_clock in
                    output_lane_read_addr_per_bank[bank_idx]
                ]
                output_lane_for_bank_luts.append(
                    DefineLUTAnyType(Array[lane_idx_width, Bit], num_lane_idxs,
                                     builtins.tuple(lane_idxs))())

            # second part creates the counters that index into the LUTs
            # elem_per counts time per element of the reshape
            elem_per_reshape_counter = AESizedCounterModM(
                ram_element_type.time(), has_ce=True)
            end_cur_elem = Decode(ram_element_type.time() - 1,
                                  elem_per_reshape_counter.O.N)(
                                      elem_per_reshape_counter.O)
            # reshape counts which element in the reshape
            num_clocks = len(output_lane_read_addr_per_bank[0])
            reshape_write_counter = AESizedCounterModM(num_clocks,
                                                       has_ce=True,
                                                       has_reset=has_reset)
            reshape_read_counter = AESizedCounterModM(num_clocks,
                                                      has_ce=True,
                                                      has_reset=has_reset)

            output_delay = (
                get_output_latencies(graph)[0]) * ram_element_type.time()
            # this is present so testing knows the delay
            cls.output_delay = output_delay
            reshape_read_delay_counter = DefineInitialDelayCounter(
                output_delay, has_ce=True, has_reset=has_reset)()
            # outer counter the repeats the reshape
            #wire(reshape_write_counter.O, cls.reshape_write_counter)

            enabled = DefineCoreirConst(1, 1)().O[0]
            if has_valid:
                enabled = cls.valid_up & enabled
                wire(reshape_read_delay_counter.valid, cls.valid_down)
            if has_ce:
                enabled = bit(cls.CE) & enabled
            wire(enabled, elem_per_reshape_counter.CE)
            wire(enabled, reshape_read_delay_counter.CE)
            wire(enabled & end_cur_elem, reshape_write_counter.CE)
            wire(enabled & end_cur_elem & reshape_read_delay_counter.valid,
                 reshape_read_counter.CE)

            if has_reset:
                wire(cls.RESET, elem_per_reshape_counter.RESET)
                wire(cls.RESET, reshape_read_delay_counter.RESET)
                wire(cls.RESET, reshape_write_counter.RESET)
                wire(cls.RESET, reshape_read_counter.RESET)

            # wire read and write counters to all LUTs
            for lut in write_bank_for_input_lane_luts:
                wire(reshape_write_counter.O, lut.addr)

            for lut in write_addr_for_bank_luts:
                wire(reshape_write_counter.O, lut.addr)

            for lut in write_valid_for_bank_luts:
                wire(reshape_write_counter.O, lut.addr)

            for lut in read_addr_for_bank_luts:
                wire(reshape_read_counter.O, lut.addr)

            for lut in output_lane_for_bank_luts:
                wire(reshape_read_counter.O, lut.addr)

            # third and final instance creation part creates the sorting networks that map lanes to banks
            input_sorting_network_t = Tuple(
                bank=Array[write_bank_for_input_lane_luts[0].data.N, Bit],
                val=ram_element_type.magma_repr())
            input_sorting_network = DefineBitonicSort(input_sorting_network_t,
                                                      len(rams),
                                                      lambda x: x.bank)()

            output_sorting_network_t = Tuple(
                lane=Array[output_lane_for_bank_luts[0].data.N, Bit],
                val=ram_element_type.magma_repr())
            output_sorting_network = DefineBitonicSort(
                output_sorting_network_t, len(rams), lambda x: x.lane)()

            # wire luts, sorting networks, inputs, and rams
            # flatten all the sseq_layers to get flat magma type of inputs and outputs
            # tseqs don't affect magma types
            num_sseq_layers_inputs = num_nested_layers(
                remove_tseqs(shared_and_diff_subtypes.diff_input))
            num_sseq_layers_to_remove_inputs = max(0,
                                                   num_sseq_layers_inputs - 1)
            num_sseq_layers_outputs = num_nested_layers(
                remove_tseqs(shared_and_diff_subtypes.diff_output))
            num_sseq_layers_to_remove_outputs = max(
                0, num_sseq_layers_outputs - 1)
            if remove_tseqs(
                    shared_and_diff_subtypes.shared_outer) != ST_Tombstone():
                #num_sseq_layers_inputs += num_nested_layers(remove_tseqs(shared_and_diff_subtypes.shared_outer))
                #num_sseq_layers_outputs += num_nested_layers(remove_tseqs(shared_and_diff_subtypes.shared_outer))
                input_ports = flatten_ports(
                    transpose_outer_dimensions(
                        shared_and_diff_subtypes.shared_outer,
                        shared_and_diff_subtypes.diff_input, cls.I),
                    num_sseq_layers_to_remove_inputs)
                output_ports = flatten_ports(
                    transpose_outer_dimensions(
                        shared_and_diff_subtypes.shared_outer,
                        shared_and_diff_subtypes.diff_output, cls.O),
                    num_sseq_layers_to_remove_outputs)
            else:
                input_ports = flatten_ports(cls.I,
                                            num_sseq_layers_to_remove_inputs)
                output_ports = flatten_ports(
                    cls.O, num_sseq_layers_to_remove_outputs)
            # this is only used if the shared outer layers contains any sseqs
            sseq_layers_to_flatten = max(
                num_nested_layers(
                    remove_tseqs(shared_and_diff_subtypes.shared_outer)) - 1,
                0)
            for idx in range(len(rams)):
                # wire input and bank to input sorting network
                wire(write_bank_for_input_lane_luts[idx].data,
                     input_sorting_network.I[idx].bank)
                #if idx == 0:
                #    wire(cls.first_valid, write_valid_for_bank_luts[idx].data)
                if idx < t_in_diff.port_width():
                    # since the input_ports are lists, need to wire them individually to the sorting ports
                    if remove_tseqs(shared_and_diff_subtypes.shared_outer
                                    ) != ST_Tombstone():
                        cur_input_port = flatten_ports(input_ports[idx],
                                                       sseq_layers_to_flatten)
                        cur_sort_port = flatten_ports(
                            input_sorting_network.I[idx].val,
                            sseq_layers_to_flatten)
                        for i in range(len(cur_input_port)):
                            wire(cur_input_port[i], cur_sort_port[i])
                    else:
                        if num_sseq_layers_inputs == 0:
                            # input_ports will be an array of bits for 1 element
                            # if no sseq in t_in
                            wire(input_ports, input_sorting_network.I[idx].val)
                        else:
                            wire(input_ports[idx],
                                 input_sorting_network.I[idx].val)
                    #wire(cls.ram_wr, input_sorting_network.O[idx].val)
                    #wire(cls.ram_rd, rams[idx].RDATA)
                else:
                    zero_const = DefineCoreirConst(
                        ram_element_type.magma_repr().size(), 0)().O
                    cur_sn_input = input_sorting_network.I[idx].val
                    while len(cur_sn_input) != len(zero_const):
                        cur_sn_input = cur_sn_input[0]
                    wire(zero_const, cur_sn_input)

                # wire input sorting network, write addr, and write valid luts to banks
                wire(input_sorting_network.O[idx].val, rams[idx].WDATA)
                wire(write_addr_for_bank_luts[idx].data, rams[idx].WADDR)
                #wire(write_addr_for_bank_luts[idx].data[0], cls.addr_wr[idx])
                if has_ce:
                    wire(write_valid_for_bank_luts[idx].data & bit(cls.CE),
                         rams[idx].WE)
                else:
                    wire(write_valid_for_bank_luts[idx].data, rams[idx].WE)

                # wire output sorting network, read addr, read bank, and read enable
                wire(rams[idx].RDATA, output_sorting_network.I[idx].val)
                wire(output_lane_for_bank_luts[idx].data,
                     output_sorting_network.I[idx].lane)
                wire(read_addr_for_bank_luts[idx].data, rams[idx].RADDR)
                #wire(read_addr_for_bank_luts[idx].data[0], cls.addr_rd[idx])
                # ok to read invalid things, so in read value LUT
                if has_ce:
                    wire(bit(cls.CE), rams[idx].RE)
                else:
                    wire(DefineCoreirConst(1, 1)().O[0], rams[idx].RE)
                if has_reset:
                    wire(cls.RESET, rams[idx].RESET)

                # wire output sorting network value to output or term
                if idx < t_out_diff.port_width():
                    # since the output_ports are lists, need to wire them individually to the sorting ports
                    if remove_tseqs(shared_and_diff_subtypes.shared_outer
                                    ) != ST_Tombstone():
                        cur_output_port = flatten_ports(
                            output_ports[idx], sseq_layers_to_flatten)
                        cur_sort_port = flatten_ports(
                            output_sorting_network.O[idx].val,
                            sseq_layers_to_flatten)
                        for i in range(len(cur_output_port)):
                            wire(cur_output_port[i], cur_sort_port[i])
                    else:
                        if num_sseq_layers_outputs == 0:
                            # output_ports will be an array of bits for 1 element
                            # if no sseq in t_out
                            wire(output_sorting_network.O[idx].val,
                                 output_ports)
                        else:
                            wire(output_sorting_network.O[idx].val,
                                 output_ports[idx])
                else:
                    wire(output_sorting_network.O[idx].val,
                         TermAnyType(type(output_sorting_network.O[idx].val)))

                # wire sorting networks bank/lane to term as not used on outputs, just used for sorting
                wire(input_sorting_network.O[idx].bank,
                     TermAnyType(type(input_sorting_network.O[idx].bank)))
                wire(output_sorting_network.O[idx].lane,
                     TermAnyType(type(output_sorting_network.O[idx].lane)))
示例#17
0
wire(partialParallel16Convolution.I11, magmaInstance0.I[1][3])
wire(partialParallel16Convolution.I12, magmaInstance0.I[1][4])
wire(partialParallel16Convolution.I13, magmaInstance0.I[1][5])
wire(partialParallel16Convolution.I14, magmaInstance0.I[1][6])
wire(partialParallel16Convolution.I15, magmaInstance0.I[1][7])
wire(partialParallel16Convolution.O0, magmaInstance189.out)
wire(partialParallel16Convolution.O1, magmaInstance190.out)
wire(partialParallel16Convolution.O2, magmaInstance191.out)
wire(partialParallel16Convolution.O3, magmaInstance192.out)
wire(partialParallel16Convolution.O4, magmaInstance193.out)
wire(partialParallel16Convolution.O5, magmaInstance194.out)
wire(partialParallel16Convolution.O6, magmaInstance195.out)
wire(partialParallel16Convolution.O7, magmaInstance196.out)
wire(partialParallel16Convolution.O8, magmaInstance197.out)
wire(partialParallel16Convolution.O9, magmaInstance198.out)
wire(partialParallel16Convolution.O10, magmaInstance199.out)
wire(partialParallel16Convolution.O11, magmaInstance200.out)
wire(partialParallel16Convolution.O12, magmaInstance201.out)
wire(partialParallel16Convolution.O13, magmaInstance202.out)
wire(partialParallel16Convolution.O14, magmaInstance203.out)
wire(partialParallel16Convolution.O15, magmaInstance204.out)
wire(magmaInstance0.ready, partialParallel16Convolution.ready_data_in)
wire(magmaInstance0.valid, partialParallel16Convolution.valid_data_out)
wire(
    partialParallel16Convolution.valid_data_in
    & partialParallel16Convolution.ready_data_out
    & bit(partialParallel16Convolution.CE), magmaInstance0.CE)
ceTerm = TermAnyType(Enable)
wire(ceTerm.I, partialParallel16Convolution.CE)
EndCircuit()
示例#18
0
        def definition(cls):
            if type(t) == ST_TSeq:
                outer_counter = AESizedCounterModM(t.n + t.i,
                                                   has_ce=True,
                                                   has_reset=has_reset)
                inner_counters = DefineNestedCounters(
                    t.t,
                    has_last=True,
                    has_cur_valid=False,
                    has_ce=has_ce,
                    has_reset=has_reset,
                    valid_when_ce_off=valid_when_ce_off)()
                if has_last:
                    is_last = Decode(t.n + t.i - 1,
                                     outer_counter.O.N)(outer_counter.O)
                if has_cur_valid:
                    cur_valid_counter = AESizedCounterModM(t.valid_clocks(),
                                                           has_ce=True,
                                                           has_reset=has_reset)
                    wire(cur_valid_counter.O, cls.cur_valid)

                # if t.n is a power of 2 and always valid, then outer_counter.O.N not enough bits
                # for valid_length to contain t.n and for is_valid to get the right input
                # always valid in this case, so just emit 1
                if math.pow(2, outer_counter.O.N) - 1 < t.n:
                    is_valid = DefineCoreirConst(1, 1)().O[0]
                    if not has_last:
                        # never using the outer_counter is not has_last
                        last_term = TermAnyType(type(outer_counter.O))
                        wire(outer_counter.O, last_term.I)
                else:
                    valid_length = DefineCoreirConst(outer_counter.O.N, t.n)()
                    is_valid_cmp = DefineCoreirUlt(outer_counter.O.N)()
                    wire(is_valid_cmp.I0, outer_counter.O)
                    wire(is_valid_cmp.I1, valid_length.O)
                    is_valid = is_valid_cmp.O

                wire(inner_counters.valid & is_valid, cls.valid)
                if has_last:
                    wire(is_last & inner_counters.last, cls.last)
                if has_reset:
                    wire(cls.RESET, outer_counter.RESET)
                    wire(cls.RESET, inner_counters.RESET)
                    if has_cur_valid:
                        wire(cls.RESET, cur_valid_counter.RESET)
                if has_ce:
                    wire(bit(cls.CE) & inner_counters.last, outer_counter.CE)
                    wire(cls.CE, inner_counters.CE)
                    if has_cur_valid:
                        wire(
                            bit(cls.CE) & inner_counters.valid & is_valid,
                            cur_valid_counter.CE)
                else:
                    wire(inner_counters.last, outer_counter.CE)
                    if has_cur_valid:
                        wire(inner_counters.valid & is_valid,
                             cur_valid_counter.CE)
            elif is_nested(t):
                inner_counters = DefineNestedCounters(
                    t.t,
                    has_last,
                    has_cur_valid,
                    has_ce,
                    has_reset,
                    valid_when_ce_off=valid_when_ce_off)()

                wire(inner_counters.valid, cls.valid)
                if has_last:
                    wire(inner_counters.last, cls.last)
                if has_reset:
                    wire(cls.RESET, inner_counters.RESET)
                if has_ce:
                    wire(cls.CE, inner_counters.CE)
                if has_cur_valid:
                    wire(inner_counters.cur_valid, cls.cur_valid)
            else:
                # only 1 element, so always last and valid element
                valid_and_last = DefineCoreirConst(1, 1)()
                if has_last:
                    wire(valid_and_last.O[0], cls.last)
                if has_cur_valid:
                    cur_valid = DefineCoreirConst(1, 0)()
                    wire(cur_valid.O, cls.cur_valid)
                if has_ce:
                    if valid_when_ce_off:
                        wire(cls.valid, valid_and_last.O[0])
                        ce_term = TermAnyType(Bit)
                        wire(cls.CE, ce_term.I)
                    else:
                        wire(cls.valid, cls.CE)
                else:
                    wire(valid_and_last.O[0], cls.valid)
                if has_reset:
                    reset_term = TermAnyType(Bit)
                    wire(reset_term.I, cls.RESET)