def infer_stride(data, kernel, out): """Infer the stride from stages in reverse. Parameters ---------- data : Tensor data stage. kernel : Tensor kernel stage. out : Tensor output stage. Returns ------- hstride : int stride size on height wstride : int stride size on width """ _, _, IH, IW = data.shape _, _, KH, KW = kernel.shape _, _, OH, OW = out.shape hstride = (IH - KH) // tvm.make.Max(OH - 1, 1) + tvm.select(OH == 1, 1, 0) wstride = (IW - KW) // tvm.make.Max(OW - 1, 1) + tvm.select(OW == 1, 1, 0) return get_const_int(hstride), get_const_int(wstride)
def infer_stride(data, kernel, out): """Infer the stride from stages in reverse. Parameters ---------- data : Tensor data stage. kernel : Tensor kernel stage. out : Tensor output stage. Returns ------- hstride : int stride size on height wstride : int stride size on width """ _, _, IH, IW = data.shape _, _, KH, KW = kernel.shape _, _, OH, OW = out.shape hstride = (IH - KH) // tvm.make.Max(OH - 1, 1) + tvm.select(OH == 1, 1, 0) wstride = (IW - KW) // tvm.make.Max(OW - 1, 1) + tvm.select(OW == 1, 1, 0) return get_const_int(hstride), get_const_int(wstride)
def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): """Low level IR routing for multibox_prior operator. Parameters ---------- data : Buffer Input data buffer. out : Buffer Output buffer. sizes : tuple of float Tuple of sizes for anchor boxes. ratios : tuple of float Tuple of ratios for anchor boxes. steps : Tuple of float Priorbox step across y and x, -1 for auto calculation. offsets : tuple of int Priorbox center offsets, y and x respectively. Returns ------- stmt : Stmt The result IR statement. """ ib = tvm.ir_builder.create() p_out = ib.buffer_ptr(out) in_height = data.shape[2] in_width = data.shape[3] num_sizes = len(sizes) num_ratios = len(ratios) size_ratio_concat = sizes + ratios steps_h = steps[0] if steps[0] > 0 else 1.0 / in_height steps_w = steps[1] if steps[1] > 0 else 1.0 / in_width offset_h = offsets[0] offset_w = offsets[1] with ib.for_range(0, in_height, for_type="parallel", name="i") as i: center_h = (i + offset_h) * steps_h with ib.for_range(0, in_width, name="j") as j: center_w = (j + offset_w) * steps_w for k in range(num_sizes + num_ratios - 1): w = tvm.select(k < num_sizes, size_ratio_concat[k] * in_height / in_width / 2.0, size_ratio_concat[0] * in_height / in_width * math.sqrt(size_ratio_concat[k + 1]) / 2.0) h = tvm.select(k < num_sizes, size_ratio_concat[k] / 2.0, size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0) count = (i * in_width * (num_sizes + num_ratios - 1) + j * (num_sizes + num_ratios - 1) + k) * 4 p_out[count] = center_w - w p_out[count + 1] = center_h - h p_out[count + 2] = center_w + w p_out[count + 3] = center_h + h return ib.get()
def test_rewrite_select(): ib = tvm.ir_builder.create() A = ib.allocate("float32", 100, name="A", scope="global") i = tvm.var("i") y = tvm.select(i > 1, A[i - 1], 1.0) yy = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(y)).value z = tvm.select(tvm.select(i > 1, A[i - 1], 1.0) > 0.0, A[i], 0.1) zz = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(z)).value a = tvm.select(i > 10, y, z) aa = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(a)).value assert yy.name == "tvm_if_then_else" assert zz.name == "tvm_if_then_else" assert isinstance(aa, tvm.expr.Select)
def test_rewrite_select(): ib = tvm.ir_builder.create() A = ib.allocate("float32", 100, name="A", scope="global") i = tvm.var("i") y = tvm.select(i > 1, A[i-1], 1.0) yy = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(y)).value z = tvm.select(tvm.select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1) zz = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(z)).value a = tvm.select(i>10, y, z) aa = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(a)).value assert yy.name == "tvm_if_then_else" assert zz.name == "tvm_if_then_else" assert isinstance(aa, tvm.expr.Select)
def test_copy_pad_split(): m = 4 * 3 A = tvm.placeholder((m, ), name="A") Apad = tvm.compute( (m + 2, ), lambda i: tvm.select(tvm.all(i >= 1, i <= m), A[i - 1], 0.0), "Apad") B = tvm.compute((m, ), lambda i: Apad[i] + Apad[i + 1] + Apad[i + 2]) s = tvm.create_schedule(B.op) xo, xi = s[B].split(B.op.axis[0], factor=4) s[Apad].compute_at(s[B], xo) s[Apad].pragma(s[Apad].op.axis[0], "memcpy") bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.CanonicalSimplify(stmt) def cb(src, dst, pad_before, pad_after, pad_value): assert (dst.elem_offset.value == 0) assert_expr_equal(src.elem_offset, tvm.max(xo * 4, 1) - 1) rpad_before = tvm.max(1 - xo * 4, 0) rpad_after = tvm.max(xo * 4 - 7, 0) assert_expr_equal(pad_before[0], rpad_before) assert_expr_equal(pad_after[0], rpad_after) assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after) return tvm.make.Evaluate(0) stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
def select_array(i, j): now = tvm.const(0.0, dtype) for ii in range(row): for jj in range(col): now = tvm.select(tvm.all(i % row == ii, j % col == jj), tvm.const(matrix[ii][jj], dtype), now) return now
def test_copy_pad(): m = tvm.var('m') l = tvm.var('l') A = tvm.placeholder((m, l), name='A') B = tvm.compute( (m + 2, l), lambda i, j: tvm.select(tvm.all(i >= 1, i < m + 1), A[i - 1, j], 1.0), name='B') s = tvm.create_schedule(B.op) s[B].pragma(B.op.axis[0], "memcpy") bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) def cb(src, dst, pad_before, pad_after, pad_value): assert tvm.ir_pass.Simplify(src.elem_offset).value == 0 assert pad_before[0].value == 1 assert pad_before[1].value == 0 assert pad_after[0].value == 1 assert pad_after[1].value == 0 assert pad_value.value == 1.0 return tvm.make.Evaluate(0) stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
def test_copy_pad_split(): m = 4 * 3 A = tvm.placeholder((m, ), name="A") Apad = tvm.compute((m + 2,), lambda i: tvm.select(tvm.all(i >= 1, i <= m), A[i - 1], 0.0), "Apad") B = tvm.compute((m,), lambda i: Apad[i] + Apad[i + 1] + Apad[i + 2]) s = tvm.create_schedule(B.op) xo, xi = s[B].split(B.op.axis[0], factor=4) s[Apad].compute_at(s[B], xo) s[Apad].pragma(s[Apad].op.axis[0], "memcpy") bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.CanonicalSimplify(stmt) def cb(src, dst, pad_before, pad_after, pad_value): assert(dst.elem_offset.value == 0) assert_expr_equal(src.elem_offset, tvm.max(xo * 4, 1) - 1) rpad_before = tvm.max(1 - xo * 4, 0) rpad_after = tvm.max(xo * 4 - 7, 0) assert_expr_equal(pad_before[0], rpad_before) assert_expr_equal(pad_after[0], rpad_after) assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after) return tvm.make.Evaluate(0) stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
def conv_time(args): tile_x, tile_y, step, target, dev_id, number = args # Algorithm A = tvm.te.placeholder((in_size, in_size, in_channel, batch), name='A') W = tvm.te.placeholder((kernel, kernel, in_channel, out_channel), name='W') # Pad input Apad = tvm.te.compute( (in_size + 2 * pad, in_size + 2 * pad, in_channel, batch), lambda yy, xx, cc, nn: tvm.select( tvm.te.all(yy >= pad, yy - pad < in_size, xx >= pad, xx - pad < in_size), A[yy - pad, xx - pad, cc, nn], tvm.const(0., "float32")), name='Apad') # Create reduction variables rc = tvm.te.reduce_axis((0, in_channel), name='rc') ry = tvm.te.reduce_axis((0, kernel), name='ry') rx = tvm.te.reduce_axis((0, kernel), name='rx') # Compute the convolution B = tvm.te.compute( (out_size, out_size, out_channel, batch), lambda yy, xx, ff, nn: tvm.te.sum(Apad[yy * stride + ry, xx * stride + rx, rc, nn] * W[ry, rx, rc, ff], axis=[ry, rx, rc]), name='B') # Designate the memory hierarchy s = tvm.te.create_schedule(B.op) # s[Apad].compute_inline() BL = s.cache_write(B, "local") # tile consts # tile = 16 # step = 16 # Split the workloads hi, wi, fi, ni = s[B].op.axis bz = s[B].fuse(hi, wi) by, fi = s[B].split(fi, factor=tile_x) bx, ni = s[B].split(ni, factor=tile_y) s[B].reorder(bx, bz, by, fi, ni) bp = s[B].fuse(bz, bx) s[B].parallel(bp) # xi, yi, ci = s[B].op.reduce_axis # co, ci = s[B].split(ci, factor=step) # s[B].reorder(co, xi, yi, fi, ci, ni) # s[B].unroll(ci) # s[B].vectorize(ni) s[BL].compute_at(s[B], bp) h, w, f, n = s[BL].op.axis xi, yi, ci = s[BL].op.reduce_axis co, ci = s[BL].split(ci, factor=step) s[BL].reorder(co, xi, yi, f, ci, n) s[BL].unroll(ci) s[BL].vectorize(n) time_cost = evaluate(s, [A, W, B], target, dev_id, number) print("args={}, time_cost={}".format(args, time_cost)) # stmt = tvm.lower(s, [A, W, B], simple_mode=True) # print(stmt) return time_cost
def select_array(i, j): now = tvm.const(0.0, dtype) for ii in range(row): for jj in range(col): now = tvm.select(tvm.all(i % row == ii, j % col == jj), tvm.const(data[ii][jj], dtype), now) return now
def make_relu_gradient(shape, tgt, tgt_host, func_name, dtype="float32"): A = tvm.placeholder(shape, dtype=dtype, name="A") B = tvm.placeholder(shape, dtype=dtype, name="B") C = tvm.compute( A.shape, lambda *i: B(*i) * tvm.select( A(*i) > 0, tvm.const(1, A.dtype), tvm.const(0, A.dtype)), "C") s = tvm.create_schedule(C.op) return tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name=func_name)
def _compute(*indices): ret = a_tuple[0](*indices) ind = indices[axis] for i in range(len(a_tuple) - 1): ind -= axis_sizes[i] ret = tvm.select( ind >= 0, a_tuple[i + 1](*(indices[0:axis] + (ind, ) + indices[axis + 1:])), ret) return ret
def make_relu_gradient(shape, tgt, tgt_host, func_name, dtype="float32"): """Hint: use tvm.select""" A = tvm.placeholder(shape, dtype=dtype, name="A") B = tvm.placeholder(shape, dtype=dtype, name="B") C = tvm.compute(A.shape, lambda *i: B(*i) * tvm.select(A(*i) > 0, 1, 0)) s = tvm.create_schedule(C.op) f = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name=func_name) return f
def make_relu_gradient(shape, tgt, tgt_host, func_name, dtype="float32"): inp = tvm.placeholder(shape, dtype, name="input") out_grad = tvm.placeholder(shape, dtype, name="input") inp_grad = tvm.compute(shape, lambda *i: tvm.select(inp(*i) > 0), out_grad(*i), tvm.const(0, dtype)) s = tvm.create_schedule(inp_grad.op) f = tvm.build(s, [inp, out_grad], tgt, tgt_host, func_name) return f
def _compute(*indices): ret = a_tuple[0](*indices) ind = indices[axis] for i in range(len(a_tuple) - 1): ind -= axis_sizes[i] ret = tvm.select(ind >= 0, a_tuple[i + 1](*(indices[0:axis] + (ind,) + indices[axis + 1:])), ret) return ret
def test_schedule_bound_condition(): A = tvm.placeholder((64,), name='A', dtype="float32") Apad = tvm.compute((66,), lambda i: tvm.select(tvm.all(i>0, i < 65), A[i-1], tvm.const(0.)), name='Apad') Apad2 = tvm.compute((66,), lambda i: Apad[i]*2, name='Apad2') s = tvm.create_schedule(Apad2.op) AL1 = s.cache_read(A,"local",[Apad]) s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.ir_pass.Simplify(stmt) assert (isinstance(stmt.body.body.first.body.body.then_case, tvm.stmt.IfThenElse))
def schedule_pack_data(input): # input: c, h, w osize = in_size + 2 * padding shape = (batch_size, in_channel // bn, osize, bn, osize) # shape = (batch_size, in_channel, osize, osize) data_pad = tvm.compute( shape, lambda n, C, h, c, w: tvm.select( tvm.all(h >= padding, h < osize - padding, w >= padding, w < osize - padding), input[n, C * bn + c, h - padding, w - padding], 0.0)) s = tvm.create_schedule(data_pad.op) return s, data_pad
def test_schedule_bound_condition(): A = tvm.placeholder((64,), name='A', dtype="float32") Apad = tvm.compute((66,), lambda i: tvm.select( tvm.all(i>0, i < 65), A[i-1], tvm.const(0., "float32")), name='Apad') Apad2 = tvm.compute((66,), lambda i: Apad[i]*2, name='Apad2') s = tvm.create_schedule(Apad2.op) AL1 = s.cache_read(A,"local",[Apad]) s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.ir_pass.Simplify(stmt) assert (isinstance(stmt.body.body.first.body.body.then_case, tvm.stmt.IfThenElse))
def calculate_overlap(out_tensor, box_a_idx, box_b_idx): """Calculate overlap of two boxes. """ w = tvm.make.Max(0.0, tvm.make.Min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2]) - tvm.make.Max(out_tensor[box_a_idx], out_tensor[box_b_idx])) h = tvm.make.Max(0.0, tvm.make.Min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3]) - tvm.make.Max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1])) i = w * h u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx]) * \ (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \ (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \ (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i return tvm.select(u <= 0.0, 0.0, i / u)
def _dilate(*indices): not_zero = [] index_tuple = [] for i in range(n): if not util.equal_const_int(strides[i], 1): index_tuple.append(indices[i] / strides[i]) not_zero.append((indices[i] % strides[i]).equal(0)) else: index_tuple.append(indices[i]) if not_zero: not_zero = tvm.all(*not_zero) return tvm.select(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype)) return data(*index_tuple)
def calculate_overlap(out_tensor, box_a_idx, box_b_idx): """Calculate overlap of two boxes. """ w = tvm.make.Max(0.0, tvm.make.Min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2]) - tvm.make.Max(out_tensor[box_a_idx], out_tensor[box_b_idx])) h = tvm.make.Max(0.0, tvm.make.Min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3]) - tvm.make.Max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1])) i = w * h u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx]) * \ (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \ (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \ (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i return tvm.select(u <= 0.0, 0.0, i / u)
def _dilate(*indices): not_zero = [] index_tuple = [] for i in range(n): if not equal_const_int(strides[i], 1): index_tuple.append(indices[i] // strides[i]) not_zero.append((indices[i] % strides[i]).equal(0)) else: index_tuple.append(indices[i]) if not_zero: not_zero = tvm.all(*not_zero) return tvm.select(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype)) return data(*index_tuple)
def _decl_im2col(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'): """declare the Im2Col method for conv2d""" _, CI, IH, IW = [x.value for x in data.shape] CO, _, KH, KW = [x.value for x in kernel.shape] HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) if isinstance(stride, (tuple, list)): HSTR, WSTR = stride else: HSTR, WSTR = stride, stride N = 1 OH = (IH + 2*HPAD - KH) // HSTR + 1 OW = (IW + 2*WPAD - KW) // WSTR + 1 DO_PAD = (HPAD != 0 and WPAD != 0) if DO_PAD: data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") else: data_pad = data ALIGN = 16 def upround(x, align): return (x + align - 1) // align * align # A [CO, CI * KH * KW] reduce_len = upround(CI * KH * KW, ALIGN) A = tvm.compute((upround(CO, ALIGN), reduce_len), lambda i, j: kernel[i][j // KW // KH][j // KW % KH][j % KW], name='A') # B [CI * KH * KW, N * OH * OW] B = tvm.compute((reduce_len, upround(N * OH * OW, ALIGN)), lambda i, j:\ tvm.select(tvm.all(i < CI * KH * KW, j < N * OH * OW), data_pad[j // (OH*OW)][i // (KH*KW)][j // OW % OH*HSTR + i // KW % KH] [j % OW*WSTR + i % KW], tvm.const(0, data_pad.dtype)), name='B') gemm_n, gemm_l, gemm_m = A.shape[0], reduce_len, B.shape[1] # C [CO, N * OH * OW] k = tvm.reduce_axis((0, gemm_l), name='k') C = tvm.compute((gemm_n, gemm_m), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C') # output # the last term C[gemm_n-1, gemm_m-1] is for enabling the alignment, # otherwise the alignment above will be eliminated by bound inference output = tvm.compute((N, CO, OH, OW), lambda n, co, h, w:\ C[co][n * OW * OW + h * OW + w] + tvm.const(0, C.dtype) * C[gemm_n-1, gemm_m-1], name='output', tag='im2col_conv_output') return output
def _decl_im2col(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'): """declare the Im2Col method for conv2d""" _, CI, IH, IW = [x.value for x in data.shape] CO, _, KH, KW = [x.value for x in kernel.shape] HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) if isinstance(stride, (tuple, list)): HSTR, WSTR = stride else: HSTR, WSTR = stride, stride N = 1 OH = (IH + 2*HPAD - KH) // HSTR + 1 OW = (IW + 2*WPAD - KW) // WSTR + 1 DO_PAD = (HPAD != 0 and WPAD != 0) if DO_PAD: data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") else: data_pad = data ALIGN = 16 def upround(x, align): return (x + align - 1) // align * align # A [CO, CI * KH * KW] reduce_len = upround(CI * KH * KW, ALIGN) A = tvm.compute((upround(CO, ALIGN), reduce_len), lambda i, j: kernel[i][j // KW // KH][j // KW % KH][j % KW], name='A') # B [CI * KH * KW, N * OH * OW] B = tvm.compute((reduce_len, upround(N * OH * OW, ALIGN)), lambda i, j:\ tvm.select(tvm.all(i < CI * KH * KW, j < N * OH * OW), data_pad[j // (OH*OW)][i // (KH*KW)][j // OW % OH*HSTR + i // KW % KH] [j % OW*WSTR + i % KW], tvm.const(0, data_pad.dtype)), name='B') gemm_n, gemm_l, gemm_m = A.shape[0], reduce_len, B.shape[1] # C [CO, N * OH * OW] k = tvm.reduce_axis((0, gemm_l), name='k') C = tvm.compute((gemm_n, gemm_m), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C') # output # the last term C[gemm_n-1, gemm_m-1] is for enabling the alignment, # otherwise the alignment above will be eliminated by bound inference output = tvm.compute((N, CO, OH, OW), lambda n, co, h, w:\ C[co][n * OW * OW + h * OW + w] + tvm.const(0, C.dtype) * C[gemm_n-1, gemm_m-1], name='output', tag='im2col_conv_output') return output
def _pad(*indices): not_zero = [] index_tuple = [] for i in range(n): if equal_const_int(pad_before[i], 0) and equal_const_int(pad_after[i], 0): index_tuple.append(indices[i]) else: index_tuple.append(indices[i] - pad_before[i]) not_zero.append(indices[i] >= pad_before[i]) not_zero.append(indices[i] < data.shape[i] + pad_before[i]) if not_zero: not_zero = tvm.all(*not_zero) return tvm.select(not_zero, data(*index_tuple), pad_value) return data(*index_tuple)
def make_relu_gradient(shape, tgt, tgt_host, func_name, dtype="float32"): """TODO: Your code here""" """Hint: use tvm.select""" A = tvm.placeholder(shape, dtype=dtype, name="A") output_grad = tvm.placeholder(shape, dtype=dtype, name="rg_output_grad") res = tvm.compute(A.shape, lambda *i: tvm.select(A(*i) <= 0, 0.0, output_grad(*i))) s = tvm.create_schedule([res.op]) f = tvm.build(s, [A, output_grad, res], tgt, target_host=tgt_host, name=func_name) return f
def compute_temp(k, p, eps, nu): temp_expr = {} for j in range(4): t0 = M[0][j][k][p] + M[1][j][k][p] t1 = M[1][j][k][p] - M[2][j][k][p] temp_expr[(0, j)] = t0 + M[2][j][k][p] temp_expr[(1, j)] = t1 - M[3][j][k][p] now = tvm.const(0.0, "float32") for ii in range(2): for jj in range(4): now = tvm.select(tvm.all(eps == ii, nu == jj), temp_expr[(ii, jj)], now) return now
def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw, vh): """Transform prior anchor box to output box through location predictions. """ al = anchor[anchor_base_idx] at = anchor[anchor_base_idx + 1] ar = anchor[anchor_base_idx + 2] ab = anchor[anchor_base_idx + 3] aw = ar - al ah = ab - at ax = (al + ar) / 2.0 ay = (at + ab) / 2.0 px = loc[loc_base_idx] py = loc[loc_base_idx + 1] pw = loc[loc_base_idx + 2] ph = loc[loc_base_idx + 3] ox = px * vx * aw + ax oy = py * vy * ah + ay ow = tvm.exp(pw * vw) * aw / 2.0 oh = tvm.exp(ph * vh) * ah / 2.0 return tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox - ow)), ox - ow), \ tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy - oh)), oy - oh), \ tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox + ow)), ox + ow), \ tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy + oh)), oy + oh)
def test_prim(reducer, np_reducer): # graph n = tvm.var('n') m = tvm.var('m') A = tvm.placeholder((n, m), name='A') R = tvm.compute((n, ), lambda i: tvm.select((i > 1), 1, 0), name='R') k = tvm.reduce_axis((0, m)) B = tvm.compute((n, ), lambda i: reducer(A[i, k], axis=k, where=(R[i] == 1)), name='B') # schedule s = tvm.create_schedule(B.op) # create iter var and assign them tags. num_thread = 1 xo, xi = s[B].split(B.op.axis[0], factor=num_thread) s[B].bind(xo, tvm.thread_axis("blockIdx.x")) s[B].bind(xi, tvm.thread_axis("threadIdx.x")) s[R].compute_inline() # one line to build the function. def check_device(device, host="stackvm"): ctx = tvm.context(device, 0) if not tvm.module.enabled(host): return if not ctx.exist: print("skip because %s is not enabled.." % device) return freduce = tvm.build(s, args=[A, B], target=device, target_host=host, name="myreduce") # launch the kernel. n = 1028 m = 129 x = tvm.nd.array( np.random.uniform(size=(n, m)).astype(A.dtype), ctx) y = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx) freduce(x, y) npy = y.asnumpy() npy[:2] = 0 res = np_reducer(x.asnumpy(), axis=1) res[:2] = 0 tvm.testing.assert_allclose(npy, res, rtol=1e-4) check_device("metal") check_device("vulkan") check_device("cuda") check_device("opencl")
def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw, vh): """Transform prior anchor box to output box through location predictions. """ al = anchor[anchor_base_idx] at = anchor[anchor_base_idx + 1] ar = anchor[anchor_base_idx + 2] ab = anchor[anchor_base_idx + 3] aw = ar - al ah = ab - at ax = (al + ar) / 2.0 ay = (at + ab) / 2.0 px = loc[loc_base_idx] py = loc[loc_base_idx + 1] pw = loc[loc_base_idx + 2] ph = loc[loc_base_idx + 3] ox = px * vx * aw + ax oy = py * vy * ah + ay ow = tvm.exp(pw * vw) * aw / 2.0 oh = tvm.exp(ph * vh) * ah / 2.0 return tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox - ow)), ox - ow), \ tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy - oh)), oy - oh), \ tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox + ow)), ox + ow), \ tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy + oh)), oy + oh)
def conv_compute(n, m, h_, w_): x = h_ * stride - padding + r y = w_ * stride - padding + s return tvm.sum( tvm.select( tvm.any( x < 0, y < 0, x >= H, y >= W, ), tvm.const(0, dtype), # padding data[n, c, x, y] * filters[m, c, r, s] ), axis = [c, r, s] )
def _pad(*indices): not_zero = [] index_tuple = [] for i in range(n): if equal_const_int(pad_before[i], 0) and equal_const_int( pad_after[i], 0): index_tuple.append(indices[i]) else: index_tuple.append(indices[i] - pad_before[i]) not_zero.append(indices[i] >= pad_before[i]) not_zero.append(indices[i] < data.shape[i] + pad_before[i]) if not_zero: not_zero = tvm.all(*not_zero) return tvm.select(not_zero, data(*index_tuple), pad_value) return data(*index_tuple)
def check_llvm(n, offset): if not tvm.module.enabled("llvm"): return A = tvm.placeholder((n, ), name='A') C = tvm.compute((n,), lambda i: tvm.select(i >= offset, A[i], 0.0), name='C') s = tvm.create_schedule(C.op) # build and invoke the kernel. f = tvm.build(s, [A, C], "llvm") ctx = tvm.cpu(0) # launch the kernel. a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx) c = tvm.nd.empty((n,), A.dtype, ctx) f(a, c) c_np = a.asnumpy() c_np[:offset] = 0 np.testing.assert_allclose(c.asnumpy(), c_np)
def dilate_kernel( *indices ): # This function is the same as topi.nn.dilate, but inlined not_zero = [] index_tuple = [] for i in range(len(dilate_args)): if not topi.util.equal_const_int(dilate_args[i], 1): index_tuple.append(indices[i] // dilate_args[i]) not_zero.append((indices[i] % dilate_args[i]).equal(0)) else: index_tuple.append(indices[i]) if not_zero: not_zero = tvm.all(*not_zero) return tvm.select(not_zero, kernel(*index_tuple), tvm.const(0.0, data.dtype)) return kernel(*index_tuple)
def check_llvm(n, offset): if not tvm.module.enabled("llvm"): return A = tvm.placeholder((n, ), name='A') C = tvm.compute((n,), lambda i: tvm.select(i >= offset, A[i], 0.0), name='C') s = tvm.create_schedule(C.op) # build and invoke the kernel. f = tvm.build(s, [A, C], "llvm") ctx = tvm.cpu(0) # launch the kernel. a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx) c = tvm.nd.empty((n,), A.dtype, ctx) f(a, c) c_np = a.asnumpy() c_np[:offset] = 0 tvm.testing.assert_allclose(c.asnumpy(), c_np)
def compute_output(n, k, h, w): b = n * nH * nW + (h // m) * nW + w // m eps = h % m nu = w % m output_expr = {} for i in range(2): t0 = temp[k][b][i][0] + temp[k][b][i][1] t1 = temp[k][b][i][1] - temp[k][b][i][2] output_expr[(i, 0)] = t0 + temp[k][b][i][2] output_expr[(i, 1)] = t1 - temp[k][b][i][3] now = tvm.const(0.0, "float32") for ii in range(2): for jj in range(2): now = tvm.select(tvm.all(eps == ii, nu == jj), output_expr[(ii, jj)], now) return now
def test_prim(reducer, np_reducer): # graph n = tvm.var('n') m = tvm.var('m') A = tvm.placeholder((n, m), name='A') R = tvm.compute((n, ), lambda i: tvm.select((i > 1), 1, 0), name='R') k = tvm.reduce_axis((0, m)) B = tvm.compute((n,), lambda i: reducer(A[i, k], axis=k, where=(R[i]==1)), name='B') # schedule s = tvm.create_schedule(B.op) # create iter var and assign them tags. num_thread = 1 xo, xi = s[B].split(B.op.axis[0], factor=num_thread) s[B].bind(xo, tvm.thread_axis("blockIdx.x")) s[B].bind(xi, tvm.thread_axis("threadIdx.x")) s[R].compute_inline() # one line to build the function. def check_device(device, host="stackvm"): ctx = tvm.context(device, 0) if not tvm.module.enabled(host): return if not ctx.exist: print("skip because %s is not enabled.." % device) return freduce = tvm.build(s, args=[A, B], target=device, target_host=host, name="myreduce") # launch the kernel. n = 1028 m = 129 x = tvm.nd.array(np.random.uniform(size=(n, m)).astype(A.dtype), ctx) y = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx) freduce(x, y) npy = y.asnumpy() npy[:2] = 0 res = np_reducer(x.asnumpy(), axis=1) res[:2] = 0 np.testing.assert_allclose(npy, res, rtol=1e-4) check_device("metal") check_device("vulkan") check_device("cuda") check_device("opencl")
def make_conv2d_unoptimized(shapeX, shapeF, tgt, tgt_host, func_name, dtype="float32"): in_size, in_size, in_channel, batch = shapeX kernel, kernel, in_channel, out_channel = shapeF pad = 1 stride = 1 A = tvm.placeholder((in_size, in_size, in_channel, batch), name='A') W = tvm.placeholder((kernel, kernel, in_channel, out_channel), name='W') out_size = (in_size - kernel + 2 * pad) // stride + 1 # Pad input Apad = tvm.compute( (in_size + 2 * pad, in_size + 2 * pad, in_channel, batch), lambda yy, xx, cc, nn: tvm.select( tvm.all(yy >= pad, yy - pad < in_size, xx >= pad, xx - pad < in_size), A[yy - pad, xx - pad, cc, nn], tvm.const(0.)), name='Apad') # Create reduction variables rc = tvm.reduce_axis((0, in_channel), name='rc') ry = tvm.reduce_axis((0, kernel), name='ry') rx = tvm.reduce_axis((0, kernel), name='rx') # Compute the convolution B = tvm.compute( (out_size, out_size, out_channel, batch), lambda yy, xx, ff, nn: tvm.sum(Apad[yy * stride + ry, xx * stride + rx, rc, nn] * W[ry, rx, rc, ff], axis=[ry, rx, rc]), name='B') s = tvm.create_schedule(B.op) s[Apad].bind(Apad.op.axis[0], tvm.thread_axis("blockIdx.x")) s[Apad].bind(Apad.op.axis[1], tvm.thread_axis("threadIdx.x")) s[B].bind(B.op.axis[0], tvm.thread_axis("blockIdx.x")) s[B].bind(B.op.axis[1], tvm.thread_axis("threadIdx.x")) f = tvm.build(s, [A, W, B], tgt, target_host=tgt_host, name=func_name) return _export_module(f, func_name, remote)
def make_relu_gradient(shape, tgt, tgt_host, func_name, dtype="float32"): """TODO: Your code here""" """Hint: use tvm.select""" # 1 if > 0 else 0 # describe A = tvm.placeholder(shape, dtype=dtype, name="A") B = tvm.placeholder(shape, dtype=dtype, name="B") zero = tvm.const(0, A.dtype) C = tvm.compute(A.shape, lambda *i: tvm.select(A(*i) > zero, B(*i), zero)) # schedule s = tvm.create_schedule(C.op) # compile f = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name=func_name) return f
def compute_X_dot_A(k, b, eps, nu, kk, bb): temp_expr = {} for i in range(m): m1_add_m2 = A_T_dot_M[k][b][i][1][kk][bb] + A_T_dot_M[k][b][i][2][ kk][bb] m1_sub_m2 = A_T_dot_M[k][b][i][1][kk][bb] - A_T_dot_M[k][b][i][2][ kk][bb] m3_add_m4 = A_T_dot_M[k][b][i][3][kk][bb] + A_T_dot_M[k][b][i][4][ kk][bb] m3_sub_m4 = A_T_dot_M[k][b][i][3][kk][bb] - A_T_dot_M[k][b][i][4][ kk][bb] m5_add_m6 = A_T_dot_M[k][b][i][5][kk][bb] + A_T_dot_M[k][b][i][6][ kk][bb] m5_sub_m6 = A_T_dot_M[k][b][i][5][kk][bb] - A_T_dot_M[k][b][i][6][ kk][bb] s0 = A_T_dot_M[k][b][i][0][kk][bb] + m1_add_m2 s5 = A_T_dot_M[k][b][i][7][kk][bb] + m1_sub_m2 s1 = m1_sub_m2 + m5_sub_m6 * 16 s4 = m1_add_m2 + m3_add_m4 * 16 s2 = m1_add_m2 + 8 * m5_add_m6 s3 = m1_sub_m2 + 8 * m3_sub_m4 s0 = s0 + m5_add_m6 * 32 s5 = s5 + m3_sub_m4 * 32 s1 = s1 + m3_sub_m4 * 2 s4 = s4 + m5_add_m6 * 2 s0 = s0 + m3_add_m4 s5 = s5 + m5_sub_m6 s2 = s2 + m3_add_m4 * 4 s3 = s3 + m5_sub_m6 * 4 temp_expr[(i, 0)] = s0 temp_expr[(i, 1)] = s1 temp_expr[(i, 2)] = s2 temp_expr[(i, 3)] = s3 temp_expr[(i, 4)] = s4 temp_expr[(i, 5)] = s5 now = tvm.const(0.0, "float32") for ii in range(m): for jj in range(m): now = tvm.select(tvm.all(eps == ii, nu == jj), temp_expr[(ii, jj)], now) return now
def decl_winograd(data, U, stride, padding, out_dtype): """declare winograd fast convolution F(2x2, 3x3) for conv2d""" N, C, H, W = [util.get_const_int(x) for x in data.shape] _, _, C, K = [util.get_const_int(x) for x in U.shape] HPAD, WPAD = 1, 1 if isinstance(stride, (tuple, list)): HSTR, WSTR = stride else: HSTR, WSTR = stride, stride assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") m = 2 r = 3 alpha = m + r - 1 K = K nH, nW = (H + m - 1) // m, (W + m - 1) // m P = N * nH * nW # pack input tile input_tile = tvm.compute( (C, P, alpha, alpha), lambda c, b, eps, nu: tvm.select( b < P, data_pad[b // (nH * nW)][c][b // nW % nH * m + eps][ b % nW * m + nu], tvm.const(0, data_pad.dtype)), name='d') V = decl_V_minimal(input_tile, alpha, C, P) # batch gemm c = tvm.reduce_axis((0, C), name='c') M = tvm.compute( (alpha, alpha, K, P), lambda eps, nu, k, b: tvm.sum(U[eps][nu][c][k] * V[eps][nu][c][b], axis=c), name='M') # inverse transform and unpack output = decl_output_minimal(M, N, K, H, W, P, m, nH, nW) return output
def compute_X_dot_B(b, eps, nu, c, bb): temp_expr = {} for i in range(alpha): wd0 = B_T_dot_X[b][c][i][0][bb] - B_T_dot_X[b][c][i][6][bb] d4_sub_d2 = B_T_dot_X[b][c][i][4][bb] - B_T_dot_X[b][c][i][2][bb] wd7 = B_T_dot_X[b][c][i][7][bb] - B_T_dot_X[b][c][i][1][bb] d3_sub_d5 = B_T_dot_X[b][c][i][3][bb] - B_T_dot_X[b][c][i][5][bb] wd1 = B_T_dot_X[b][c][i][2][bb] + B_T_dot_X[b][c][i][6][bb] wd2 = B_T_dot_X[b][c][i][1][bb] + B_T_dot_X[b][c][i][5][bb] wd4 = B_T_dot_X[b][c][i][5][bb] + B_T_dot_X[b][c][i][1][bb] * 0.25 wd5 = B_T_dot_X[b][c][i][6][bb] - B_T_dot_X[b][c][i][4][bb] * 5 wd3 = B_T_dot_X[b][c][i][6][bb] + B_T_dot_X[b][c][i][2][bb] * 0.25 wd6 = B_T_dot_X[b][c][i][1][bb] + B_T_dot_X[b][c][i][5][bb] * 0.25 wd0 = wd0 + d4_sub_d2 * 5.25 wd7 = wd7 + d3_sub_d5 * 5.25 wd1 = wd1 - B_T_dot_X[b][c][i][4][bb] * 4.25 wd2 = wd2 - B_T_dot_X[b][c][i][3][bb] * 4.25 wd3 = wd3 - B_T_dot_X[b][c][i][4][bb] * 1.25 wd5 = wd5 + B_T_dot_X[b][c][i][2][bb] * 4 wd4 = wd4 - B_T_dot_X[b][c][i][3][bb] * 1.25 wd6 = wd6 - B_T_dot_X[b][c][i][3][bb] * 1.25 temp_expr[(i, 0)] = wd0 temp_expr[(i, 1)] = wd1 + wd2 temp_expr[(i, 2)] = wd1 - wd2 temp_expr[(i, 3)] = wd3 + wd4 * 2 temp_expr[(i, 4)] = wd3 - wd4 * 2 temp_expr[(i, 5)] = wd5 + wd6 * 2 temp_expr[(i, 6)] = wd5 - wd6 * 2 temp_expr[(i, 7)] = wd7 now = tvm.const(0.0, "float32") for ii in range(alpha): for jj in range(alpha): now = tvm.select(tvm.all(eps == ii, nu == jj), temp_expr[(ii, jj)], now) return now
def less(lhs, rhs, out_type=tvm.int8): """Compare two input tensors element-wise and return an mask tensor which contains 1 if lhs < rhs holds else 0 Parameters ---------- lhs : tvm.Tensor Left input argument. rhs : tvm.Tensor Right argument. out_type: str Output data type. Default is int8 Returns ------- y : tvm.Tensor The result. """ return tvm.compute( lhs.shape, lambda *i: tvm.select( lhs(*i) < rhs(*i), tvm.const(1, out_type), tvm.const(0, out_type)))
def less(lhs, rhs, out_type=tvm.int8): """Compare two input tensors element-wise and return an mask tensor which contains 1 if lhs < rhs holds else 0 Parameters ---------- lhs : tvm.Tensor Left input argument. rhs : tvm.Tensor Right argument. out_type: str Output data type. Default is int8 Returns ------- y : tvm.Tensor The result. """ return tvm.compute(lhs.shape, lambda *i: tvm.select(lhs(*i) < rhs(*i), tvm.const(1, out_type), tvm.const(0, out_type)))
def test_copy_pad(): m = tvm.var('m') l = tvm.var('l') A = tvm.placeholder((m, l), name='A') B = tvm.compute((m + 2, l), lambda i, j: tvm.select(tvm.all(i >= 1, i < m + 1), A[i - 1, j], 1.0), name='B') s = tvm.create_schedule(B.op) s[B].pragma(B.op.axis[0], "memcpy") bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) def cb(src, dst, pad_before, pad_after, pad_value): assert tvm.ir_pass.Simplify(src.elem_offset).value == 0 assert pad_before[0].value == 1 assert pad_before[1].value == 0 assert pad_after[0].value == 1 assert pad_after[1].value == 0 assert pad_value.value == 1.0 return tvm.make.Evaluate(0) stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
def fcombine(x, y): lhs = tvm.select((x[1] >= y[1]), x[0], y[0]) rhs = tvm.select((x[1] >= y[1]), x[1], y[1]) return lhs, rhs
def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): """Low level IR routing for multibox_prior operator. Parameters ---------- data : Buffer Input data buffer. out : Buffer Output buffer. sizes : tuple of float Tuple of sizes for anchor boxes. ratios : tuple of float Tuple of ratios for anchor boxes. steps : Tuple of float Priorbox step across y and x, -1 for auto calculation. offsets : tuple of int Priorbox center offsets, y and x respectively. Returns ------- stmt : Stmt The result IR statement. """ max_threads = int(math.sqrt( tvm.target.current_target(allow_none=False).max_num_threads)) tx = tvm.thread_axis("threadIdx.x") ty = tvm.thread_axis("threadIdx.y") bx = tvm.thread_axis("blockIdx.x") by = tvm.thread_axis("blockIdx.y") ib = tvm.ir_builder.create() p_out = ib.buffer_ptr(out) in_height = data.shape[2] in_width = data.shape[3] nthread_tx = max_threads nthread_bx = in_height // max_threads + 1 nthread_ty = max_threads nthread_by = in_width // max_threads + 1 ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(ty, "thread_extent", nthread_ty) ib.scope_attr(bx, "thread_extent", nthread_bx) ib.scope_attr(by, "thread_extent", nthread_by) num_sizes = len(sizes) num_ratios = len(ratios) size_ratio_concat = sizes + ratios steps_h = steps[0] if steps[0] > 0 else 1.0 / in_height steps_w = steps[1] if steps[1] > 0 else 1.0 / in_width offset_h = offsets[0] offset_w = offsets[1] i = bx * max_threads + tx j = by * max_threads + ty with ib.if_scope((i < in_height)): with ib.if_scope((j < in_width)): center_h = (i + offset_h) * steps_h center_w = (j + offset_w) * steps_w for k in range(num_sizes + num_ratios - 1): w = tvm.select(k < num_sizes, size_ratio_concat[ k] * in_height / in_width / 2.0, size_ratio_concat[0] * in_height / in_width * math.sqrt(size_ratio_concat[k + 1]) / 2.0) h = tvm.select(k < num_sizes, size_ratio_concat[k] / 2.0, size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0) count = (i * in_width * (num_sizes + num_ratios - 1) + j * (num_sizes + num_ratios - 1) + k) * 4 p_out[count] = center_w - w p_out[count + 1] = center_h - h p_out[count + 2] = center_w + w p_out[count + 3] = center_h + h body = ib.get() return body
in_channel = 256 out_channel = 512 in_size = 14 kernel = 3 pad = 1 stride = 1 # Algorithm A = tvm.placeholder((in_size, in_size, in_channel, batch), name='A') W = tvm.placeholder((kernel, kernel, in_channel, out_channel), name='W') out_size = (in_size - kernel + 2*pad) // stride + 1 # Pad input Apad = tvm.compute( (in_size + 2*pad, in_size + 2*pad, in_channel, batch), lambda yy, xx, cc, nn: tvm.select( tvm.all(yy >= pad, yy - pad < in_size, xx >= pad, xx - pad < in_size), A[yy - pad, xx - pad, cc, nn], tvm.const(0.)), name='Apad') # Create reduction variables rc = tvm.reduce_axis((0, in_channel), name='rc') ry = tvm.reduce_axis((0, kernel), name='ry') rx = tvm.reduce_axis((0, kernel), name='rx') # Compute the convolution B = tvm.compute( (out_size, out_size, out_channel, batch), lambda yy, xx, ff, nn: tvm.sum( Apad[yy * stride + ry, xx * stride + rx, rc, nn] * W[ry, rx, rc, ff], axis=[ry, rx, rc]), name='B')
def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, threshold, variances): """Low level IR routing for transform location in multibox_detection operator. Parameters ---------- cls_prob : Buffer Buffer of class probabilities. loc_pred : Buffer Buffer of location regression predictions. anchor : Buffer Buffer of prior anchor boxes. valid_count : Buffer Buffer of number of valid output boxes. out : Buffer Output buffer. clip : boolean Whether to clip out-of-boundary boxes. threshold : float Threshold to be a positive prediction. variances : tuple of float Variances to be decoded from box regression output. Returns ------- stmt : Stmt The result IR statement. """ def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw, vh): """Transform prior anchor box to output box through location predictions. """ al = anchor[anchor_base_idx] at = anchor[anchor_base_idx + 1] ar = anchor[anchor_base_idx + 2] ab = anchor[anchor_base_idx + 3] aw = ar - al ah = ab - at ax = (al + ar) / 2.0 ay = (at + ab) / 2.0 px = loc[loc_base_idx] py = loc[loc_base_idx + 1] pw = loc[loc_base_idx + 2] ph = loc[loc_base_idx + 3] ox = px * vx * aw + ax oy = py * vy * ah + ay ow = tvm.exp(pw * vw) * aw / 2.0 oh = tvm.exp(ph * vh) * ah / 2.0 return tvm.select(clip, tvm.max(0, tvm.min(1, ox - ow)), ox - ow), \ tvm.select(clip, tvm.max(0, tvm.min(1, oy - oh)), oy - oh), \ tvm.select(clip, tvm.max(0, tvm.min(1, ox + ow)), ox + ow), \ tvm.select(clip, tvm.max(0, tvm.min(1, oy + oh)), oy + oh) batch_size = cls_prob.shape[0] num_classes = cls_prob.shape[1] num_anchors = cls_prob.shape[2] ib = tvm.ir_builder.create() p_cls_prob = ib.buffer_ptr(cls_prob) p_loc_pred = ib.buffer_ptr(loc_pred) p_anchor = ib.buffer_ptr(anchor) p_valid_count = ib.buffer_ptr(valid_count) p_out = ib.buffer_ptr(out) with ib.for_range(0, batch_size, for_type="parallel", name="n") as n: p_valid_count[n] = 0 with ib.for_range(0, num_anchors, name="i") as i: # Find the predicted class id and probability score = ib.allocate('float32', (1,), name="score", scope="local") cls_id = ib.allocate('int32', (1,), name="id", scope="local") score[0] = -1.0 cls_id[0] = 0 with ib.for_range(0, num_classes, name="j") as j: with ib.if_scope(j > 0): temp = p_cls_prob[n * num_anchors * num_classes + j * num_anchors + i] cls_id[0] = tvm.select(temp > score[0], j, cls_id[0]) score[0] = tvm.max(temp, score[0]) with ib.if_scope(tvm.all(cls_id[0] > 0, score[0] < threshold)): cls_id[0] = 0 # [id, prob, xmin, ymin, xmax, ymax] # Remove background, restore original id with ib.if_scope(cls_id[0] > 0): out_base_idx = n * num_anchors * 6 + p_valid_count[n] * 6 p_out[out_base_idx] = cls_id[0] - 1.0 p_out[out_base_idx + 1] = score[0] offset = i * 4 p_out[out_base_idx + 2], p_out[out_base_idx + 3], p_out[out_base_idx + 4], \ p_out[out_base_idx + 5] = transform_loc(p_loc_pred, n * num_anchors * 4 + offset, p_anchor, offset, clip, variances[0], variances[1], variances[2], variances[3]) p_valid_count[n] += 1 return ib.get()
def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, tile_size): N, CI, IH, IW = get_const_tuple(data.shape) if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation if len(kernel.shape) == 4: if dilation_h != 1 or dilation_w != 1: kernel = dilate(kernel, (1, 1, dilation_h, dilation_w)) pre_computed = False CO, _, KH, KW = get_const_tuple(kernel.shape) else: assert (dilation_h, dilation_w) == (1, 1), "Does not support dilation" pre_computed = True H_CAT, W_CAT, CO, CI, VC = get_const_tuple(kernel.shape) CO *= VC KH, KW = H_CAT - tile_size + 1, W_CAT - tile_size + 1 HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) assert layout == 'NCHW' assert KH == 3 and KW == 3 and HPAD == 1 and WPAD == 1 and HSTR == 1 and WSTR == 1 data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") if tile_size == 4: G_data = np.array([ [1 / 4.0, 0, 0], [-1 / 6.0, -1 / 6.0, -1 / 6.0], [-1 / 6.0, 1 / 6.0, -1 / 6.0], [1 / 24.0, 1 / 12.0, 1 / 6.0], [1 / 24.0, -1 / 12.0, 1 / 6.0], [0, 0, 1]], out_dtype) B_data = np.array([ [4, 0, 0, 0, 0, 0], [0, -4, 4, -2, 2, 4], [-5, -4, -4, -1, -1, 0], [0, 1, -1, 2, -2, -5], [1, 1, 1, 1, 1, 0], [0, 0, 0, 0, 0, 1]], out_dtype) A_data = np.array([ [1, 0, 0, 0], [1, 1, 1, 1], [1, -1, 1, -1], [1, 2, 4, 8], [1, -2, 4, -8], [0, 0, 0, 1]], out_dtype) elif tile_size == 2: G_data = np.array([ [1, 0, 0], [1.0/2, 1.0/2, 1.0/2], [1.0/2, -1.0/2, 1.0/2], [0, 0, 1]], out_dtype) B_data = np.array([ [1, 0, 0, 0], [0, 1, -1, 1], [-1, 1, 1, 0], [0, 0, 0, -1]], out_dtype) A_data = np.array([ [1, 0], [1, 1], [1, -1], [0, -1]], out_dtype) else: raise ValueError("Unsupported tile size for winograd: " + str(tile_size)) m = A_data.shape[1] r = 3 alpha = m + r - 1 H = (IH + 2 * HPAD - 3) // HSTR + 1 W = (IW + 2 * WPAD - 3) // WSTR + 1 nH, nW = (H + m-1) // m, (W + m-1) // m P = N * nH * nW ##### space definition begin ##### tile_bna_candidates = [1, 2, 4, 8, 16] factors = get_factors(CO) cfg.define_knob('tile_bna', [x for x in tile_bna_candidates if x in factors]) cfg.define_knob('tile_bnb', [1, 2, 4, 8, 16]) cfg.define_split('tile_t1', CI, num_outputs=2, max_factor=128) cfg.define_split('tile_t2', CO, num_outputs=2, max_factor=128) cfg.define_split('c_unroll', CI, num_outputs=2, max_factor=8) cfg.define_knob('yt', [1, 2, 4, 8, 16, 32]) ##### space definition end ##### if cfg.is_fallback: cfg['tile_bnb'].val = 4 cfg['tile_bna'].val = 4 while CO % cfg['tile_bna'].val != 0: cfg['tile_bna'].val //= 2 cfg['yt'].val = 8 cfg.fallback_split('tile_t1', [-1, 128]) cfg.fallback_split('tile_t2', [-1, 128]) cfg.fallback_split('c_unroll', [-1, 8]) bna = cfg['tile_bna'].val bnb = cfg['tile_bnb'].val P_round = (P + bnb - 1) // bnb * bnb assert CO % bna == 0 and P_round % bnb == 0 # pack input tile input_tile = tvm.compute((CI, P_round // bnb, alpha, alpha, bnb), lambda ci, b, eps, nu, bb: \ tvm.select(b * bnb + bb < P, data_pad[(b*bnb+bb) // (nH*nW)][ci][(b*bnb+bb) // nW % nH * m + eps] [(b*bnb+bb) % nW * m + nu], tvm.const(0, data_pad.dtype)), name='d') # transform kernel if pre_computed: U = kernel else: G = const_matrix(G_data, 'G') r_kh = tvm.reduce_axis((0, KH), 'r_kh') r_kw = tvm.reduce_axis((0, KW), 'r_kw') U = tvm.compute((alpha, alpha, CO // bna, CI, bna), lambda eps, nu, co, ci, vco: tvm.sum(kernel[co * bna + vco][ci][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name='U') # transform image B = const_matrix(B_data, 'B') r_a = tvm.reduce_axis((0, alpha), 'r_a') r_b = tvm.reduce_axis((0, alpha), 'r_b') V = tvm.compute((alpha, alpha, P_round // bnb, CI, bnb), lambda eps, nu, p, ci, vp: tvm.sum(input_tile[ci][p][r_a][r_b][vp] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b]), name='V') # batch gemm ci = tvm.reduce_axis((0, CI), name='c') M = tvm.compute((alpha, alpha, CO, P_round), lambda eps, nu, co, p: tvm.sum(U[eps][nu][co // bna][ci][co % bna] * V[eps][nu][p // bnb][ci][p % bnb], axis=ci), name='M') A = const_matrix(A_data, 'A') r_a = tvm.reduce_axis((0, alpha), 'r_a') r_b = tvm.reduce_axis((0, alpha), 'r_b') Y = tvm.compute((CO, P, m, m), lambda co, p, vh, vw: tvm.sum(M[r_a][r_b][co][p] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b]), name='Y') # unpack output output = tvm.compute((N, CO, H, W), lambda n, co, h, w: Y[co][n * nH * nW + (h//m) * nW + w//m][h % m][w % m] # thw following term is used to make the padding effective, # otherwise the padding will be eliminated by bound inference + tvm.const(0, out_dtype) * M[alpha-1][alpha-1][CO-1][P_round-1], name='output', tag='winograd_conv2d_output') # we have to manually assign effective GFLOP for winograd cfg.add_flop(2 * N * CO * H * W * KH * KW * CI) return output
def _decl_winograd(data, kernel, stride, padding, layout, out_dtype): """declare winograd fast convolution F(2x2, 3x3) for conv2d""" N, CI, H, W = [util.get_const_int(x) for x in data.shape] CO, CI, KH, KW = [util.get_const_int(x) for x in kernel.shape] HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) if isinstance(stride, (tuple, list)): HSTR, WSTR = stride else: HSTR, WSTR = stride, stride assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1 and KH == 3 and KW == 3 data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") B_data = np.array([ [1, 0, 0, 0], [0, 1, -1, 1], [-1, 1, 1, 0], [0, 0, 0, -1] ], out_dtype) G_data = np.array([ [1, 0, 0], [1.0/2, 1.0/2, 1.0/2], [1.0/2, -1.0/2, 1.0/2], [0, 0, 1], ], out_dtype) A_data = np.array([ [1, 0], [1, 1], [1, -1], [0, -1], ], out_dtype) m = 2 r = 3 alpha = m + r - 1 K = CO C = CI nH, nW = (H + m-1) // m, (W + m-1) // m P = N * nH * nW bna, bnb = 4, 4 if data.dtype == 'float16': bnb *= 2 P_round = (P + bnb - 1) // bnb * bnb assert K % bna == 0 and P_round % bnb == 0 # pack input tile input_tile = tvm.compute((C, P_round // bnb, alpha, alpha, bnb), lambda c, b, eps, nu, bb: tvm.select(b * bnb + bb < P,\ data_pad[(b*bnb+bb) // (nH*nW)][c][(b*bnb+bb) // nW % nH * m + eps]\ [(b*bnb+bb) % nW * m + nu], tvm.const(0, data_pad.dtype)), name='d') # transform kernel G = const_array(G_data, 'G') r_kh = tvm.reduce_axis((0, KH), 'r_kh') r_kw = tvm.reduce_axis((0, KW), 'r_kw') U = tvm.compute((alpha, alpha, K // bna, C, bna), lambda eps, nu, k, c, kk: tvm.sum(kernel[k * bna + kk][c][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name='U') # transform image B = const_array(B_data, 'B') r_eps = tvm.reduce_axis((0, alpha), 'r_eps') r_nu = tvm.reduce_axis((0, alpha), 'r_nu') V = tvm.compute((alpha, alpha, P_round // bnb, C, bnb), lambda eps, nu, b, c, bb: tvm.sum(input_tile[c][b][r_eps][r_nu][bb] * B[r_eps][eps] * B[r_nu][nu], axis=[r_eps, r_nu]), name='V') # batch gemm c = tvm.reduce_axis((0, C), name='c') M = tvm.compute((alpha, alpha, K, P_round), lambda eps, nu, k, b: tvm.sum(U[eps][nu][k // bna][c][k % bna] * V[eps][nu][b // bnb][c][b % bnb], axis=c), name='M') # inverse transform A = const_array(A_data, 'A') r_eps = tvm.reduce_axis((0, alpha), 'r_eps') r_nu = tvm.reduce_axis((0, alpha), 'r_nu') Y = tvm.compute((K, P, m, m), lambda k, b, vh, vw: tvm.sum(M[r_eps][r_nu][k][b] * A[r_eps][vh] * A[r_nu][vw], axis=[r_eps, r_nu]), name='Y') # unpack output output = tvm.compute((N, K, H, W), lambda n, k, h, w: Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m] # thw following term is used to make the padding effective, # otherwise the padding will be eliminated by bound inference + tvm.const(0, out_dtype) * M[alpha-1][alpha-1][K-1][P_round-1], name='output', tag='winograd_conv_output') return output
def _compute(*indices): value = x(*indices) calpha = tvm.const(alpha, value.dtype) return tvm.select(value > 0, value, value * calpha)
def argmax_comp(x, y): idx = tvm.select((x[1] >= y[1]), x[0], y[0]) val = tvm.select((x[1] >= y[1]), x[1], y[1]) return idx, val
def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, nms_topk): """Low level IR routing for transform location in multibox_detection operator. Parameters ---------- data: Buffer Buffer of output boxes with class and score. sort_result : Buffer Buffer of output box indexes sorted by score. valid_count : Buffer Buffer of number of valid output boxes. out : Buffer Output buffer. nms_threshold : float Non-maximum suppression threshold. force_suppress : boolean Whether to suppress all detections regardless of class_id. nms_topk : int Keep maximum top k detections before nms, -1 for no limit. Returns ------- stmt : Stmt The result IR statement. """ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): """Calculate overlap of two boxes. """ w = tvm.make.Max(0.0, tvm.make.Min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2]) - tvm.make.Max(out_tensor[box_a_idx], out_tensor[box_b_idx])) h = tvm.make.Max(0.0, tvm.make.Min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3]) - tvm.make.Max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1])) i = w * h u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx]) * \ (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \ (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \ (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i return tvm.select(u <= 0.0, 0.0, i / u) max_threads = int(math.sqrt( tvm.target.current_target(allow_none=False).max_num_threads)) tx = tvm.thread_axis("threadIdx.x") ty = tvm.thread_axis("threadIdx.y") bx = tvm.thread_axis("blockIdx.x") by = tvm.thread_axis("blockIdx.y") ib = tvm.ir_builder.create() p_data = ib.buffer_ptr(data) p_sort_result = ib.buffer_ptr(sort_result) p_valid_count = ib.buffer_ptr(valid_count) p_out = ib.buffer_ptr(out) batch_size = out.shape[0] num_anchors = out.shape[1] nthread_tx = max_threads nthread_bx = num_anchors // max_threads + 1 nthread_ty = max_threads nthread_by = 6 // max_threads + 1 ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(ty, "thread_extent", nthread_ty) ib.scope_attr(bx, "thread_extent", nthread_bx) ib.scope_attr(by, "thread_extent", nthread_by) i = bx * max_threads + tx j = by * max_threads + ty nms_threshold_node = tvm.make.node( "FloatImm", dtype="float32", value=nms_threshold) nms_topk_node = tvm.make.node("IntImm", dtype="int32", value=nms_topk) force_suppress_node = tvm.make.node( "IntImm", dtype="int32", value=1 if force_suppress else 0) with ib.for_range(0, batch_size, for_type="unroll", name="n") as n: with ib.if_scope( tvm.all(nms_threshold_node > 0, nms_threshold_node < 1, p_valid_count[0] > 0)): # Reorder output nkeep = tvm.select( tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n]), nms_topk, p_valid_count[n]) with ib.if_scope(i < nkeep): with ib.if_scope(j < 6): p_out[(n * num_anchors * 6 + i * 6 + j)] = p_data[(n * num_anchors * 6 + p_sort_result[n * num_anchors + i] * 6 + j)] with ib.if_scope(tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n])): with ib.if_scope(i < p_valid_count[n] - nkeep): with ib.if_scope(j < 6): p_out[(n * num_anchors * 6 + (i + nkeep) * 6 + j)] = p_data[(n * num_anchors * 6 + (i + nkeep) * 6 + j)] # Apply nms with ib.if_scope(i < p_valid_count[n]): offset_i = i * 6 with ib.if_scope(p_out[n * num_anchors * 6 + offset_i] >= 0): with ib.if_scope(j < p_valid_count[n]): offset_j = j * 6 with ib.if_scope(tvm.all(j > i, p_out[n * num_anchors * 6 + offset_j] >= 0)): with ib.if_scope(tvm.any(force_suppress_node > 0, p_out[n * num_anchors * 6 + offset_i] == p_out[n * num_anchors * 6 + offset_j])): # When force_suppress == True or class_id equals iou = calculate_overlap( p_out, n * num_anchors * 6 + offset_i + 2, n * num_anchors * 6 + offset_j + 2) with ib.if_scope(iou >= nms_threshold): p_out[ n * num_anchors * 6 + offset_j] = -1.0 with ib.else_scope(): with ib.if_scope(i < p_valid_count[n]): with ib.if_scope(j < 6): p_out[(n * num_anchors * 6 + i * 6 + j)] = p_data[n * num_anchors * 6 + i * 6 + j] # Set invalid entry to be -1 with ib.if_scope(i < num_anchors - p_valid_count[n]): with ib.if_scope(j < 6): p_out[n * num_anchors * 6 + (i + p_valid_count[n]) * 6 + j] = -1.0 body = ib.get() return body
def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, nms_topk): """Low level IR routing for transform location in multibox_detection operator. Parameters ---------- data: Buffer Buffer of output boxes with class and score. sort_result : Buffer Buffer of output box indexes sorted by score. valid_count : Buffer Buffer of number of valid output boxes. out : Buffer Output buffer. nms_threshold : float Non-maximum suppression threshold. force_suppress : boolean Whether to suppress all detections regardless of class_id. nms_topk : int Keep maximum top k detections before nms, -1 for no limit. Returns ------- stmt : Stmt The result IR statement. """ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): """Calculate overlap of two boxes. """ w = tvm.make.Max(0.0, tvm.make.Min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2]) - tvm.make.Max(out_tensor[box_a_idx], out_tensor[box_b_idx])) h = tvm.make.Max(0.0, tvm.make.Min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3]) - tvm.make.Max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1])) i = w * h u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx]) * \ (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \ (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \ (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i return tvm.select(u <= 0.0, 0.0, i / u) ib = tvm.ir_builder.create() p_data = ib.buffer_ptr(data) p_sort_result = ib.buffer_ptr(sort_result) p_valid_count = ib.buffer_ptr(valid_count) p_out = ib.buffer_ptr(out) batch_size = out.shape[0] num_anchors = out.shape[1] nms_threshold_node = tvm.make.node("FloatImm", dtype="float32", value=nms_threshold) nms_topk_node = tvm.make.node("IntImm", dtype="int32", value=nms_topk) force_suppress_node = tvm.make.node("IntImm", dtype="int32", value=1 if force_suppress else 0) with ib.for_range(0, batch_size, for_type="parallel", name="n") as n: with ib.if_scope(tvm.all(nms_threshold_node > 0, nms_threshold_node < 1, p_valid_count[0] > 0)): # Reorder output nkeep = tvm.select(tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n]), nms_topk, p_valid_count[n]) with ib.for_range(0, nkeep, name="l") as l: with ib.for_range(0, 6, name="m") as m: p_out[(n * num_anchors * 6 + l * 6 + m)] = p_data[(n * num_anchors * 6 + p_sort_result[n * num_anchors + l] * 6 + m)] with ib.if_scope(tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n])): with ib.for_range(0, p_valid_count[n] - nkeep, name="l") as l: with ib.for_range(0, 6, name="m") as m: p_out[(n * num_anchors * 6 + (l + nkeep) * 6 + m)] = p_data[(n * num_anchors * 6 + (l + nkeep) * 6 + m)] # Apply nms with ib.for_range(0, p_valid_count[n], name="l") as l: offset_l = l * 6 with ib.if_scope(p_out[n * num_anchors * 6 + offset_l] >= 0): with ib.for_range(0, p_valid_count[n], name="m") as m: offset_m = m * 6 with ib.if_scope(tvm.all(m > l, p_out[n * num_anchors * 6 + offset_m] >= 0)): with ib.if_scope(tvm.any(force_suppress_node > 0, p_out[n * num_anchors * 6 + offset_l] == p_out[n * num_anchors * 6 + offset_m])): # When force_suppress == True or class_id equals iou = calculate_overlap(p_out, n * num_anchors * 6 + offset_l + 2, n * num_anchors * 6 + offset_m + 2) with ib.if_scope(iou >= nms_threshold): p_out[n * num_anchors * 6 + offset_m] = -1.0 with ib.else_scope(): with ib.for_range(0, p_valid_count[n], name="l") as l: with ib.for_range(0, 6, name="m") as m: p_out[(n * num_anchors * 6 + l * 6 + m)] = p_data[n * num_anchors * 6 + l * 6 + m] # Set invalid entry to be -1 with ib.for_range(0, num_anchors - p_valid_count[n], name="l") as l: with ib.for_range(0, 6, name="m") as m: p_out[n * num_anchors * 6 + (l + p_valid_count[n]) * 6 + m] = -1.0 return ib.get()
def sort_ir_out(data, index, new_index, loc, output, axis_mul_before, axis_mul_after, axis): """Low level IR routing subfunction 4/4 for writing sorted indices to output format. Parameters ---------- data: Buffer Buffer of output boxes with class and score. index : Buffer Buffer of number of valid output boxes. new_index : Buffer Buffer of sorted indices in a flatten format. loc : Buffer Buffer of start locations of each sorting segment. output : Buffer Output buffer of output box indexes sorted by score. axis_mul_before : int The multiplication result of axis dimensions before axis. axis_mul_after : int The multiplication result of axis dimensions after axis. axis : int The axis used for sorting. is_descend : bool If the sorted data is in descending order. Returns ------- stmt : Stmt The result IR statement. """ max_threads = int( tvm.target.current_target(allow_none=False).max_num_threads) tx = tvm.thread_axis("threadIdx.x") bx = tvm.thread_axis("blockIdx.x") ib = tvm.ir_builder.create() dshape = tvm.max(loc.shape[0], data.shape[axis]) p_index = ib.buffer_ptr(index) index_new = ib.buffer_ptr(new_index) sizes = ib.buffer_ptr(loc) p_out = ib.buffer_ptr(output) nthread_tx = max_threads nthread_bx = dshape // max_threads + 1 ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx with ib.if_scope(axis_mul_before * axis_mul_after > 1): with ib.if_scope(tid < axis_mul_before * axis_mul_after): i = tid / axis_mul_after j = tid % axis_mul_after base_idx = i * data.shape[axis] * axis_mul_after + j with ib.for_range(0, data.shape[axis], name="k") as k: with ib.if_scope(tid == 0): start = 0 with ib.else_scope(): start = sizes[tid-1] p_out[base_idx + k * axis_mul_after] = tvm.select( k < p_index[tid], index_new[k+start], k) with ib.else_scope(): with ib.if_scope(tid < data.shape[axis]): p_out[tid] = tvm.select(tid < p_index[0], index_new[tid], tid) body = ib.get() return body