def _instr(index): ib = tvm.ir_builder.create() if index == 1: ib.emit(outs[0].vstore(0, tvm.const(0, 'int32x16'))) return ib.get() a_int8 = ins[0].vload([0], "uint8x4") re_int32 = tvm.call_pure_intrin('int32', 'reinterpret', a_int8) vec_ai32 = re_int32.astype('int32x16') vec_a = tvm.call_pure_intrin('int8x64', 'reinterpret', vec_ai32) vec_b = ins[1].vload([0, 0], "int8x64") vec_one = tvm.const(1, "int16x32") pair_reduction = tvm.call_llvm_intrin('int16x32', 'llvm.x86.avx512.pmaddubs.w.512', tvm.const(0, 'uint32'), vec_a, vec_b) quad_reduction = tvm.call_llvm_intrin('int32x16', 'llvm.x86.avx512.pmaddw.d.512', tvm.const(0, 'uint32'), pair_reduction, vec_one) if index == 0: ib.emit(outs[0].vstore(0, quad_reduction)) else: ib.emit(outs[0].vstore(0, quad_reduction + outs[0].vload([0], 'int32x16'))) return ib.get()
def _instr(index): ib = tvm.ir_builder.create() if index == 1: ib.emit(outs[0].vstore( 0, tvm.const(0, '%s32x%d' % (dtype, int32_lanes)))) return ib.get() dtype_a = '%s8x%d' % (dtype, num_int8_elements) dtype_b = '%s8x%d' % (dtype, int32_lanes * num_int8_elements) dtype_c = '%s32x%d' % (dtype, int32_lanes) a_int8 = ins[0].vload([0], dtype_a) re_int32 = tvm.call_pure_intrin('%s32' % dtype, 'reinterpret', a_int8) # broadcast a vec_ai32 = re_int32.astype(dtype_c) vec_a = tvm.call_pure_intrin(dtype_b, 'reinterpret', vec_ai32) vec_b = ins[1].vload([0, 0], dtype_b) vec_c = outs[0].vload([0], dtype_c) inst = 'udot' if dtype == 'uint' else 'sdot' inst = 'llvm.aarch64.neon.%s.v%di32.v%di8' % ( inst, int32_lanes, int32_lanes * num_int8_elements) vdot = tvm.call_llvm_intrin(dtype_c, inst, tvm.const(2, 'uint32'), vec_c, vec_a, vec_b) ib.emit(outs[0].vstore(0, vdot)) return ib.get()
def _instr(index): irb = tvm.ir_builder.create() if index == 1: irb.emit(zz.vstore(0, tvm.const(0, 'uint16x8'))) return irb.get() cnts8 = [None] * 8 cnts4 = [None] * 4 cnts2 = [None] * 2 for bw in range(w_b): for bx in range(x_b): if k_i == 16: for i in range(m): ands = ww.vload([bw, i, 0], 'uint8x16') & xx.vload( [bx, 0], 'uint8x16') cnts = tvm.popcount(ands) upper_half = tvm.call_pure_intrin( 'uint8x8', 'vectorhigh', cnts) lower_half = tvm.call_pure_intrin( 'uint8x8', 'vectorlow', cnts) cnts8[i] = upper_half + lower_half for i in range(m // 2): cnts4[i] = tvm.call_llvm_intrin( 'uint8x8', vpadd, args_1, cnts8[i * 2], cnts8[i * 2 + 1]) for i in range(m // 4): cnts2[i] = tvm.call_llvm_intrin( 'uint8x8', vpadd, args_1, cnts4[i * 2], cnts4[i * 2 + 1]) cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw + bx, dtype) out = tvm.call_llvm_intrin('uint16x8', vpadalu, args_2, zz.vload(0, 'uint16x8'), shifted_cnts) else: # ki == 8 for i in range(m): ands = ww.vload([bw, i, 0], 'uint8x8') & xx.vload( [bx, 0], 'uint8x8') cnts8[i] = tvm.popcount(ands) for i in range(m // 2): cnts4[i] = tvm.call_llvm_intrin( 'uint8x8', vpadd, args_1, cnts8[i * 2], cnts8[i * 2 + 1]) for i in range(m // 4): cnts2[i] = tvm.call_llvm_intrin( 'uint8x8', vpadd, args_1, cnts4[i * 2], cnts4[i * 2 + 1]) cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw + bx, dtype) out = tvm.call_llvm_intrin('uint16x8', vpadalu, args_2, zz.vload(0, 'uint16x8'), shifted_cnts) irb.emit(zz.vstore(0, out)) return irb.get()
def _instr(index): irb = tvm.ir_builder.create() if index == 1: # reduce reset irb.emit(zz.vstore(0, tvm.const(0, return_dtype))) return irb.get() # body and reduce update cnts8 = [None] * 8 cnts4 = [None] * 4 cnts2 = [None] * 2 for bw in range(w_b): for bx in range(x_b): if k_i == 16: for i in range(m): w_ = ww.vload([bw, i, 0], 'uint8x16').astype(full_dtype) x_ = xx.vload([bx, 0], 'uint8x16').astype(full_dtype) if unipolar: cnts = tvm.popcount(w_ & x_) - tvm.popcount(~w_ & x_) else: cnts = tvm.popcount(w_ & x_) upper_half = tvm.call_pure_intrin(half_dtype, 'vectorhigh', cnts) lower_half = tvm.call_pure_intrin(half_dtype, 'vectorlow', cnts) cnts8[i] = upper_half + lower_half for i in range(m//2): cnts4[i] = tvm.call_llvm_intrin(half_dtype, vpadd, args_1, cnts8[i*2], cnts8[i*2+1]) for i in range(m//4): cnts2[i] = tvm.call_llvm_intrin(half_dtype, vpadd, args_1, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.call_pure_intrin(full_dtype, 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw+bx, pack_dtype) out = tvm.call_llvm_intrin(return_dtype, vpadalu, args_2, zz.vload(0, return_dtype), shifted_cnts) else: # ki == 8 for i in range(m): w_ = ww.vload([bw, i, 0], 'uint8x8').astype(half_dtype) x_ = xx.vload([bx, 0], 'uint8x8').astype(half_dtype) if unipolar: cnts8[i] = tvm.popcount(w_ & x_) - tvm.popcount(~w_ & x_) else: cnts8[i] = tvm.popcount(w_ & x_) for i in range(m//2): cnts4[i] = tvm.call_llvm_intrin(half_dtype, vpadd, args_1, cnts8[i*2], cnts8[i*2+1]) for i in range(m//4): cnts2[i] = tvm.call_llvm_intrin(half_dtype, vpadd, args_1, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.call_pure_intrin(full_dtype, 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw+bx, pack_dtype) out = tvm.call_llvm_intrin(return_dtype, vpadalu, args_2, zz.vload(0, return_dtype), shifted_cnts) irb.emit(zz.vstore(0, out)) return irb.get()
def test_reinterpret(): nn = 1024 n = tvm.convert(nn) A = tvm.placeholder((n, ), name='A', dtype="int32") B = tvm.compute( A.shape, lambda *i: tvm.call_pure_intrin("float32", "reinterpret", A(*i)), name='B') s = tvm.create_schedule(B.op) def check_c(): mhost = tvm.build(s, [A, B], "c", name="reinterpret") temp = util.tempdir() path_dso = temp.relpath("temp.so") mhost.export_library(path_dso) m = tvm.module.load(path_dso) fadd = m['reinterpret'] ctx = tvm.cpu(0) n = nn a = tvm.nd.array( np.random.randint(-2**30, 2**30, size=n).astype(A.dtype), ctx) b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx) fadd(a, b) tvm.testing.assert_allclose(b.asnumpy(), a.asnumpy().view('float32')) check_c()
def _instr(index): irb = tvm.ir_builder.create() if index == 1: irb.emit(zz.vstore(0, tvm.const(0, 'uint16x8'))) return irb.get() cnts8 = [None] * 8 cnts4 = [None] * 4 cnts2 = [None] * 2 for bw in range(w_b): for bx in range(x_b): if k_i == 16: for i in range(m): ands = ww.vload([bw, i, 0], 'uint8x16') & xx.vload([bx, 0], 'uint8x16') cnts = tvm.popcount(ands) upper_half = tvm.call_pure_intrin('uint8x8', 'vectorhigh', cnts) lower_half = tvm.call_pure_intrin('uint8x8', 'vectorlow', cnts) cnts8[i] = upper_half + lower_half for i in range(m//2): cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd, args_1, cnts8[i*2], cnts8[i*2+1]) for i in range(m//4): cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd, args_1, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw+bx, dtype) out = tvm.call_llvm_intrin('uint16x8', vpadalu, args_2, zz.vload(0, 'uint16x8'), shifted_cnts) else: # ki == 8 for i in range(m): ands = ww.vload([bw, i, 0], 'uint8x8') & xx.vload([bx, 0], 'uint8x8') cnts8[i] = tvm.popcount(ands) for i in range(m//2): cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd, args_1, cnts8[i*2], cnts8[i*2+1]) for i in range(m//4): cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd, args_1, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw+bx, dtype) out = tvm.call_llvm_intrin('uint16x8', vpadalu, args_2, zz.vload(0, 'uint16x8'), shifted_cnts) irb.emit(zz.vstore(0, out)) return irb.get()
def test_llvm_intrin(): ib = tvm.ir_builder.create() n = tvm.convert(4) A = ib.pointer("float32", name="A") args = [tvm.call_pure_intrin("handle", "tvm_address_of", A[0]), 0, 3, 1] ib.emit( tvm.make.Evaluate( tvm.make.Call("int32", "prefetch", args, tvm.expr.Call.Intrinsic, None, 0))) body = ib.get() func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True) fcode = tvm.build(func, None, "llvm")
def _instr(index): ib = tvm.ir_builder.create() if index == 1: ib.emit(outs[0].vstore(0, tvm.const(0, 'int32x16'))) return ib.get() a_int8 = ins[0].vload([0], "uint8x4") re_int32 = tvm.call_pure_intrin('int32', 'reinterpret', a_int8) vec_ai32 = re_int32.astype('int32x16') vec_b = ins[1].vload([0, 0], "int8x64") vnni_inst_name = 'llvm.x86.avx512.vpdpbusd.512' llvm_id = tvm.target.codegen.llvm_lookup_intrinsic_id( vnni_inst_name) if llvm_id != 0: # VNNI is available for current LLVM version vec_bi32 = tvm.call_pure_intrin('int32x16', 'reinterpret', vec_b) vec_zero = tvm.const(0, "int32x16") quad_reduction = tvm.call_llvm_intrin( 'int32x16', 'llvm.x86.avx512.vpdpbusd.512', tvm.const(0, 'uint32'), vec_zero, vec_ai32, vec_bi32) else: # Fall back to the normal AVX512 vec_a = tvm.call_pure_intrin('int8x64', 'reinterpret', vec_ai32) vec_one = tvm.const(1, "int16x32") pair_reduction = tvm.call_llvm_intrin( 'int16x32', 'llvm.x86.avx512.pmaddubs.w.512', tvm.const(0, 'uint32'), vec_a, vec_b) quad_reduction = tvm.call_llvm_intrin( 'int32x16', 'llvm.x86.avx512.pmaddw.d.512', tvm.const(0, 'uint32'), pair_reduction, vec_one) if index == 0: ib.emit(outs[0].vstore(0, quad_reduction)) else: ib.emit(outs[0].vstore( 0, quad_reduction + outs[0].vload([0], 'int32x16'))) return ib.get()
def test_llvm_intrin(): ib = tvm.ir_builder.create() n = tvm.convert(4) A = ib.pointer("float32", name="A") args = [ tvm.call_pure_intrin("handle", "tvm_address_of", A[0]), 0, 3, 1 ] ib.emit(tvm.make.Evaluate( tvm.make.Call( "int32", "prefetch", args, tvm.expr.Call.Intrinsic, None, 0))) body = ib.get() func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True) fcode = tvm.build(func, None, "llvm")
def _instr(index): ib = tvm.ir_builder.create() if index == 1: for i in range(4): ib.emit(outs[0].vstore([i * 32], tvm.const(0, 'int16x32'))) return ib.get() a_int8 = ins[0].vload([0], "uint8x2") re_int16 = tvm.call_pure_intrin('int16', 'reinterpret', a_int8) vec_ai16 = re_int16.astype('int16x32') vec_a = tvm.call_pure_intrin('int8x64', 'reinterpret', vec_ai16) for i in range(4): vec_b = ins[1].vload([i * 32, 0], "int8x64") pair_reduction = tvm.call_llvm_intrin( 'int16x32', 'llvm.x86.avx512.pmaddubs.w.512', tvm.const(0, 'uint32'), vec_a, vec_b) if index == 0: ib.emit(outs[0].vstore([i * 32], pair_reduction)) else: ib.emit(outs[0].vstore( [i * 32], pair_reduction + outs[0].vload([i * 32], 'int16x32'))) return ib.get()
def _instr(index): ib = tvm.ir_builder.create() if index == 1: ib.emit(outs[0].vstore(0, tvm.const(0, 'int32x16'))) return ib.get() a_int8 = ins[0].vload([0], "uint8x4") re_int32 = tvm.call_pure_intrin('int32', 'reinterpret', a_int8) vec_ai32 = re_int32.astype('int32x8') vec_a = tvm.call_pure_intrin('int8x32', 'reinterpret', vec_ai32) vec_b_0 = ins[1].vload([0, 0], "int8x32") vec_b_1 = ins[1].vload([8, 0], "int8x32") vec_one = tvm.const(1, "int16x16") pair_reduction_0 = tvm.call_llvm_intrin( 'int16x16', 'llvm.x86.avx2.pmadd.ub.sw', tvm.const(0, 'uint32'), vec_a, vec_b_0) quad_reduction_0 = tvm.call_llvm_intrin('int32x8', 'llvm.x86.avx2.pmadd.wd', tvm.const(0, 'uint32'), pair_reduction_0, vec_one) pair_reduction_1 = tvm.call_llvm_intrin( 'int16x16', 'llvm.x86.avx2.pmadd.ub.sw', tvm.const(0, 'uint32'), vec_a, vec_b_1) quad_reduction_1 = tvm.call_llvm_intrin('int32x8', 'llvm.x86.avx2.pmadd.wd', tvm.const(0, 'uint32'), pair_reduction_1, vec_one) if index == 0: ib.emit(outs[0].vstore([0], quad_reduction_0)) ib.emit(outs[0].vstore([8], quad_reduction_1)) else: ib.emit(outs[0].vstore([0], quad_reduction_0 + \ outs[0].vload([0], 'int32x8'))) ib.emit(outs[0].vstore([8], quad_reduction_1 + \ outs[0].vload([8], 'int32x8'))) return ib.get()
def _instr(index): ib = tvm.ir_builder.create() if index == 1: ib.emit(outs[0].vstore(0, tvm.const(0, 'int32x16'))) return ib.get() a_int8 = ins[0].vload([0], "uint8x4") re_int32 = tvm.call_pure_intrin('int32', 'reinterpret', a_int8) vec_ai32 = re_int32.astype('int32x16') vec_a = tvm.call_pure_intrin('int8x64', 'reinterpret', vec_ai32) vec_b = ins[1].vload([0, 0], "int8x64") vec_one = tvm.const(1, "int16x32") pair_reduction = tvm.call_llvm_intrin( 'int16x32', 'llvm.x86.avx512.pmaddubs.w.512', tvm.const(0, 'uint32'), vec_a, vec_b) quad_reduction = tvm.call_llvm_intrin( 'int32x16', 'llvm.x86.avx512.pmaddw.d.512', tvm.const(0, 'uint32'), pair_reduction, vec_one) if index == 0: ib.emit(outs[0].vstore(0, quad_reduction)) else: ib.emit(outs[0].vstore( 0, quad_reduction + outs[0].vload([0], 'int32x16'))) return ib.get()
def get_vthread(name): tx = tvm.thread_axis(name) ty = tvm.thread_axis(name) ib = tvm.ir_builder.create() A = ib.pointer("float32", name="A") C = ib.pointer("float32", name="C") with ib.for_range(0, n) as i: ib.scope_attr(tx, "virtual_thread", nthread) ib.scope_attr(ty, "virtual_thread", nthread) B = ib.allocate("float32", m, name="B", scope="shared") B[i] = A[i * nthread + tx] bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode()) ib.emit(tvm.call_extern("int32", "Run", bbuffer.access_ptr("r"), tvm.call_pure_intrin("int32", "tvm_context_id"))) C[i * nthread + tx] = B[i] + 1 return ib.get()
def mylog(x): """customized log intrinsic function""" return tvm.call_pure_intrin(x.dtype, "mylog", x)
def vectorize(op): if isinstance(op, tvm.stmt.For): outer_loops.pop() if to_vectorize: if str(op.loop_var) == f'{to_vectorize[-1]}.init': return tvm.tir.For(op.loop_var, op.min, op.extent, tvm.stmt.For.Vectorized, op.device_api, op.body) elif str(op.loop_var) == str(to_vectorize[-1]): loops = [] loads = [] store = [None] guard = [None] def get_loops(op): if isinstance(op, tvm.stmt.For): loops.append(op) elif isinstance(op, tvm.expr.Load): loads.append(op) elif isinstance(op, tvm.stmt.Store): assert store[0] is None store[0] = op elif isinstance(op, tvm.stmt.IfThenElse): guard[0] = op tvm.ir_pass.PostOrderVisit(op, get_loops) inner, outer = loops loops = loops[::-1] inner_ext = as_const_int(inner.extent) outer_ext = as_const_int(outer.extent) assert inner_ext is not None and outer_ext is not None assert outer_ext == 16 and inner_ext == 4 empty = { outer.loop_var: tvm.const(0, 'int32'), inner.loop_var: tvm.const(0, 'int32') } operands = [] indeces = [] for elem in loads: iters = [i.loop_var for i in outer_loops + loops] coef = tvm.arith.DetectLinearEquation( elem.index, iters) base_index = sum( i * j for i, j in zip(iters[:-2], coef)) + coef[-1] inner_stride = as_const_int(coef[-2]) outer_stride = as_const_int(coef[-3]) assert inner_stride is not None and outer_stride is not None if tvm.ir_pass.Equal(elem.buffer_var, store[0].buffer_var): index = tvm.tir.Ramp(base_index, tvm.const(1, 'int32'), 16) continue indeces = [] for i in range(outer_ext): for j in range(inner_ext): indeces.append(i * outer_stride + j * inner_stride) bound = max(indeces) + 1 to_load = tvm.tir.Ramp(base_index, tvm.const(1, 'int32'), bound) value = tvm.tir.Load(elem.dtype + 'x%d' % bound, elem.buffer_var, to_load, tvm.const(1, 'int32x%d' % bound)) assert 64 % bound == 0 operands.append( tvm.tir.Shuffle( [value] * (64 // bound), [tvm.const(i, 'int32') for i in indeces])) buffer_var = store[0].buffer_var operands = [ tvm.tir.Load('int32x16', buffer_var, index, tvm.const(1, 'int32x16')) ] + operands operands = [ tvm.call_pure_intrin('int32x16', 'reinterpret', i) for i in operands ] res = tvm.call_llvm_intrin('int32x16', 'llvm.x86.avx512.vpdpbusd.512', tvm.const(0, 'uint32'), *operands) res = tvm.tir.Store(buffer_var, res, index, tvm.const(1, 'int32x16')) if guard[0] is not None: res = tvm.tir.IfThenElse(guard[0].condition, res, None) return res elif isinstance(op, tvm.stmt.AttrStmt): if not to_vectorize: return None if tvm.ir_pass.Equal(op.node.var, to_vectorize[-1]): to_vectorize.pop() return op.body return None
def atomic_add(x, y): return tvm.call_pure_intrin(y.dtype, "atomic_add", x, y)
def get_valid_counts_ir(data, valid_count, flag, score_threshold, id_index, score_index): """Low level IR to get valid count of bounding boxes given a score threshold. Also prepares to move valid boxes to the top of input data. Parameters ---------- data : Buffer Input data. 3-D Buffer with shape [batch_size, num_anchors, elem_length]. valid_count : Buffer 1D buffer for valid number of boxes with shape [batch_size, ]. flag : Buffer 2D Buffer of flag indicating valid data with shape [batch_size, num_anchors]. score_threshold : float32 Lower limit of score for valid bounding boxes. id_index : optional, int index of the class categories, -1 to disable. score_index: optional, int Index of the scores/confidence of boxes. Returns ------- stmt : Stmt The result IR statement. """ batch_size = data.shape[0] num_anchors = data.shape[1] elem_length = data.shape[2] ib = tvm.ir_builder.create() data = ib.buffer_ptr(data) valid_count = ib.buffer_ptr(valid_count) flag = ib.buffer_ptr(flag) atomic_add_return = ib.allocate(valid_count.dtype, (1, ), name='atomic_add_return', scope='local') one_count = tvm.const(1, dtype=valid_count.dtype) score_threshold = tvm.make.node("FloatImm", dtype="float32", value=score_threshold) id_index = tvm.make.node("IntImm", dtype="int32", value=id_index) score_index = tvm.make.node("IntImm", dtype="int32", value=score_index) max_threads = int( tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads nthread_bx = batch_size * num_anchors // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") bx = tvm.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx idxd = tvm.indexdiv # initialize valid_count with ib.if_scope(tid < batch_size): valid_count[tid] = 0 # initialize flag with ib.if_scope(tid < batch_size * num_anchors): flag[tid] = 0 with ib.if_scope(tid < batch_size * num_anchors): i = idxd(tid, num_anchors) with ib.if_scope( tvm.all( data[tid * elem_length + score_index] > score_threshold, tvm.any(id_index < 0, data[tid * elem_length + id_index] >= 0))): flag[tid] = 1 atomic_add_return[0] = atomic_add( tvm.call_pure_intrin("handle", "tvm_address_of", valid_count[i]), one_count) return ib.get()