def _get_max_threads(batch_size): target = tvm.target.Target.current() max_threads = tvm.target.Target.current(allow_none=False).max_num_threads if "vulkan" in str(target) and not isinstance(batch_size, tvm.tir.IntImm): # SPIR-V does not support dynamic thread group size return max_threads return tir.min(batch_size, max_threads)
def _calc_first_occurence_ir(argsorted_indices, inc_scan, first_occurence): """Low level IR to calculate the first occurence of each unique element in the input data. Parameters ---------- argsorted_indices : Buffer A buffer that stores the argsorted indices of the input data. inc_scan : Buffer A buffer that stores the inclusive scan of the binary tir.NE adjacent difference of the sorted data. first_occurence : Buffer A buffer that stores the first occurence of each unique element in the input data. """ ib = tir.ir_builder.create() argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices) inc_scan_ptr = ib.buffer_ptr(inc_scan) first_occurence_ptr = ib.buffer_ptr(first_occurence) batch_size = argsorted_indices.shape[0] max_threads = tir.min( batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) with ib.new_scope(): nthread_tx = max_threads nthread_bx = ceil_div(batch_size, max_threads) tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx with ib.if_scope(tid < batch_size): first_occurence_ptr[tid] = batch_size with ib.new_scope(): nthread_tx = max_threads nthread_bx = ceil_div(batch_size, max_threads) tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx with ib.if_scope(tid < batch_size): with ib.if_scope(tid == 0): first_occurence_ptr[ inc_scan_ptr[tid]] = argsorted_indices_ptr[tid] with ib.else_scope(): with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]): first_occurence_ptr[ inc_scan_ptr[tid]] = argsorted_indices_ptr[tid] return ib.get()
def read_out_of_bound_after_compute_at(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [16], "float32") B = tir.alloc_buffer([16], "float32") C = tir.match_buffer(c, [16], "float32") for j in tir.serial(0, 16): for i in tir.serial(0, tir.min(1, 15 - j) + 1): with tir.block([16], "B") as [v]: tir.bind(v, j + i) B[v] = A[v] with tir.block([16], "C") as [v]: tir.bind(v, j) tir.reads([B[v:v + 2]]) C[v] = tir.if_then_else(v < 15, tir.max(B[v], B[v + 1]), B[v], dtype="float32")
def _calc_adjacent_diff_ir(data, output, binop=tir.Sub): """Low level IR to calculate adjacent difference in an 1-D array. Parameters ---------- data : Buffer Input 1-D Buffer. output: Buffer A buffer to store adjacent difference, of the same shape as data. The adjacent difference is defined as: output[0] = 0, output[i] = binop(data[i], data[i-1]) where i > 0 and i < len(data). binop: function, optional A binary associative op to use for calculating adjacent difference. The function takes two TIR expressions and produce a new TIR expression. By default it uses tvm.tir.Sub to compute the adjacent difference. """ ib = tir.ir_builder.create() data_ptr = ib.buffer_ptr(data) output_ptr = ib.buffer_ptr(output) batch_size = data.shape[0] max_threads = tir.min( batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) with ib.new_scope(): nthread_tx = max_threads nthread_bx = ceil_div(batch_size, max_threads) tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx with ib.if_scope(tid < batch_size): with ib.if_scope(tid == 0): output_ptr[tid] = 0 with ib.else_scope(): output_ptr[tid] = tir.Cast( output.dtype, binop(data_ptr[tid], data_ptr[tid - 1])) return ib.get()
def _calc_unique_ir(data, argsorted_indices, inc_scan, index_converter, unique_elements, indices, counts): """Low level IR to calculate unique elements, inverse indices, and counts (optional) of unique elements of 1-D array. Parameters ---------- data : Buffer Input 1-D Buffer. argsorted_indices : Buffer A buffer that stores the argsorted indices of the input data. inc_scan : Buffer A buffer that stores the inclusive scan of the binary tir.NE adjacent difference of the sorted data. index_converter (optional) : Buffer An optional index converter that transforms the unique element index such that new_idx = index_converter[old_idx]. unique_elements : Buffer A buffer that stores the unique elements. indices : Buffer A buffer that stores the the index of each input data element in the unique element array. counts (optional) : Buffer A buffer that stores the count of each unique element. """ ib = tir.ir_builder.create() data_ptr = ib.buffer_ptr(data) argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices) inc_scan_ptr = ib.buffer_ptr(inc_scan) unique_elements_ptr = ib.buffer_ptr(unique_elements) indices_ptr = ib.buffer_ptr(indices) index_converter_ptr = None if isinstance(index_converter, tir.Buffer): index_converter_ptr = ib.buffer_ptr(index_converter) if isinstance(counts, tir.Buffer): counts_ptr = ib.buffer_ptr(counts) # use indices_ptr as a tmp buffer to store tids with inc_scan[tid] != inc_scan[tid-1] unique_seq_indices_ptr = ib.buffer_ptr(indices) batch_size = data.shape[0] max_threads = tir.min( batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) # if need to return counts if isinstance(counts, tir.Buffer): num_unique = inc_scan_ptr[inc_scan.shape[0] - 1] + 1 num_elements = data.shape[0] with ib.new_scope(): nthread_tx = max_threads nthread_bx = ceil_div(batch_size, max_threads) tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx with ib.if_scope(tid < batch_size): with ib.if_scope(tid == 0): unique_seq_indices_ptr[num_unique - 1] = num_elements with ib.else_scope(): with ib.if_scope( inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]): unique_seq_indices_ptr[inc_scan_ptr[tid] - 1] = tid with ib.new_scope(): nthread_tx = max_threads nthread_bx = ceil_div(batch_size, max_threads) tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx with ib.if_scope(tid < num_unique): unique_idx = tid if not index_converter_ptr else index_converter_ptr[ tid] with ib.if_scope(tid == 0): counts_ptr[unique_idx] = unique_seq_indices_ptr[tid] with ib.else_scope(): counts_ptr[unique_idx] = (unique_seq_indices_ptr[tid] - unique_seq_indices_ptr[tid - 1]) # calculate unique elements and inverse indices with ib.new_scope(): nthread_tx = max_threads nthread_bx = ceil_div(batch_size, max_threads) tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx with ib.if_scope(tid < batch_size): data_idx = argsorted_indices_ptr[tid] unique_idx = (inc_scan_ptr[tid] if not index_converter_ptr else index_converter_ptr[inc_scan_ptr[tid]]) indices_ptr[data_idx] = unique_idx with ib.if_scope(tid == 0): unique_elements_ptr[unique_idx] = data_ptr[data_idx] with ib.else_scope(): with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]): unique_elements_ptr[unique_idx] = data_ptr[data_idx] return ib.get()
def _clamp_tvm(e, low, high): return tir.min(tir.max(e, low), high)
def apply(lhs, rhs): return tir.abs(_force_int(lhs)) << tir.min(30, tir.abs( _force_int(rhs)))
def apply(lhs, rhs): return tir.min(lhs, rhs)
def _get_max_threads(batch_row): max_threads = tvm.target.Target.current(allow_none=False).max_num_threads return tir.min(batch_row, max_threads)