def _unipolar_conv(n, h, w, co, vh, vw, vc): return tvm.sum( ((tvm.popcount(kernel_vec[co, dh, dw, kb, vc, ci].astype('int16') & data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ib, ci].astype('int16')) - tvm.popcount(~kernel_vec[co, dh, dw, kb, vc, ci].astype('int16') & data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ib, ci]).astype('int16')) << (kb + ib).astype('int16')), axis=[dh, dw, kb, ib, ci])
def _instr(index): irb = tvm.ir_builder.create() if index == 1: irb.emit(zz.vstore(0, tvm.const(0, 'uint16x8'))) return irb.get() cnts8 = [None] * 8 cnts4 = [None] * 4 cnts2 = [None] * 2 for bw in range(w_b): for bx in range(x_b): if k_i == 16: for i in range(m): ands = ww.vload([bw, i, 0], 'uint8x16') & xx.vload( [bx, 0], 'uint8x16') cnts = tvm.popcount(ands) upper_half = tvm.call_pure_intrin( 'uint8x8', 'vectorhigh', cnts) lower_half = tvm.call_pure_intrin( 'uint8x8', 'vectorlow', cnts) cnts8[i] = upper_half + lower_half for i in range(m // 2): cnts4[i] = tvm.call_llvm_intrin( 'uint8x8', vpadd, args_1, cnts8[i * 2], cnts8[i * 2 + 1]) for i in range(m // 4): cnts2[i] = tvm.call_llvm_intrin( 'uint8x8', vpadd, args_1, cnts4[i * 2], cnts4[i * 2 + 1]) cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw + bx, dtype) out = tvm.call_llvm_intrin('uint16x8', vpadalu, args_2, zz.vload(0, 'uint16x8'), shifted_cnts) else: # ki == 8 for i in range(m): ands = ww.vload([bw, i, 0], 'uint8x8') & xx.vload( [bx, 0], 'uint8x8') cnts8[i] = tvm.popcount(ands) for i in range(m // 2): cnts4[i] = tvm.call_llvm_intrin( 'uint8x8', vpadd, args_1, cnts8[i * 2], cnts8[i * 2 + 1]) for i in range(m // 4): cnts2[i] = tvm.call_llvm_intrin( 'uint8x8', vpadd, args_1, cnts4[i * 2], cnts4[i * 2 + 1]) cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw + bx, dtype) out = tvm.call_llvm_intrin('uint16x8', vpadalu, args_2, zz.vload(0, 'uint16x8'), shifted_cnts) irb.emit(zz.vstore(0, out)) return irb.get()
def _conv(nn, ff, yy, xx): b1b2 = (b1+b2).astype(out_dtype) return tvm.sum( ((tvm.popcount(PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] & Filter_q[ff, rc, ry, rx, b2]) - tvm.popcount(PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] & ~Filter_q[ff, rc, ry, rx, b2])) << (b1b2)).astype(out_dtype), axis=[rc, ry, rx, b2, b1]).astype(out_dtype)
def _conv(n, h, w, co, vh, vw, vc): b1b2 = (b1+b2).astype(out_dtype) if dorefa: return tvm.sum( (tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1].astype(out_dtype) & kernel_vec[co, dh, dw, ci, vc, b2].astype(out_dtype)) - tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1].astype(out_dtype) & ~kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype)) << b1b2, axis=[dh, dw, ci, b1, b2]) return tvm.sum(tvm.popcount( data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1] & kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype) << b1b2, axis=[dh, dw, ci, b1, b2])
def _conv(n, h, w, co, vh, vw, vc): b1b2 = (b1+b2).astype(out_dtype) if unipolar: return tvm.sum( ((tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1] & kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype) - tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1]& ~kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype)) << b1b2), axis=[dh, dw, ci, b1, b2]) return tvm.sum(tvm.popcount( data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1] & kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype) << b1b2, axis=[dh, dw, ci, b1, b2])
def _conv(n, co, h, w, vh, vw, vc): b1b2 = (b1+b2).astype(out_dtype) if unipolar: return tvm.sum((tvm.popcount( data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1].astype(out_dtype) & kernel_vec[co, ci, dh, dw, b2, vc].astype(out_dtype)) - tvm.popcount( data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1].astype(out_dtype) & ~kernel_vec[co, ci, dh, dw, b2, vc]).astype(out_dtype)) << b1b2, axis=[ci, dh, dw, b1, b2]) return tvm.sum((tvm.popcount( data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1] & kernel_vec[co, ci, dh, dw, b2, vc])).astype(out_dtype) << b1b2, axis=[ci, dh, dw, b1, b2])
def binary_dense(data, weight): """Binary matrix multiplication using xor and bit-count. Parameters ---------- data : tvm.Tensor 2-D with shape [batch, in_dim], dtype is uint32. weight : tvm.Tensor 2-D with shape [out_dim, in_dim], dtype is uint32. Returns ------- output : tvm.Tensor 2-D with shape [batch, out_dim], dtype is float32. """ assert data.dtype == 'uint32' and weight.dtype == 'uint32', \ "dtype of data and weight should be uint32" assert len(data.shape) == 2 and len(weight.shape) == 2, \ "only support 2-dim binary dense" batch, in_dim = data.shape out_dim, _ = weight.shape k = tvm.reduce_axis((0, in_dim), name='k') matmul = tvm.compute((batch, out_dim), lambda i, j: \ tvm.sum(tvm.popcount(data[i, k] ^ weight[j, k]), axis=k), \ tag='binary_dense') return tvm.compute((batch, out_dim), lambda i, j: \ 32 * in_dim - 2. * matmul(i, j), \ tag=tag.ELEMWISE)
def run(dtype): # graph n = tvm.convert(1024) A = tvm.placeholder((n,), name='A', dtype=dtype) B = tvm.compute(A.shape, lambda *i: tvm.popcount(A(*i)), name='B') s = tvm.create_schedule(B.op) # simple schedule num_thread = 8 bx, tx = s[B].split(B.op.axis[0], factor=num_thread) def check_device(device): ctx = tvm.context(device, 0) if not ctx.exist: print("skip because %s is not enabled.." % device) return target = tvm.target.create(device) if "cpu" not in target.keys: s[B].bind(bx, tvm.thread_axis("blockIdx.x")) s[B].bind(tx, tvm.thread_axis("threadIdx.x")) func = tvm.build(s, [A, B], device) # launch the kernel. n = 1024 a = tvm.nd.array(np.random.randint(low=0, high=1000, size=n, dtype=A.dtype), ctx) b = tvm.nd.array(np.zeros(shape=n, dtype=B.dtype), ctx) func(a, b) np.testing.assert_allclose( b.asnumpy(), list(map(lambda x: bin(x).count('1'), a.asnumpy())), rtol=1e-5) check_device("llvm") check_device("cuda") check_device("opencl") if dtype == "uint32": check_device("metal") check_device("vulkan")
def run(dtype): # graph n = tvm.convert(1024) A = tvm.placeholder((n,), name='A', dtype=dtype) B = tvm.compute(A.shape, lambda *i: tvm.popcount(A(*i)), name='B') s = tvm.create_schedule(B.op) # simple schedule num_thread = 8 bx, tx = s[B].split(B.op.axis[0], factor=num_thread) def check_device(device): ctx = tvm.context(device, 0) if not ctx.exist: print("skip because %s is not enabled.." % device) return target = tvm.target.create(device) if "cpu" not in target.keys: s[B].bind(bx, tvm.thread_axis("blockIdx.x")) s[B].bind(tx, tvm.thread_axis("threadIdx.x")) func = tvm.build(s, [A, B], device) # launch the kernel. n = 1024 a = tvm.nd.array(np.random.randint(low=0, high=1000, size=n, dtype=A.dtype), ctx) b = tvm.nd.array(np.zeros(shape=n, dtype=B.dtype), ctx) func(a, b) np.testing.assert_allclose( b.asnumpy(), list(map(lambda x: bin(x).count('1'), a.asnumpy())), rtol=1e-5) check_device("llvm") check_device("cuda") check_device("opencl") check_device("metal") if dtype == "uint32": check_device("vulkan")
def _conv(n, co, h, w, vh, vw, vc): b1b2 = (b1 + b2).astype(out_dtype) if dorefa: return tvm.sum( (tvm.popcount(data_vec[n, h, w, ci, vh * HSTR + dh, vw * WSTR + dw, b1].astype(out_dtype) & kernel_vec[co, ci, dh, dw, b2, vc].astype(out_dtype)) - tvm.popcount(data_vec[n, h, w, ci, vh * HSTR + dh, vw * WSTR + dw, b1].astype(out_dtype) & ~kernel_vec[co, ci, dh, dw, b2, vc]).astype( out_dtype)) << b1b2, axis=[ci, dh, dw, b1, b2]) return tvm.sum((tvm.popcount( data_vec[n, h, w, ci, vh * HSTR + dh, vw * WSTR + dw, b1] & kernel_vec[co, ci, dh, dw, b2, vc])).astype(out_dtype) << b1b2, axis=[ci, dh, dw, b1, b2])
def bgemm_topi(Y, X, K, dtype="uint64"): DB = 1 WB = 1 out_dtype = dtype data_packed = tvm.placeholder((Y, DB, K), dtype=dtype, name="A") weight_packed = tvm.placeholder((X, WB, K), dtype=dtype, name="B") oshape = (Y, X) k = tvm.reduce_axis((0, K), name='k') db = tvm.reduce_axis((0, DB), name='db') wb = tvm.reduce_axis((0, WB), name='wb') matmul = tvm.compute( oshape, lambda i, j: tvm.sum(tvm.popcount(weight_packed[ j, wb, k] & data_packed[i, db, k]).astype(out_dtype) << (db + wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense') s = tvm.create_schedule(matmul.op) cfg = autotvm.get_config() CC = s.cache_write(matmul, "global") y, x = s[matmul].op.axis yo, yi = cfg.define_split("tile_y", y, num_outputs=2, filter=lambda x: x.size[-1] <= 8) xo, xi = cfg.define_split("tile_x", x, num_outputs=2, filter=lambda x: x.size[-1] <= 8) yo, yi = cfg["tile_y"].apply(s, matmul, y) xo, xi = cfg["tile_x"].apply(s, matmul, x) s[matmul].reorder(yo, xo, yi, xi) cfg.define_knob("compute_at_axis", [0, 1, 2]) if cfg["compute_at_axis"].val == 0: s[CC].compute_at(s[matmul], xo) elif cfg["compute_at_axis"].val == 1: s[CC].compute_at(s[matmul], yi) elif cfg["compute_at_axis"].val == 2: s[CC].compute_at(s[matmul], xi) yc, xc = s[CC].op.axis wb, db, k = s[CC].op.reduce_axis cfg.define_reorder("reorder_0", [k, yc, xc], policy="all") cfg["reorder_0"].apply(s, CC, [k, yc, xc]) cfg.add_flop(2 * Y * X * K * int(dtype[4:])) return s, [data_packed, weight_packed, matmul]
def _instr(index): irb = tvm.ir_builder.create() if index == 1: irb.emit(zz.vstore(0, tvm.const(0, 'uint16x8'))) return irb.get() cnts8 = [None] * 8 cnts4 = [None] * 4 cnts2 = [None] * 2 for bw in range(w_b): for bx in range(x_b): if k_i == 16: for i in range(m): ands = ww.vload([bw, i, 0], 'uint8x16') & xx.vload([bx, 0], 'uint8x16') cnts = tvm.popcount(ands) upper_half = tvm.call_pure_intrin('uint8x8', 'vectorhigh', cnts) lower_half = tvm.call_pure_intrin('uint8x8', 'vectorlow', cnts) cnts8[i] = upper_half + lower_half for i in range(m//2): cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd, args_1, cnts8[i*2], cnts8[i*2+1]) for i in range(m//4): cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd, args_1, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw+bx, dtype) out = tvm.call_llvm_intrin('uint16x8', vpadalu, args_2, zz.vload(0, 'uint16x8'), shifted_cnts) else: # ki == 8 for i in range(m): ands = ww.vload([bw, i, 0], 'uint8x8') & xx.vload([bx, 0], 'uint8x8') cnts8[i] = tvm.popcount(ands) for i in range(m//2): cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd, args_1, cnts8[i*2], cnts8[i*2+1]) for i in range(m//4): cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd, args_1, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw+bx, dtype) out = tvm.call_llvm_intrin('uint16x8', vpadalu, args_2, zz.vload(0, 'uint16x8'), shifted_cnts) irb.emit(zz.vstore(0, out)) return irb.get()
def bitserial_dense(data, weight, data_bits, weight_bits, pack_dtype='uint32', out_dtype='int16', unipolar=True): """The default implementation of bitserial dense in topi. Parameters ---------- data : tvm.Tensor 2-D with shape [batch, in_dim] weight : tvm.Tensor 2-D with shape [out_dim, in_dim] or 3-D with shape [out_dim, weight_bits, in_dim] Returns ------- output : tvm.Tensor 2-D with shape [batch, out_dim] """ data_packed = bitpack(data, data_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype) if len(weight.shape) == 2: weight_packed = bitpack(weight, weight_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype) else: weight_packed = weight Y, DB, K = get_const_tuple(data_packed.shape) X, WB, _ = get_const_tuple(weight_packed.shape) oshape = (Y, X) k = tvm.reduce_axis((0, K), name='k') db = tvm.reduce_axis((0, DB), name='db') wb = tvm.reduce_axis((0, WB), name='wb') matmul_unipolar = tvm.compute(oshape, lambda i, j: tvm.sum( (tvm.popcount(weight_packed[j, wb, k] & data_packed[i, db, k]) - tvm.popcount(~weight_packed[j, wb, k] & data_packed[i, db, k])).astype(out_dtype) << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense_unipolar') matmul = tvm.compute(oshape, lambda i, j: tvm.sum( tvm.popcount(weight_packed[j, wb, k] & data_packed[i, db, k]).astype(out_dtype) << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense') if unipolar: return matmul_unipolar return matmul
def check_correct_assembly(type, elements, counts): n = tvm.convert(elements) A = tvm.placeholder(n, dtype=type, name='A') B = tvm.compute(A.shape, lambda i: tvm.popcount(A[i]), name='B') s = tvm.create_schedule(B.op) s[B].vectorize(s[B].op.axis[0]) f = tvm.build(s, [A, B], target) # Verify we see the correct number of vpaddl and vcnt instructions in the assembly assembly = f.get_source('asm') matches = re.findall("vpaddl", assembly) assert (len(matches) == counts) matches = re.findall("vcnt", assembly) assert (len(matches) == 1)
def test_popcount_llvm(): # graph n = tvm.var('n') A = tvm.placeholder((n,), name='A', dtype="uint32") B = tvm.compute(A.shape, lambda *i: tvm.popcount(A(*i)), name='B') s = tvm.create_schedule(B.op) if not tvm.module.enabled("llvm"): return f = tvm.build(s, [A, B], "llvm") ctx = tvm.cpu(0) # launch the kernel. n = 1024 a = tvm.nd.array(np.random.randint(low=0, high=1000, size=n, dtype=A.dtype), ctx) b = tvm.nd.array(np.zeros(shape=n, dtype=B.dtype), ctx) f(a, b) np.testing.assert_allclose( b.asnumpy(), list(map(lambda x: bin(x).count('1'), a.asnumpy())), rtol=1e-5)
def _instr(index): irb = tvm.ir_builder.create() if index == 1: # reduce reset irb.emit(zz.vstore(0, tvm.const(0, return_dtype))) return irb.get() # body and reduce update cnts8 = [None] * 8 cnts4 = [None] * 4 cnts2 = [None] * 2 for bw in range(w_b): for bx in range(x_b): if k_i == 16: for i in range(m): w_ = ww.vload([bw, i, 0], 'uint8x16').astype(full_dtype) x_ = xx.vload([bx, 0], 'uint8x16').astype(full_dtype) if unipolar: cnts = tvm.popcount(w_ & x_) - tvm.popcount(~w_ & x_) else: cnts = tvm.popcount(w_ & x_) upper_half = tvm.call_pure_intrin(half_dtype, 'vectorhigh', cnts) lower_half = tvm.call_pure_intrin(half_dtype, 'vectorlow', cnts) cnts8[i] = upper_half + lower_half for i in range(m//2): cnts4[i] = tvm.call_llvm_intrin(half_dtype, vpadd, args_1, cnts8[i*2], cnts8[i*2+1]) for i in range(m//4): cnts2[i] = tvm.call_llvm_intrin(half_dtype, vpadd, args_1, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.call_pure_intrin(full_dtype, 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw+bx, pack_dtype) out = tvm.call_llvm_intrin(return_dtype, vpadalu, args_2, zz.vload(0, return_dtype), shifted_cnts) else: # ki == 8 for i in range(m): w_ = ww.vload([bw, i, 0], 'uint8x8').astype(half_dtype) x_ = xx.vload([bx, 0], 'uint8x8').astype(half_dtype) if unipolar: cnts8[i] = tvm.popcount(w_ & x_) - tvm.popcount(~w_ & x_) else: cnts8[i] = tvm.popcount(w_ & x_) for i in range(m//2): cnts4[i] = tvm.call_llvm_intrin(half_dtype, vpadd, args_1, cnts8[i*2], cnts8[i*2+1]) for i in range(m//4): cnts2[i] = tvm.call_llvm_intrin(half_dtype, vpadd, args_1, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.call_pure_intrin(full_dtype, 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw+bx, pack_dtype) out = tvm.call_llvm_intrin(return_dtype, vpadalu, args_2, zz.vload(0, return_dtype), shifted_cnts) irb.emit(zz.vstore(0, out)) return irb.get()
def _intrin_popcount(m, k_i, w_b, x_b): dtype = 'uint8' w = tvm.placeholder((w_b, m, k_i), dtype=dtype, name='w') x = tvm.placeholder((x_b, k_i,), dtype=dtype, name='x') k = tvm.reduce_axis((0, k_i), name='k') bw = tvm.reduce_axis((0, w_b), name='bw') bx = tvm.reduce_axis((0, x_b), name='bx') z = tvm.compute((m,), lambda i: tvm.sum(tvm.popcount(w[bw, i, k].astype('uint16') & x[bx, k].astype('uint16')) << (bw+bx).astype('uint16'), axis=[bw, bx, k]), name='z') Wb = tvm.decl_buffer(w.shape, w.dtype, name="W", offset_factor=k_i, strides=[tvm.var('ldw'), tvm.var('ldw'), 1]) Xb = tvm.decl_buffer(x.shape, x.dtype, name="X", offset_factor=k_i, strides=[tvm.var('ldw'), 1]) def _intrin_func(ins, outs): ww, xx = ins zz = outs[0] vpadd = "llvm.arm.neon.vpadd.v8u8" vpadalu = "llvm.arm.neon.vpadalu.v16u8.v8u16" args_1 = tvm.const(1, 'uint32') args_2 = tvm.const(2, 'uint32') def _instr(index): irb = tvm.ir_builder.create() if index == 1: irb.emit(zz.vstore(0, tvm.const(0, 'uint16x8'))) return irb.get() cnts8 = [None] * 8 cnts4 = [None] * 4 cnts2 = [None] * 2 for bw in range(w_b): for bx in range(x_b): if k_i == 16: for i in range(m): ands = ww.vload([bw, i, 0], 'uint8x16') & xx.vload([bx, 0], 'uint8x16') cnts = tvm.popcount(ands) upper_half = tvm.call_pure_intrin('uint8x8', 'vectorhigh', cnts) lower_half = tvm.call_pure_intrin('uint8x8', 'vectorlow', cnts) cnts8[i] = upper_half + lower_half for i in range(m//2): cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd, args_1, cnts8[i*2], cnts8[i*2+1]) for i in range(m//4): cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd, args_1, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw+bx, dtype) out = tvm.call_llvm_intrin('uint16x8', vpadalu, args_2, zz.vload(0, 'uint16x8'), shifted_cnts) else: # ki == 8 for i in range(m): ands = ww.vload([bw, i, 0], 'uint8x8') & xx.vload([bx, 0], 'uint8x8') cnts8[i] = tvm.popcount(ands) for i in range(m//2): cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd, args_1, cnts8[i*2], cnts8[i*2+1]) for i in range(m//4): cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd, args_1, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw+bx, dtype) out = tvm.call_llvm_intrin('uint16x8', vpadalu, args_2, zz.vload(0, 'uint16x8'), shifted_cnts) irb.emit(zz.vstore(0, out)) return irb.get() # body, reset, update return _instr(0), _instr(1), _instr(2) with tvm.build_config(offset_factor=1, partition_const_loop=True): return tvm.decl_tensor_intrin(z.op, _intrin_func, binds={w: Wb, x:Xb})
def _intrin_popcount(m, k_i, w_b, x_b, unipolar): pack_dtype = 'uint8' w = tvm.placeholder((w_b, m, k_i), dtype=pack_dtype, name='w') x = tvm.placeholder(( x_b, k_i, ), dtype=pack_dtype, name='x') k = tvm.reduce_axis((0, k_i), name='k') bw = tvm.reduce_axis((0, w_b), name='bw') bx = tvm.reduce_axis((0, x_b), name='bx') if unipolar: dtype = 'int16' z = tvm.compute( (m, ), lambda i: tvm.sum((tvm.popcount(w[bw, i, k].astype(dtype) & x[ bx, k].astype(dtype)) - tvm.popcount(~w[bw, i, k].astype( dtype) & x[bx, k].astype(dtype))) << (bw + bx).astype(dtype), axis=[bw, bx, k]), name='z') else: dtype = 'uint16' z = tvm.compute((m, ), lambda i: tvm.sum(tvm.popcount(w[bw, i, k].astype( dtype) & x[bx, k].astype(dtype)) << (bw + bx).astype(dtype), axis=[bw, bx, k]), name='z') Wb = tvm.decl_buffer(w.shape, w.dtype, name="W", offset_factor=k_i, strides=[tvm.var('ldw'), tvm.var('ldw'), 1]) # stride can be inferred Xb = tvm.decl_buffer(x.shape, x.dtype, name="X", offset_factor=k_i, strides=[tvm.var('ldw'), 1]) Zb = tvm.decl_buffer(z.shape, z.dtype, name="Z", offset_factor=1, strides=[1]) def _intrin_func(ins, outs): ww, xx = ins zz = outs[0] args_1 = tvm.const(1, 'uint32') args_2 = tvm.const(2, 'uint32') if unipolar: vpadd = "llvm.arm.neon.vpadd.v8i8" vpadalu = "llvm.arm.neon.vpadals.v16i8.v8i16" full_dtype = 'int8x16' half_dtype = 'int8x8' return_dtype = 'int16x8' else: vpadd = "llvm.arm.neon.vpadd.v8u8" vpadalu = "llvm.arm.neon.vpadalu.v16u8.v8u16" full_dtype = 'uint8x16' half_dtype = 'uint8x8' return_dtype = 'uint16x8' def _instr(index): irb = tvm.ir_builder.create() if index == 1: # reduce reset irb.emit(zz.vstore(0, tvm.const(0, return_dtype))) return irb.get() # body and reduce update cnts8 = [None] * 8 cnts4 = [None] * 4 cnts2 = [None] * 2 for bw in range(w_b): for bx in range(x_b): if k_i == 16: for i in range(m): w_ = ww.vload([bw, i, 0], 'uint8x16').astype(full_dtype) x_ = xx.vload([bx, 0], 'uint8x16').astype(full_dtype) if unipolar: cnts = tvm.popcount(w_ & x_) - tvm.popcount(~w_ & x_) else: cnts = tvm.popcount(w_ & x_) upper_half = tvm.call_pure_intrin( half_dtype, 'vectorhigh', cnts) lower_half = tvm.call_pure_intrin( half_dtype, 'vectorlow', cnts) cnts8[i] = upper_half + lower_half for i in range(m // 2): cnts4[i] = tvm.call_llvm_intrin( half_dtype, vpadd, args_1, cnts8[i * 2], cnts8[i * 2 + 1]) for i in range(m // 4): cnts2[i] = tvm.call_llvm_intrin( half_dtype, vpadd, args_1, cnts4[i * 2], cnts4[i * 2 + 1]) cnts = tvm.call_pure_intrin(full_dtype, 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw + bx, pack_dtype) out = tvm.call_llvm_intrin(return_dtype, vpadalu, args_2, zz.vload(0, return_dtype), shifted_cnts) else: # ki == 8 for i in range(m): w_ = ww.vload([bw, i, 0], 'uint8x8').astype(half_dtype) x_ = xx.vload([bx, 0], 'uint8x8').astype(half_dtype) if unipolar: cnts8[i] = tvm.popcount( w_ & x_) - tvm.popcount(~w_ & x_) else: cnts8[i] = tvm.popcount(w_ & x_) for i in range(m // 2): cnts4[i] = tvm.call_llvm_intrin( half_dtype, vpadd, args_1, cnts8[i * 2], cnts8[i * 2 + 1]) for i in range(m // 4): cnts2[i] = tvm.call_llvm_intrin( half_dtype, vpadd, args_1, cnts4[i * 2], cnts4[i * 2 + 1]) cnts = tvm.call_pure_intrin(full_dtype, 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw + bx, pack_dtype) out = tvm.call_llvm_intrin(return_dtype, vpadalu, args_2, zz.vload(0, return_dtype), shifted_cnts) irb.emit(zz.vstore(0, out)) return irb.get() # body, reset, update return _instr(0), _instr(1), _instr(2) with tvm.build_config(offset_factor=1, partition_const_loop=True): return tvm.decl_tensor_intrin(z.op, _intrin_func, binds={ w: Wb, x: Xb, z: Zb })
def bitserial_dense_generic(cfg, data, weight, data_bits, weight_bits, pack_dtype, out_dtype, unipolar): """The default implementation of bitserial dense in topi. Parameters ---------- data : tvm.Tensor 2-D with shape [batch, in_dim] weight : tvm.Tensor 2-D with shape [out_dim, in_dim] Returns ------- output : tvm.Tensor 2-D with shape [batch, out_dim] """ data_packed = bitpack(data, data_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype) if len(weight.shape) == 2: weight_packed = bitpack(weight, weight_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype) else: weight_packed = weight batch, DB, in_dim = get_const_tuple(data_packed.shape) out_dim, WB, in_dim = get_const_tuple(weight_packed.shape) # Pad Inputs so that microkernel can be used # out_dim and in_dim need to be multiples of 8 if out_dim % 8 != 0: out_dim_pad = out_dim % 8 data_packed = pad(data_packed, [0, 0, 0], [out_dim_pad, 0, 0], name='PaddedInput') out_dim += out_dim_pad ######## Search space x, y = cfg.axis(batch), cfg.axis(out_dim) db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis( in_dim) ko, ki = cfg.define_split( 'tile_k', k, num_outputs=2, filter=lambda xx: xx.size[-1] == 8 or xx.size[-1] == 16) xo, xi = cfg.define_split('tile_x', x, num_outputs=2) yo, yi = cfg.define_split('tile_y', y, num_outputs=2, filter=lambda xx: xx.size[-1] == 8) cfg.define_reorder('reorder_0', [yo, xo, ko, xi, wb, db, yi, ki], policy='candidate', candidate=[[yo, xo, ko, xi, wb, db, yi, ki], [yo, xo, xi, ko, wb, db, yi, ki], [yo, xo, ko, xi, wb, db, yi, ki]]) ###### Compute rule VY = cfg['tile_y'].size[-1] VK = cfg['tile_k'].size[-1] wvshape = (out_dim // VY, in_dim // VK, WB, VY, VK) oshape = (batch, out_dim) k = tvm.reduce_axis((0, in_dim), name='k') db = tvm.reduce_axis((0, DB), name='db') wb = tvm.reduce_axis((0, WB), name='wb') # Tile data and weights weight_vec = tvm.compute(wvshape, lambda yo, ko, wb, vy, vk: weight_packed[ yo * VY + vy][wb][ko * VK + vk], name='weight_vec') matmul_unipolar = tvm.compute( oshape, lambda x, y: tvm.sum((tvm.popcount( weight_vec[y // VY, k // VK, wb, y % VY, k % VK].astype(out_dtype) & data_packed[x, db, k].astype(out_dtype)) - tvm.popcount( ~weight_vec[y // VY, k // VK, wb, y % VY, k % VK].astype( out_dtype) & data_packed[x, db, k].astype(out_dtype))) << (wb + db).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense_unipolar') matmul = tvm.compute( oshape, lambda x, y: tvm.sum(tvm.popcount(weight_vec[ y // VY, k // VK, wb, y % VY, k % VK].astype( out_dtype) & data_packed[x, db, k].astype(out_dtype)) << (wb + db).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense') cfg.add_flop(batch * out_dim * in_dim * binary_op_multiplier(pack_dtype)) if unipolar: return matmul_unipolar return matmul
def _conv(nn, yy, xx, ff): b1b2 = (b1+b2).astype(out_dtype) return tvm.sum((tvm.popcount( PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1] & Filter_q[ry, rx, rc, ff, b2]) << b1b2).astype(out_dtype), axis=[rc, ry, rx, b2, b1])
def _conv(nn, yy, xx, ff): b1b2 = (b1 + b2).astype(out_dtype) return tvm.sum((tvm.popcount( PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1] & Filter_q[ry, rx, rc, ff, b2]) << b1b2).astype(out_dtype), axis=[rc, ry, rx, b2, b1])
def _conv(n, h, w, co, vh, vw, vc): return tvm.sum((tvm.popcount( kernel_vec[co, dh, dw, kb, vc, ci].astype('uint16') & data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ib, ci].astype('uint16')) << (kb + ib).astype('uint16')), axis=[dh, dw, kb, ib, ci])
def bitserial_dense(cfg, data, weight, data_bits, weight_bits, pack_dtype='uint32', out_dtype='int16', unipolar=True): """Bitserial dense implementation. TODO: Why are these separate Parameters ---------- data : tvm.Tensor 2-D with shape [batch, in_dim] weight : tvm.Tensor 2-D with shape [out_dim, in_dim] or 3-D with shape [out_dim, weight_bits, in_dim] Returns ------- output : tvm.Tensor 2-D with shape [batch, out_dim] """ data_packed = bitpack(data, data_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype) if len(weight.shape) == 2: weight_packed = bitpack(weight, weight_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype) else: weight_packed = weight Y, DB, K = get_const_tuple(data_packed.shape) X, WB, _ = get_const_tuple(weight_packed.shape) ######## Search space x, y = cfg.axis(X), cfg.axis(Y) db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(K) ko, ki = cfg.define_split('tile_k', k, num_outputs=2) yo, yi = cfg.define_split('tile_y', y, num_outputs=2) xo, xi = cfg.define_split('tile_x', x, num_outputs=2) cfg.define_reorder('reorder_0', [yo, xo, ko, yi, wb, db, ki, xi], policy='candidate', candidate=[[yo, xo, ko, yi, wb, db, ki, xi], [yo, xo, yi, ko, wb, db, ki, xi]]) cfg.define_annotate('ann_reduce', [db, wb], policy='try_unroll') cfg.define_annotate('ann_spatial', [yi, xi], policy='try_unroll_vec') ###### Compute rule VX = cfg['tile_x'].size[-1] wvshape = (X // VX, WB, VX, K) oshape = (Y, X) k = tvm.reduce_axis((0, K), name='k') db = tvm.reduce_axis((0, DB), name='db') wb = tvm.reduce_axis((0, WB), name='wb') # Tile data and weights weight_vec = tvm.compute( wvshape, lambda xo, wb, vx, k: weight_packed[xo * VX + vx][wb][k], name='weight_vec') idxdiv = tvm.indexdiv idxmod = tvm.indexmod matmul_unipolar = tvm.compute( oshape, lambda i, j: tvm.sum((tvm.popcount(weight_vec[ idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]) - tvm.popcount(~weight_vec[idxdiv( j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k])).astype( out_dtype) << (db + wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense_unipolar') matmul = tvm.compute(oshape, lambda i, j: tvm.sum(tvm.popcount(weight_vec[idxdiv( j, VX), wb, idxmod(j, VX), k] & data_packed[ i, db, k]).astype(out_dtype) << (db + wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense') # binary ops cfg.add_flop(2 * Y * X * K * binary_op_multiplier(pack_dtype)) if unipolar: return matmul_unipolar return matmul
def bitserial_dense_default(cfg, data, weight, data_bits, weight_bits, pack_dtype='uint32', out_dtype='int16', unipolar=True): """Bitserial dense implementation. TODO: Why are these separate Parameters ---------- data : tvm.Tensor 2-D with shape [batch, in_dim] weight : tvm.Tensor 2-D with shape [out_dim, in_dim] or 3-D with shape [out_dim, weight_bits, in_dim] Returns ------- output : tvm.Tensor 2-D with shape [batch, out_dim] """ data_packed = bitpack(data, data_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype) if len(weight.shape) == 2: weight_packed = bitpack(weight, weight_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype) else: weight_packed = weight Y, DB, K = get_const_tuple(data_packed.shape) X, WB, _ = get_const_tuple(weight_packed.shape) ######## Search space x, y = cfg.axis(X), cfg.axis(Y) db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(K) ko, ki = cfg.define_split('tile_k', k, policy='all', num_outputs=2) yo, yi = cfg.define_split('tile_y', y, policy='all', num_outputs=2) xo, xi = cfg.define_split('tile_x', x, policy='all', num_outputs=2) cfg.define_reorder('reorder_0', [yo, xo, ko, yi, wb, db, ki, xi], policy='candidate', candidate=[ [yo, xo, ko, yi, wb, db, ki, xi], [yo, xo, yi, ko, wb, db, ki, xi]]) cfg.define_annotate('ann_reduce', [db, wb], policy='try_unroll') cfg.define_annotate('ann_spatial', [yi, xi], policy='try_unroll_vec') ###### Compute rule VX = cfg['tile_x'].size[-1] wvshape = (X//VX, WB, VX, K) oshape = (Y, X) k = tvm.reduce_axis((0, K), name='k') db = tvm.reduce_axis((0, DB), name='db') wb = tvm.reduce_axis((0, WB), name='wb') # Tile data and weights weight_vec = tvm.compute(wvshape, lambda xo, wb, vx, k: weight_packed[xo*VX+vx][wb][k], name='weight_vec') matmul_unipolar = tvm.compute(oshape, lambda i, j: tvm.sum( (tvm.popcount(weight_vec[j//VX, wb, j%VX, k] & data_packed[i, db, k]) - tvm.popcount(~weight_vec[j//VX, wb, j%VX, k] & data_packed[i, db, k])).astype(out_dtype) << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense_unipolar') matmul = tvm.compute(oshape, lambda i, j: tvm.sum( tvm.popcount(weight_vec[j//VX, wb, j%VX, k] & data_packed[i, db, k]).astype(out_dtype) << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense') # binary ops cfg.add_flop(2 * Y * X * K * binary_op_multiplier(pack_dtype)) if unipolar: return matmul_unipolar return matmul