コード例 #1
0
def sort_with_ub(instance: tik.Tik, src_ub_list, dst_ub, sorted_num):
    """
    sort_with_ub
    """
    ub_count = len(src_ub_list)
    if ub_count < 4:
        src_ub_list += [src_ub_list[-1]] * (4 - ub_count)
    element_count_list = [sorted_num] * 4
    valid_bit = 2**ub_count - 1
    instance.vmrgsort4(dst_ub,
                       src_ub_list,
                       element_count_list,
                       False,
                       valid_bit,
                       repeat_times=1)
コード例 #2
0
def _merge_recur(instance: tik.Tik,
                 out_ub,
                 dst_ub,
                 src_ub,
                 last_dim,
                 total_region_list,
                 level,
                 region_offset=0):
    """
    _merge_recur
    merge multi sorted region proposal list to one sorted region proposal list
    """

    # vmrgsort4 can merger at most 4 sorted region list
    def is_next_to_last_merge():
        return 1 < math.ceil(total_region_list / 4) <= 4

    loops = total_region_list // 4
    remain = total_region_list % 4

    if is_next_to_last_merge() and dst_ub.name == out_ub.name:
        dst_ub = instance.Tensor(out_ub.dtype,
                                 out_ub.shape,
                                 scope=tik.scope_ubuf,
                                 name="ub_merge_recur")

    merge_n0 = 16 * (4**(level - 1))
    merge_n1 = merge_n0
    merge_n2 = merge_n0
    merge_n3 = merge_n0
    merge_repeat = loops
    need_tail_process = False
    if loops > 0 and remain == 0:
        if merge_n0 * 4 * loops > last_dim:
            merge_repeat = loops - 1
            n012 = merge_n0 + merge_n1 + merge_n2
            merge_left = last_dim - ((merge_n0 * 4 * (loops - 1)) + n012)
            need_tail_process = True
    if merge_repeat > 0:
        ub_offset = region_offset
        src_list = (src_ub[ub_offset], src_ub[ub_offset + merge_n0 * 8],
                    src_ub[ub_offset + merge_n0 * 8 + merge_n1 * 8],
                    src_ub[ub_offset + merge_n0 * 8 + merge_n1 * 8 +
                           merge_n2 * 8])
        element_count_list = (merge_n0, merge_n1, merge_n2, merge_n3)
        valid_bit = 15
        instance.vmrgsort4(dst_ub[ub_offset], src_list, element_count_list,
                           False, valid_bit, merge_repeat)

    if need_tail_process:
        tail_offset = 4 * merge_n0 * merge_repeat * 8
        ub_offset = region_offset + tail_offset
        src_list = (src_ub[ub_offset], src_ub[ub_offset + merge_n0 * 8],
                    src_ub[ub_offset + merge_n0 * 8 + merge_n1 * 8],
                    src_ub[ub_offset + merge_n0 * 8 + merge_n1 * 8 +
                           merge_n2 * 8])
        element_count_list = (merge_n0, merge_n1, merge_n2, merge_left)
        valid_bit = 15
        instance.vmrgsort4(dst_ub[ub_offset],
                           src_list,
                           element_count_list,
                           False,
                           valid_bit,
                           repeat_times=1)

    if loops > 0:
        offset = 4 * loops * 16 * (4**(level - 1))
    else:
        offset = 0

    if remain == 3:
        merge_n0 = 16 * (4**(level - 1))
        merge_n1 = merge_n0
        merge_n2 = last_dim - (offset + merge_n0 + merge_n1)
        ub_offset = region_offset + offset * 8
        src_list = (src_ub[ub_offset], src_ub[ub_offset + merge_n0 * 8],
                    src_ub[ub_offset + merge_n0 * 8 + merge_n1 * 8],
                    src_ub[ub_offset + merge_n0 * 8 + merge_n1 * 8 +
                           merge_n2 * 8])
        element_count_list = (merge_n0, merge_n1, merge_n2, 0)
        valid_bit = 2**remain - 1
        instance.vmrgsort4(dst_ub[ub_offset],
                           src_list,
                           element_count_list,
                           False,
                           valid_bit,
                           repeat_times=1)
    elif remain == 2:
        merge_n0 = 16 * (4**(level - 1))
        merge_n1 = last_dim - (offset + merge_n0)
        ub_offset = region_offset + offset * 8
        src_list = (src_ub[ub_offset], src_ub[ub_offset + merge_n0 * 8],
                    src_ub[ub_offset + merge_n0 * 8 + merge_n1 * 8],
                    src_ub[ub_offset + merge_n0 * 8 + merge_n1 * 8 +
                           merge_n2 * 8])
        element_count_list = (merge_n0, merge_n1, 0, 0)
        valid_bit = 2**remain - 1
        instance.vmrgsort4(dst_ub[ub_offset],
                           src_list,
                           element_count_list,
                           False,
                           valid_bit,
                           repeat_times=1)
    elif remain == 1:
        merge_n0 = last_dim - offset
        num_blocks_write = (
            merge_n0 * 8 * common_util.get_data_size(src_ub.dtype) + 31) // 32
        ub_offset = region_offset + offset * 8
        instance.data_move(dst_ub[ub_offset], src_ub[ub_offset], 0, 1,
                           num_blocks_write, 0, 0)

    next_total_region_list = math.ceil(total_region_list / 4)
    if next_total_region_list <= 1:
        return dst_ub

    if is_next_to_last_merge():
        src_ub = out_ub

    return _merge_recur(instance, out_ub, src_ub, dst_ub, last_dim,
                        next_total_region_list, level + 1, region_offset)