コード例 #1
0
def _get_max_threads(batch_size):
    target = tvm.target.Target.current()
    max_threads = tvm.target.Target.current(allow_none=False).max_num_threads
    if "vulkan" in str(target) and not isinstance(batch_size, tvm.tir.IntImm):
        # SPIR-V does not support dynamic thread group size
        return max_threads
    return tir.min(batch_size, max_threads)
コード例 #2
0
ファイル: unique.py プロジェクト: Xuxue1/tvm
def _calc_first_occurence_ir(argsorted_indices, inc_scan, first_occurence):
    """Low level IR to calculate the first occurence of each unique element in the input data.

    Parameters
    ----------
    argsorted_indices : Buffer
        A buffer that stores the argsorted indices of the input data.

    inc_scan : Buffer
        A buffer that stores the inclusive scan of the binary tir.NE adjacent difference
        of the sorted data.

    first_occurence : Buffer
        A buffer that stores the first occurence of each unique element in the input data.
    """
    ib = tir.ir_builder.create()
    argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices)
    inc_scan_ptr = ib.buffer_ptr(inc_scan)
    first_occurence_ptr = ib.buffer_ptr(first_occurence)
    batch_size = argsorted_indices.shape[0]
    max_threads = tir.min(
        batch_size,
        tvm.target.Target.current(allow_none=False).max_num_threads)
    with ib.new_scope():
        nthread_tx = max_threads
        nthread_bx = ceil_div(batch_size, max_threads)
        tx = te.thread_axis("threadIdx.x")
        bx = te.thread_axis("blockIdx.x")
        ib.scope_attr(tx, "thread_extent", nthread_tx)
        ib.scope_attr(bx, "thread_extent", nthread_bx)
        tid = bx * max_threads + tx
        with ib.if_scope(tid < batch_size):
            first_occurence_ptr[tid] = batch_size
    with ib.new_scope():
        nthread_tx = max_threads
        nthread_bx = ceil_div(batch_size, max_threads)
        tx = te.thread_axis("threadIdx.x")
        bx = te.thread_axis("blockIdx.x")
        ib.scope_attr(tx, "thread_extent", nthread_tx)
        ib.scope_attr(bx, "thread_extent", nthread_bx)
        tid = bx * max_threads + tx
        with ib.if_scope(tid < batch_size):
            with ib.if_scope(tid == 0):
                first_occurence_ptr[
                    inc_scan_ptr[tid]] = argsorted_indices_ptr[tid]
            with ib.else_scope():
                with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]):
                    first_occurence_ptr[
                        inc_scan_ptr[tid]] = argsorted_indices_ptr[tid]
    return ib.get()
コード例 #3
0
def read_out_of_bound_after_compute_at(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [16], "float32")
    B = tir.alloc_buffer([16], "float32")
    C = tir.match_buffer(c, [16], "float32")
    for j in tir.serial(0, 16):
        for i in tir.serial(0, tir.min(1, 15 - j) + 1):
            with tir.block([16], "B") as [v]:
                tir.bind(v, j + i)
                B[v] = A[v]
        with tir.block([16], "C") as [v]:
            tir.bind(v, j)
            tir.reads([B[v:v + 2]])
            C[v] = tir.if_then_else(v < 15,
                                    tir.max(B[v], B[v + 1]),
                                    B[v],
                                    dtype="float32")
コード例 #4
0
ファイル: unique.py プロジェクト: Xuxue1/tvm
def _calc_adjacent_diff_ir(data, output, binop=tir.Sub):
    """Low level IR to calculate adjacent difference in an 1-D array.

    Parameters
    ----------
    data : Buffer
        Input 1-D Buffer.

    output: Buffer
        A buffer to store adjacent difference, of the same shape as data. The adjacent difference
        is defined as: output[0] = 0, output[i] = binop(data[i], data[i-1])
        where i > 0 and i < len(data).

    binop: function, optional
        A binary associative op to use for calculating adjacent difference. The function takes two
        TIR expressions and produce a new TIR expression. By default it uses tvm.tir.Sub to
        compute the adjacent difference.
    """
    ib = tir.ir_builder.create()
    data_ptr = ib.buffer_ptr(data)
    output_ptr = ib.buffer_ptr(output)
    batch_size = data.shape[0]
    max_threads = tir.min(
        batch_size,
        tvm.target.Target.current(allow_none=False).max_num_threads)
    with ib.new_scope():
        nthread_tx = max_threads
        nthread_bx = ceil_div(batch_size, max_threads)
        tx = te.thread_axis("threadIdx.x")
        bx = te.thread_axis("blockIdx.x")
        ib.scope_attr(tx, "thread_extent", nthread_tx)
        ib.scope_attr(bx, "thread_extent", nthread_bx)
        tid = bx * max_threads + tx
        with ib.if_scope(tid < batch_size):
            with ib.if_scope(tid == 0):
                output_ptr[tid] = 0
            with ib.else_scope():
                output_ptr[tid] = tir.Cast(
                    output.dtype, binop(data_ptr[tid], data_ptr[tid - 1]))
    return ib.get()
コード例 #5
0
ファイル: unique.py プロジェクト: Xuxue1/tvm
def _calc_unique_ir(data, argsorted_indices, inc_scan, index_converter,
                    unique_elements, indices, counts):
    """Low level IR to calculate unique elements, inverse indices, and counts (optional) of
    unique elements of 1-D array.

    Parameters
    ----------
    data : Buffer
        Input 1-D Buffer.

    argsorted_indices : Buffer
        A buffer that stores the argsorted indices of the input data.

    inc_scan : Buffer
        A buffer that stores the inclusive scan of the binary tir.NE adjacent difference
        of the sorted data.

    index_converter (optional) : Buffer
        An optional index converter that transforms the unique element index
        such that new_idx = index_converter[old_idx].

    unique_elements : Buffer
        A buffer that stores the unique elements.

    indices : Buffer
        A buffer that stores the the index of each input data element in the unique element array.

    counts (optional) : Buffer
        A buffer that stores the count of each unique element.
    """
    ib = tir.ir_builder.create()
    data_ptr = ib.buffer_ptr(data)
    argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices)
    inc_scan_ptr = ib.buffer_ptr(inc_scan)
    unique_elements_ptr = ib.buffer_ptr(unique_elements)
    indices_ptr = ib.buffer_ptr(indices)

    index_converter_ptr = None
    if isinstance(index_converter, tir.Buffer):
        index_converter_ptr = ib.buffer_ptr(index_converter)

    if isinstance(counts, tir.Buffer):
        counts_ptr = ib.buffer_ptr(counts)
        # use indices_ptr as a tmp buffer to store tids with inc_scan[tid] != inc_scan[tid-1]
        unique_seq_indices_ptr = ib.buffer_ptr(indices)

    batch_size = data.shape[0]
    max_threads = tir.min(
        batch_size,
        tvm.target.Target.current(allow_none=False).max_num_threads)

    # if need to return counts
    if isinstance(counts, tir.Buffer):
        num_unique = inc_scan_ptr[inc_scan.shape[0] - 1] + 1
        num_elements = data.shape[0]
        with ib.new_scope():
            nthread_tx = max_threads
            nthread_bx = ceil_div(batch_size, max_threads)
            tx = te.thread_axis("threadIdx.x")
            bx = te.thread_axis("blockIdx.x")
            ib.scope_attr(tx, "thread_extent", nthread_tx)
            ib.scope_attr(bx, "thread_extent", nthread_bx)
            tid = bx * max_threads + tx
            with ib.if_scope(tid < batch_size):
                with ib.if_scope(tid == 0):
                    unique_seq_indices_ptr[num_unique - 1] = num_elements
                with ib.else_scope():
                    with ib.if_scope(
                            inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]):
                        unique_seq_indices_ptr[inc_scan_ptr[tid] - 1] = tid
        with ib.new_scope():
            nthread_tx = max_threads
            nthread_bx = ceil_div(batch_size, max_threads)
            tx = te.thread_axis("threadIdx.x")
            bx = te.thread_axis("blockIdx.x")
            ib.scope_attr(tx, "thread_extent", nthread_tx)
            ib.scope_attr(bx, "thread_extent", nthread_bx)
            tid = bx * max_threads + tx
            with ib.if_scope(tid < num_unique):
                unique_idx = tid if not index_converter_ptr else index_converter_ptr[
                    tid]
                with ib.if_scope(tid == 0):
                    counts_ptr[unique_idx] = unique_seq_indices_ptr[tid]
                with ib.else_scope():
                    counts_ptr[unique_idx] = (unique_seq_indices_ptr[tid] -
                                              unique_seq_indices_ptr[tid - 1])
    # calculate unique elements and inverse indices
    with ib.new_scope():
        nthread_tx = max_threads
        nthread_bx = ceil_div(batch_size, max_threads)
        tx = te.thread_axis("threadIdx.x")
        bx = te.thread_axis("blockIdx.x")
        ib.scope_attr(tx, "thread_extent", nthread_tx)
        ib.scope_attr(bx, "thread_extent", nthread_bx)
        tid = bx * max_threads + tx
        with ib.if_scope(tid < batch_size):
            data_idx = argsorted_indices_ptr[tid]
            unique_idx = (inc_scan_ptr[tid] if not index_converter_ptr else
                          index_converter_ptr[inc_scan_ptr[tid]])
            indices_ptr[data_idx] = unique_idx
            with ib.if_scope(tid == 0):
                unique_elements_ptr[unique_idx] = data_ptr[data_idx]
            with ib.else_scope():
                with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]):
                    unique_elements_ptr[unique_idx] = data_ptr[data_idx]
    return ib.get()
コード例 #6
0
def _clamp_tvm(e, low, high):
    return tir.min(tir.max(e, low), high)
コード例 #7
0
 def apply(lhs, rhs):
     return tir.abs(_force_int(lhs)) << tir.min(30, tir.abs(
         _force_int(rhs)))
コード例 #8
0
 def apply(lhs, rhs):
     return tir.min(lhs, rhs)
コード例 #9
0
def _get_max_threads(batch_row):
    max_threads = tvm.target.Target.current(allow_none=False).max_num_threads
    return tir.min(batch_row, max_threads)