def _instr(index): ib = tvm.ir_builder.create() if index == 1: ib.emit(outs[0].vstore(0, tvm.const(0, 'int32x16'))) return ib.get() a_int8 = ins[0].vload([0], "uint8x4") re_int32 = tvm.call_pure_intrin('int32', 'reinterpret', a_int8) vec_ai32 = re_int32.astype('int32x16') vec_a = tvm.call_pure_intrin('int8x64', 'reinterpret', vec_ai32) vec_b = ins[1].vload([0, 0], "int8x64") vec_one = tvm.const(1, "int16x32") pair_reduction = tvm.call_llvm_intrin('int16x32', 'llvm.x86.avx512.pmaddubs.w.512', tvm.const(0, 'uint32'), vec_a, vec_b) quad_reduction = tvm.call_llvm_intrin('int32x16', 'llvm.x86.avx512.pmaddw.d.512', tvm.const(0, 'uint32'), pair_reduction, vec_one) if index == 0: ib.emit(outs[0].vstore(0, quad_reduction)) else: ib.emit(outs[0].vstore(0, quad_reduction + outs[0].vload([0], 'int32x16'))) return ib.get()
def _instr(index): ib = tvm.ir_builder.create() if index == 1: ib.emit(outs[0].vstore( 0, tvm.const(0, '%s32x%d' % (dtype, int32_lanes)))) return ib.get() dtype_a = '%s8x%d' % (dtype, num_int8_elements) dtype_b = '%s8x%d' % (dtype, int32_lanes * num_int8_elements) dtype_c = '%s32x%d' % (dtype, int32_lanes) a_int8 = ins[0].vload([0], dtype_a) re_int32 = tvm.call_pure_intrin('%s32' % dtype, 'reinterpret', a_int8) # broadcast a vec_ai32 = re_int32.astype(dtype_c) vec_a = tvm.call_pure_intrin(dtype_b, 'reinterpret', vec_ai32) vec_b = ins[1].vload([0, 0], dtype_b) vec_c = outs[0].vload([0], dtype_c) inst = 'udot' if dtype == 'uint' else 'sdot' inst = 'llvm.aarch64.neon.%s.v%di32.v%di8' % ( inst, int32_lanes, int32_lanes * num_int8_elements) vdot = tvm.call_llvm_intrin(dtype_c, inst, tvm.const(2, 'uint32'), vec_c, vec_a, vec_b) ib.emit(outs[0].vstore(0, vdot)) return ib.get()
def test_llvm_lookup_intrin(): ib = tvm.ir_builder.create() m = tvm.size_var("m") A = ib.pointer("uint8x8", name="A") x = tvm.call_llvm_intrin("uint8x8", "llvm.ctpop.i8", tvm.const(1, 'uint32'), A) ib.emit(x) body = ib.get() func = tvm.ir_pass.MakeAPI(body, "ctpop", [A], 1, True) fcode = tvm.build(func, None, "llvm")
def test_llvm_lookup_intrin(): ib = tvm.ir_builder.create() m = tvm.var("m") A = ib.pointer("uint8x8", name="A") x = tvm.call_llvm_intrin("uint8x8", "llvm.ctpop.i8", tvm.const(1, 'uint32'), A) ib.emit(x) body = ib.get() func = tvm.ir_pass.MakeAPI(body, "ctpop", [A], 1, True) fcode = tvm.build(func, None, "llvm")
def _instr(index): ib = tvm.ir_builder.create() if index == 1: ib.emit(outs[0].vstore(0, tvm.const(0, 'int32x16'))) return ib.get() a_int8 = ins[0].vload([0], "uint8x4") re_int32 = tvm.call_pure_intrin('int32', 'reinterpret', a_int8) vec_ai32 = re_int32.astype('int32x16') vec_b = ins[1].vload([0, 0], "int8x64") vnni_inst_name = 'llvm.x86.avx512.vpdpbusd.512' llvm_id = tvm.target.codegen.llvm_lookup_intrinsic_id( vnni_inst_name) if llvm_id != 0: # VNNI is available for current LLVM version vec_bi32 = tvm.call_pure_intrin('int32x16', 'reinterpret', vec_b) vec_zero = tvm.const(0, "int32x16") quad_reduction = tvm.call_llvm_intrin( 'int32x16', 'llvm.x86.avx512.vpdpbusd.512', tvm.const(0, 'uint32'), vec_zero, vec_ai32, vec_bi32) else: # Fall back to the normal AVX512 vec_a = tvm.call_pure_intrin('int8x64', 'reinterpret', vec_ai32) vec_one = tvm.const(1, "int16x32") pair_reduction = tvm.call_llvm_intrin( 'int16x32', 'llvm.x86.avx512.pmaddubs.w.512', tvm.const(0, 'uint32'), vec_a, vec_b) quad_reduction = tvm.call_llvm_intrin( 'int32x16', 'llvm.x86.avx512.pmaddw.d.512', tvm.const(0, 'uint32'), pair_reduction, vec_one) if index == 0: ib.emit(outs[0].vstore(0, quad_reduction)) else: ib.emit(outs[0].vstore( 0, quad_reduction + outs[0].vload([0], 'int32x16'))) return ib.get()
def _instr(index): irb = tvm.ir_builder.create() if index == 1: irb.emit(zz.vstore(0, tvm.const(0, 'uint16x8'))) return irb.get() cnts8 = [None] * 8 cnts4 = [None] * 4 cnts2 = [None] * 2 for bw in range(w_b): for bx in range(x_b): if k_i == 16: for i in range(m): ands = ww.vload([bw, i, 0], 'uint8x16') & xx.vload( [bx, 0], 'uint8x16') cnts = tvm.popcount(ands) upper_half = tvm.call_pure_intrin( 'uint8x8', 'vectorhigh', cnts) lower_half = tvm.call_pure_intrin( 'uint8x8', 'vectorlow', cnts) cnts8[i] = upper_half + lower_half for i in range(m // 2): cnts4[i] = tvm.call_llvm_intrin( 'uint8x8', vpadd, args_1, cnts8[i * 2], cnts8[i * 2 + 1]) for i in range(m // 4): cnts2[i] = tvm.call_llvm_intrin( 'uint8x8', vpadd, args_1, cnts4[i * 2], cnts4[i * 2 + 1]) cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw + bx, dtype) out = tvm.call_llvm_intrin('uint16x8', vpadalu, args_2, zz.vload(0, 'uint16x8'), shifted_cnts) else: # ki == 8 for i in range(m): ands = ww.vload([bw, i, 0], 'uint8x8') & xx.vload( [bx, 0], 'uint8x8') cnts8[i] = tvm.popcount(ands) for i in range(m // 2): cnts4[i] = tvm.call_llvm_intrin( 'uint8x8', vpadd, args_1, cnts8[i * 2], cnts8[i * 2 + 1]) for i in range(m // 4): cnts2[i] = tvm.call_llvm_intrin( 'uint8x8', vpadd, args_1, cnts4[i * 2], cnts4[i * 2 + 1]) cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw + bx, dtype) out = tvm.call_llvm_intrin('uint16x8', vpadalu, args_2, zz.vload(0, 'uint16x8'), shifted_cnts) irb.emit(zz.vstore(0, out)) return irb.get()
def _instr(index): ib = tvm.ir_builder.create() if index == 1: ib.emit(outs[0].vstore(0, tvm.const(0, 'int32x16'))) return ib.get() a_int8 = ins[0].vload([0], "uint8x4") re_int32 = tvm.call_pure_intrin('int32', 'reinterpret', a_int8) vec_ai32 = re_int32.astype('int32x16') vec_a = tvm.call_pure_intrin('int8x64', 'reinterpret', vec_ai32) vec_b = ins[1].vload([0, 0], "int8x64") vec_one = tvm.const(1, "int16x32") pair_reduction = tvm.call_llvm_intrin( 'int16x32', 'llvm.x86.avx512.pmaddubs.w.512', tvm.const(0, 'uint32'), vec_a, vec_b) quad_reduction = tvm.call_llvm_intrin( 'int32x16', 'llvm.x86.avx512.pmaddw.d.512', tvm.const(0, 'uint32'), pair_reduction, vec_one) if index == 0: ib.emit(outs[0].vstore(0, quad_reduction)) else: ib.emit(outs[0].vstore( 0, quad_reduction + outs[0].vload([0], 'int32x16'))) return ib.get()
def _instr(index): ib = tvm.ir_builder.create() if index == 1: ib.emit(outs[0].vstore(0, tvm.const(0, 'int32x16'))) return ib.get() a_int8 = ins[0].vload([0], "uint8x4") re_int32 = tvm.call_pure_intrin('int32', 'reinterpret', a_int8) vec_ai32 = re_int32.astype('int32x8') vec_a = tvm.call_pure_intrin('int8x32', 'reinterpret', vec_ai32) vec_b_0 = ins[1].vload([0, 0], "int8x32") vec_b_1 = ins[1].vload([8, 0], "int8x32") vec_one = tvm.const(1, "int16x16") pair_reduction_0 = tvm.call_llvm_intrin( 'int16x16', 'llvm.x86.avx2.pmadd.ub.sw', tvm.const(0, 'uint32'), vec_a, vec_b_0) quad_reduction_0 = tvm.call_llvm_intrin('int32x8', 'llvm.x86.avx2.pmadd.wd', tvm.const(0, 'uint32'), pair_reduction_0, vec_one) pair_reduction_1 = tvm.call_llvm_intrin( 'int16x16', 'llvm.x86.avx2.pmadd.ub.sw', tvm.const(0, 'uint32'), vec_a, vec_b_1) quad_reduction_1 = tvm.call_llvm_intrin('int32x8', 'llvm.x86.avx2.pmadd.wd', tvm.const(0, 'uint32'), pair_reduction_1, vec_one) if index == 0: ib.emit(outs[0].vstore([0], quad_reduction_0)) ib.emit(outs[0].vstore([8], quad_reduction_1)) else: ib.emit(outs[0].vstore([0], quad_reduction_0 + \ outs[0].vload([0], 'int32x8'))) ib.emit(outs[0].vstore([8], quad_reduction_1 + \ outs[0].vload([8], 'int32x8'))) return ib.get()
def _instr(index): irb = tvm.ir_builder.create() if index == 1: # reduce reset irb.emit(zz.vstore(0, tvm.const(0, return_dtype))) return irb.get() # body and reduce update cnts8 = [None] * 8 cnts4 = [None] * 4 cnts2 = [None] * 2 for bw in range(w_b): for bx in range(x_b): if k_i == 16: for i in range(m): w_ = ww.vload([bw, i, 0], 'uint8x16').astype(full_dtype) x_ = xx.vload([bx, 0], 'uint8x16').astype(full_dtype) if unipolar: cnts = tvm.popcount(w_ & x_) - tvm.popcount(~w_ & x_) else: cnts = tvm.popcount(w_ & x_) upper_half = tvm.call_pure_intrin(half_dtype, 'vectorhigh', cnts) lower_half = tvm.call_pure_intrin(half_dtype, 'vectorlow', cnts) cnts8[i] = upper_half + lower_half for i in range(m//2): cnts4[i] = tvm.call_llvm_intrin(half_dtype, vpadd, args_1, cnts8[i*2], cnts8[i*2+1]) for i in range(m//4): cnts2[i] = tvm.call_llvm_intrin(half_dtype, vpadd, args_1, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.call_pure_intrin(full_dtype, 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw+bx, pack_dtype) out = tvm.call_llvm_intrin(return_dtype, vpadalu, args_2, zz.vload(0, return_dtype), shifted_cnts) else: # ki == 8 for i in range(m): w_ = ww.vload([bw, i, 0], 'uint8x8').astype(half_dtype) x_ = xx.vload([bx, 0], 'uint8x8').astype(half_dtype) if unipolar: cnts8[i] = tvm.popcount(w_ & x_) - tvm.popcount(~w_ & x_) else: cnts8[i] = tvm.popcount(w_ & x_) for i in range(m//2): cnts4[i] = tvm.call_llvm_intrin(half_dtype, vpadd, args_1, cnts8[i*2], cnts8[i*2+1]) for i in range(m//4): cnts2[i] = tvm.call_llvm_intrin(half_dtype, vpadd, args_1, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.call_pure_intrin(full_dtype, 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw+bx, pack_dtype) out = tvm.call_llvm_intrin(return_dtype, vpadalu, args_2, zz.vload(0, return_dtype), shifted_cnts) irb.emit(zz.vstore(0, out)) return irb.get()
def _instr(index): irb = tvm.ir_builder.create() if index == 1: irb.emit(zz.vstore(0, tvm.const(0, 'uint16x8'))) return irb.get() cnts8 = [None] * 8 cnts4 = [None] * 4 cnts2 = [None] * 2 for bw in range(w_b): for bx in range(x_b): if k_i == 16: for i in range(m): ands = ww.vload([bw, i, 0], 'uint8x16') & xx.vload([bx, 0], 'uint8x16') cnts = tvm.popcount(ands) upper_half = tvm.call_pure_intrin('uint8x8', 'vectorhigh', cnts) lower_half = tvm.call_pure_intrin('uint8x8', 'vectorlow', cnts) cnts8[i] = upper_half + lower_half for i in range(m//2): cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd, args_1, cnts8[i*2], cnts8[i*2+1]) for i in range(m//4): cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd, args_1, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw+bx, dtype) out = tvm.call_llvm_intrin('uint16x8', vpadalu, args_2, zz.vload(0, 'uint16x8'), shifted_cnts) else: # ki == 8 for i in range(m): ands = ww.vload([bw, i, 0], 'uint8x8') & xx.vload([bx, 0], 'uint8x8') cnts8[i] = tvm.popcount(ands) for i in range(m//2): cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd, args_1, cnts8[i*2], cnts8[i*2+1]) for i in range(m//4): cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd, args_1, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw+bx, dtype) out = tvm.call_llvm_intrin('uint16x8', vpadalu, args_2, zz.vload(0, 'uint16x8'), shifted_cnts) irb.emit(zz.vstore(0, out)) return irb.get()
def _instr(index): ib = tvm.ir_builder.create() if index == 1: for i in range(4): ib.emit(outs[0].vstore([i * 32], tvm.const(0, 'int16x32'))) return ib.get() a_int8 = ins[0].vload([0], "uint8x2") re_int16 = tvm.call_pure_intrin('int16', 'reinterpret', a_int8) vec_ai16 = re_int16.astype('int16x32') vec_a = tvm.call_pure_intrin('int8x64', 'reinterpret', vec_ai16) for i in range(4): vec_b = ins[1].vload([i * 32, 0], "int8x64") pair_reduction = tvm.call_llvm_intrin( 'int16x32', 'llvm.x86.avx512.pmaddubs.w.512', tvm.const(0, 'uint32'), vec_a, vec_b) if index == 0: ib.emit(outs[0].vstore([i * 32], pair_reduction)) else: ib.emit(outs[0].vstore( [i * 32], pair_reduction + outs[0].vload([i * 32], 'int16x32'))) return ib.get()
def vectorize(op): if isinstance(op, tvm.stmt.For): outer_loops.pop() if to_vectorize: if str(op.loop_var) == f'{to_vectorize[-1]}.init': return tvm.tir.For(op.loop_var, op.min, op.extent, tvm.stmt.For.Vectorized, op.device_api, op.body) elif str(op.loop_var) == str(to_vectorize[-1]): loops = [] loads = [] store = [None] guard = [None] def get_loops(op): if isinstance(op, tvm.stmt.For): loops.append(op) elif isinstance(op, tvm.expr.Load): loads.append(op) elif isinstance(op, tvm.stmt.Store): assert store[0] is None store[0] = op elif isinstance(op, tvm.stmt.IfThenElse): guard[0] = op tvm.ir_pass.PostOrderVisit(op, get_loops) inner, outer = loops loops = loops[::-1] inner_ext = as_const_int(inner.extent) outer_ext = as_const_int(outer.extent) assert inner_ext is not None and outer_ext is not None assert outer_ext == 16 and inner_ext == 4 empty = { outer.loop_var: tvm.const(0, 'int32'), inner.loop_var: tvm.const(0, 'int32') } operands = [] indeces = [] for elem in loads: iters = [i.loop_var for i in outer_loops + loops] coef = tvm.arith.DetectLinearEquation( elem.index, iters) base_index = sum( i * j for i, j in zip(iters[:-2], coef)) + coef[-1] inner_stride = as_const_int(coef[-2]) outer_stride = as_const_int(coef[-3]) assert inner_stride is not None and outer_stride is not None if tvm.ir_pass.Equal(elem.buffer_var, store[0].buffer_var): index = tvm.tir.Ramp(base_index, tvm.const(1, 'int32'), 16) continue indeces = [] for i in range(outer_ext): for j in range(inner_ext): indeces.append(i * outer_stride + j * inner_stride) bound = max(indeces) + 1 to_load = tvm.tir.Ramp(base_index, tvm.const(1, 'int32'), bound) value = tvm.tir.Load(elem.dtype + 'x%d' % bound, elem.buffer_var, to_load, tvm.const(1, 'int32x%d' % bound)) assert 64 % bound == 0 operands.append( tvm.tir.Shuffle( [value] * (64 // bound), [tvm.const(i, 'int32') for i in indeces])) buffer_var = store[0].buffer_var operands = [ tvm.tir.Load('int32x16', buffer_var, index, tvm.const(1, 'int32x16')) ] + operands operands = [ tvm.call_pure_intrin('int32x16', 'reinterpret', i) for i in operands ] res = tvm.call_llvm_intrin('int32x16', 'llvm.x86.avx512.vpdpbusd.512', tvm.const(0, 'uint32'), *operands) res = tvm.tir.Store(buffer_var, res, index, tvm.const(1, 'int32x16')) if guard[0] is not None: res = tvm.tir.IfThenElse(guard[0].condition, res, None) return res elif isinstance(op, tvm.stmt.AttrStmt): if not to_vectorize: return None if tvm.ir_pass.Equal(op.node.var, to_vectorize[-1]): to_vectorize.pop() return op.body return None