def intrin_gemv(m, n): w = tvm.placeholder((m, n), name='w') x = tvm.placeholder((n, ), name='x') k = tvm.reduce_axis((0, n), name='k') z = tvm.compute((m, ), lambda i: tvm.sum(w[i, k] * x[k], axis=k), name='z') Wb = tvm.decl_buffer(w.shape, w.dtype, name="W", offset_factor=16, strides=[tvm.var('ldw'), 1]) def intrin_func(ins, outs): ww, xx = ins zz = outs[0] ww_ptr = ww.access_ptr("r") xx_ptr = xx.access_ptr("r") zz_ptr = zz.access_ptr("w") body = tvm.call_packed("gemm", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) reset = tvm.call_packed("fill_zero", zz_ptr, n) update = tvm.call_packed("gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) return body, reset, update with tvm.build_config(data_alignment=16, offset_factor=16): return tvm.decl_tensor_intrin(z.op, intrin_func, binds={w: Wb})
def intrin_wmma_gemm(): n = 16 A = tvm.placeholder((n, n), name='A', dtype='float16') B = tvm.placeholder((n, n), name='B', dtype='float16') k = tvm.reduce_axis((0, n), name="k") C = tvm.compute((n, n), lambda ii, jj: tvm.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k), name='C') BA = tvm.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=256) BB = tvm.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=256) BC = tvm.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=256) def intrin_func(ins, outs): BA, BB = ins BC, = outs def init(): ib = tvm.ir_builder.create() ib.emit(tvm.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0)) return ib.get() def update(): ib = tvm.ir_builder.create() ib.emit(tvm.call_intrin('handle', 'tvm_mma_sync', BC.data, BC.elem_offset // 256, BA.data, BA.elem_offset // 256, BB.data, BB.elem_offset // 256, BC.data, BC.elem_offset // 256)) return ib.get() return update(), init(), update() return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC})
def intrin_vadd(n, cache_read=False, cache_write=False): scope_ubuf = 'local' dtype = 'float32' x = tvm.placeholder((n, ), dtype=dtype, name='vx') y = tvm.placeholder((n, ), dtype=dtype, name='vy') z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z') s = tvm.create_schedule(z.op) def create_buffer(t): return tvm.decl_buffer(t.shape, t.dtype, name='W' + t.name, scope=scope_ubuf, offset_factor=16) binds = {} if cache_read: binds[x] = create_buffer(x) binds[y] = create_buffer(y) if cache_write: binds[z] = create_buffer(z) def intrin_func(ins, outs): ib = tvm.ir_builder.create() ib.emit( tvm.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) return ib.get() with tvm.build_config(offset_factor=16): return tvm.decl_tensor_intrin(z.op, intrin_func, binds=binds)
def test_tensor_intrin_scalar_params(): n = tvm.size_var("n") x = tvm.placeholder((n,), name='x') v = tvm.size_var("v") w = tvm.size_var("w") z = tvm.compute((n,), lambda i: x[i]*v + w, name='z') def intrin_func(ins, outs, sp): assert(isinstance(ins[0], tvm.schedule.Buffer)) assert(ins[0].shape[0] == n) assert(sp[0] == v) assert(sp[1] == w) return tvm.call_packed("hw_func", ins[0].data, outs[0].data, sp[0], sp[1]) with tvm.build_config(offset_factor=1): intrin = tvm.decl_tensor_intrin(z.op, intrin_func, scalar_params=[v, w]) assert intrin.op == z.op assert intrin.reduce_init is None assert tuple(intrin.inputs) == tuple(z.op.input_tensors) assert(intrin.buffers[0].shape[0] == n) assert tuple(intrin.scalar_params) == tuple((v, w)) A = tvm.placeholder((10,10), name='A') # Pass scalar inputs to the TensorIntrin, interleaved with tensor inputs C = tvm.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C") s = tvm.create_schedule(C.op) stmt = tvm.lower(s, [A, C], simple_mode=True) assert isinstance(stmt.body.body.body, tvm.stmt.Evaluate) assert len(stmt.body.body.body.value.args) == 5 assert str(stmt.body.body.body.value.args[3]) == "(i*i)" assert str(stmt.body.body.body.value.args[4]) == "(i + j)"
def intrin_gemv(m, n): w = tvm.placeholder((m, n), name='w') x = tvm.placeholder((n,), name='x') k = tvm.reduce_axis((0, n), name='k') z = tvm.compute((m,), lambda i: tvm.sum(w[i, k] * x[k], axis=k), name='z') Wb = tvm.decl_buffer(w.shape, w.dtype, name="W", offset_factor=16, strides=[tvm.var('ldw'), 1]) def intrin_func(ins, outs): ww, xx = ins zz = outs[0] ww_ptr = ww.access_ptr("r") xx_ptr = xx.access_ptr("r") zz_ptr = zz.access_ptr("w") body = tvm.call_packed( "gemm", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) reset = tvm.call_packed( "fill_zero", zz_ptr, n) update = tvm.call_packed( "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) return body, reset, update with tvm.build_config(data_alignment=16, offset_factor=16): return tvm.decl_tensor_intrin(z.op, intrin_func, binds={w: Wb})
def intrin_vadd(n): dtype = 'float32' x = tvm.placeholder((n, ), dtype=dtype, name='vx') y = tvm.placeholder((n, ), dtype=dtype, name='vy') z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z') s = tvm.create_schedule(z.op) def create_buffer(t): return tvm.decl_buffer(t.shape, t.dtype, name='W' + t.name, offset_factor=16) def intrin_func(ins, outs): ib = tvm.ir_builder.create() ib.emit( tvm.call_extern("float32", 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) return ib.get() with tvm.build_config(offset_factor=16): return tvm.decl_tensor_intrin(z.op, intrin_func, binds={ x: create_buffer(x), y: create_buffer(y), z: create_buffer(z) })
def intrin_wmma_load_matrix(shape, scope): n, m, l = shape if scope == "wmma.matrix_a": row, col = n, l elif scope == "wmma.matrix_b": row, col = l, m A = tvm.placeholder((row, col), name='A', dtype='float16') BA = tvm.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=row * col) C = tvm.compute((row, col), lambda i, j: A[i, j], name='C') BC = tvm.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=row * col) def intrin_func(ins, outs): ib = tvm.ir_builder.create() BA = ins[0] BC = outs[0] ib.emit( tvm.call_intrin('handle', 'tvm_load_matrix_sync', BC.data, n, m, l, BC.elem_offset // (row * col), BA.access_ptr('r'), col, 'row_major')) return ib.get() return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
def test_tensor_intrin(): n = 16 x = tvm.placeholder((n, ), name='x') y = tvm.placeholder((n, ), name='y') z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z') def intrin_func(ins, outs): assert (isinstance(ins[0], tvm.schedule.Buffer)) assert (ins[0].shape[0].value == n) return tvm.call_packed("vadd", ins[0].data, outs[0].data, ins[0].shape[0]) intrin = tvm.decl_tensor_intrin(z.op, intrin_func) assert intrin.op == z.op assert intrin.reduce_init is None assert tuple(intrin.inputs) == tuple(z.op.input_tensors) assert (intrin.buffers[0].shape[0].value == n) m = 32 x = tvm.placeholder((m, ), name='x') y = tvm.placeholder((m, ), name='y') z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z') s = tvm.create_schedule(z.op) xo, xi = s[z].split(z.op.axis[0], factor=n) s[z].tensorize(xi, intrin) assert (s[z].iter_var_attrs[xi].tensor_intrin == intrin) assert ( s[z].iter_var_attrs[xi].iter_type == tvm.schedule.IterVar.Tensorized)
def intrin_gemv(m, l): a = tvm.placeholder((l,), name='a') b = tvm.placeholder((m, l), name='b') k = tvm.reduce_axis((0, l), name='k') c = tvm.compute((m,), lambda i: tvm.sum(a[k] * b[i, k], axis=k), name='c') Ab = tvm.decl_buffer(a.shape, a.dtype, name="A", offset_factor=1, strides=[1]) Bb = tvm.decl_buffer(b.shape, b.dtype, name="B", offset_factor=1, strides=[tvm.var("s1"), 1]) Cb = tvm.decl_buffer(c.shape, c.dtype, name="C", offset_factor=1, strides=[1]) def intrin_func(ins, outs): ib = tvm.ir_builder.create() aa, bb = ins cc = outs[0] ib.emit(tvm.call_extern("int32", "gemv_update", cc.access_ptr("w"), aa.access_ptr("r"), bb.access_ptr("r"), m, l, bb.strides[0])) return ib.get() with tvm.build_config(offset_factor=1): return tvm.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb})
def intrin_vadd(n, cache_read=False, cache_write=False): scope_ubuf = 'local' dtype = 'float32' x = tvm.placeholder((n,), dtype=dtype, name='vx') y = tvm.placeholder((n,), dtype=dtype, name='vy') z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z') s = tvm.create_schedule(z.op) def create_buffer(t): return tvm.decl_buffer(t.shape, t.dtype, name='W'+t.name, scope=scope_ubuf, offset_factor=16) binds = {} if cache_read: binds[x] = create_buffer(x) binds[y] = create_buffer(y) if cache_write: binds[z] = create_buffer(z) def intrin_func(ins, outs): ib = tvm.ir_builder.create() ib.emit(tvm.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) return ib.get() with tvm.build_config(offset_factor=16): return tvm.decl_tensor_intrin(z.op, intrin_func, binds=binds)
def intrin_dot(): n = 4 # dp4a requires operands packed by 4 x = tvm.placeholder((n,), name='x', dtype='int8') y = tvm.placeholder((n,), name='y', dtype='int8') k = tvm.reduce_axis((0, n), name='k') z = tvm.compute( (1,), lambda _: tvm.sum( x[k].astype('int32') * y[k].astype('int32'), axis=k)) def intrin_func(ins, outs): xx, yy = ins zz = outs[0] ib = tvm.ir_builder.create() dp4a = zz.vstore(0, tvm.call_pure_extern('int32', '__dp4a', xx.vload(0, dtype='int8x4'), yy.vload(0, dtype='int8x4'), zz.vload(0))) ib.emit(dp4a) body = ib.get() return body, zz.vstore(0, 0), body with tvm.build_config(data_alignment=4, offset_factor=1) as cfg: binds = {t: tvm.decl_buffer(t.shape, t.dtype, t.op.name, data_alignment=cfg.data_alignment, offset_factor=cfg.offset_factor, scope='local') for t in [x, y, z]} return tvm.decl_tensor_intrin(z.op, intrin_func, binds=binds)
def intrin_test(): m1 = tvm.var("m1") n1 = tvm.var("n1") a = tvm.placeholder((m1, n1), name='a') c = tvm.compute((1, n1), lambda i, j: a[0, j] + a[1, j] + a[2, j], name='c') Ab = tvm.decl_buffer(a.shape, name="Abuf", offset_factor=1) Cb = tvm.decl_buffer(c.shape, name="Cbuf", offset_factor=1) def intrin_func(ins, outs): aa = ins[0] cc = outs[0] def _body(): ib = tvm.ir_builder.create() ib.emit( tvm.call_extern("int32", "test", cc.access_ptr("w"), aa.access_ptr("r"))) return ib.get() return _body() with tvm.build_config(offset_factor=1): return tvm.decl_tensor_intrin(c.op, intrin_func, binds={ a: Ab, c: Cb })
def intrin_wmma_store_matrix(shape): n, m, l = shape A = tvm.placeholder((n, m), name='A', dtype='float32') BA = tvm.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=n * m) C = tvm.compute((n, m), lambda i, j: A[i, j], name='C') BC = tvm.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=n * m) def intrin_func(ins, outs): ib = tvm.ir_builder.create() BA = ins[0] BC = outs[0] ib.emit( tvm.call_intrin('handle', 'tvm_store_matrix_sync', BA.data, n, m, l, BA.elem_offset // (n * m), BC.access_ptr('w'), m, 'row_major')) return ib.get() return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
def dp4a(x_scope='local', y_scope='local', z_scope='local'): """ Int8 dot product reduced by every 4 elements using __dp4a Parameters ---------- x_scope : str, optional The storage scope of buffer for lhs y_scope : str, optional The storage scope of buffer for rhs z_scope : str, optional The storage scope of buffer for result Returns ------- intrin : TensorIntrin The dp4a TensorIntrin that can be used in tensorizing schedule. """ n = 4 # dp4a requires operands packed by 4 x = tvm.placeholder((n,), name='x', dtype='int8') y = tvm.placeholder((n,), name='y', dtype='int8') k = tvm.reduce_axis((0, n), name='rc') z = tvm.compute((1,), lambda i: tvm.sum( x[k].astype('int32') * y[k].astype('int32'), axis=[k])) def _intrin_func(ins, outs): def _instr(index): xx, yy = ins zz = outs[0] if index == 1: return zz.vstore(0, 0) ib = tvm.ir_builder.create() vec_x = xx.vload(0, dtype='int8x4') vec_y = yy.vload(0, dtype='int8x4') prev_z = 0 if index == 0 else zz.vload(0) new_z = tvm.call_pure_extern('int32', '__dp4a', vec_x, vec_y, prev_z) ib.emit(zz.vstore(0, new_z)) return ib.get() return _instr(0), _instr(1), _instr(2) # body, reset, update with tvm.build_config(data_alignment=4, offset_factor=1) as cfg: scopes = {x: x_scope, y: y_scope, z: z_scope} binds = {t: tvm.decl_buffer(t.shape, t.dtype, t.op.name, data_alignment=cfg.data_alignment, offset_factor=cfg.offset_factor, scope=scopes[t]) for t in [x, y, z]} return tvm.decl_tensor_intrin(z.op, _intrin_func, binds=binds)
def mma_sync_wmma(): factor1 = tvm.placeholder((1, ), name='factor1', dtype='float16') factor2 = tvm.placeholder((1, ), name='factor2', dtype='float16') product = tvm.placeholder((1, ), name='product', dtype='float32') #product = tvm.placeholder((1,),name='product',dtype='float16') schedule = tvm.compute( (1, ), lambda _: (factor1[0] + factor2[0] + product[0].astype('float16'))) def mma_sync(inputs, outputs): print(inputs) factor1_, factor2_, product_ = inputs schedule_ = outputs[0] #get address for matrix A A_ptr = factor1_.access_ptr("r") #get address for matrix B B_ptr = factor2_.access_ptr("r") #get address for matrix C C_ptr = product_.access_ptr("w") body = tvm.call_extern('float32', "wmma_call", A_ptr, B_ptr, C_ptr) #body = tvm.call_extern('float32',"__INIT_TILE_WARP__") init = tvm.call_extern('float32', "__INIT_TILE_WARP__") #product_.vstore((0,0,0,0),0.) return body, init, body with tvm.build_config(data_alignment=1, offset_factor=1) as cfg: binds = { t: tvm.decl_buffer(t.shape, t.dtype, t.op.name, data_alignment=cfg.data_alignment, offset_factor=cfg.offset_factor, scope='global') for t in [factor1, factor2] } print(factor1.shape[0].dtype) binds.update({ product: tvm.decl_buffer(product.shape, product.dtype, product.op.name, data_alignment=cfg.data_alignment, offset_factor=cfg.offset_factor, scope='global') }) binds.update({ schedule: tvm.decl_buffer(schedule.shape, schedule.dtype, schedule.op.name, data_alignment=cfg.data_alignment, offset_factor=cfg.offset_factor, scope='global') }) return tvm.decl_tensor_intrin(schedule.op, mma_sync, binds=binds)
def intrin_vadd(n): x = tvm.placeholder((n,), name='vx') y = tvm.placeholder((n,), name='vy') z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z') def intrin_func(ins, outs): xx, yy = ins zz = outs[0] return tvm.call_packed("vadd", xx, yy, zz) with tvm.build_config(offset_factor=16): return tvm.decl_tensor_intrin(z.op, intrin_func)
def intrin_vadd(n): x = tvm.placeholder((n,), name='vx') y = tvm.placeholder((n,), name='vy') z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z') def intrin_func(ins, outs): xx, yy = ins zz = outs[0] return tvm.call_packed("vadd", xx, yy, zz) with tvm.build_config(offset_factor=16): return tvm.decl_tensor_intrin(z.op, intrin_func)
def intrin_vadd(n): x = tvm.placeholder((n,)) y = tvm.placeholder((n,)) z = tvm.compute(x.shape, lambda i: x[i] + y[i]) def intrin_func(ins, outs): ib = tvm.ir_builder.create() ib.emit(tvm.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) return ib.get() with tvm.build_config(offset_factor=n): return tvm.decl_tensor_intrin(z.op, intrin_func)
def intrin_gemv(m, l): a = tvm.placeholder((l, ), name='a') b = tvm.placeholder((m, l), name='b') k = tvm.reduce_axis((0, l), name='k') c = tvm.compute((m, ), lambda i: tvm.sum(a[k] * b[i, k], axis=k), name='c') Ab = tvm.decl_buffer(a.shape, a.dtype, name="A", offset_factor=1, strides=[1]) Bb = tvm.decl_buffer(b.shape, b.dtype, name="B", offset_factor=1, strides=[tvm.var("s1"), 1]) Cb = tvm.decl_buffer(c.shape, c.dtype, name="C", offset_factor=1, strides=[1]) def intrin_func(ins, outs): aa, bb = ins cc = outs[0] def _body(): ib = tvm.ir_builder.create() ib.emit( tvm.call_extern("int32", "gemv_update", cc.access_ptr("w"), aa.access_ptr("r"), bb.access_ptr("r"), m, l, bb.strides[0])) return ib.get() def _reduce_reset(): ib = tvm.ir_builder.create() ib.emit( tvm.call_extern("int32", "gemv_reset", cc.access_ptr("w"), m)) return ib.get() def _reduce_update(): return _body() return _body(), _reduce_reset(), _reduce_update() with tvm.build_config(offset_factor=1): return tvm.decl_tensor_intrin(c.op, intrin_func, binds={ a: Ab, b: Bb, c: Cb })
def op_intrin(): bh = 9 bw = 9 x = tvm.placeholder((5, 5), name='A') y = tvm.compute((bh, bw), lambda i,j: x[j/3 + i%3, j%3+ i/3]) def intrin_func(ins, outs): xx, = ins zz = outs[0] return tvm.call_packed("op", xx, zz) with tvm.build_config(offset_factor=2): return tvm.decl_tensor_intrin(y.op, intrin_func)
def op_intrin(): bh = 9 bw = 9 x = tvm.placeholder((5, 5), name='A') y = tvm.compute((bh, bw), lambda i, j: x[j / 3 + i % 3, j % 3 + i / 3]) def intrin_func(ins, outs): xx, = ins zz = outs[0] return tvm.call_packed("op", xx, zz) with tvm.build_config(offset_factor=2): return tvm.decl_tensor_intrin(y.op, intrin_func)
def intrin_pool(): A = tvm.placeholder((64, 16, 16), name='A') kh = tvm.reduce_axis((0, 3), name='kh') kw = tvm.reduce_axis((0, 3), name='kw') P = tvm.compute((64, 14, 14), lambda c, oh, ow: tvm.max(A[c, oh + kh, ow + kw], axis=[kh, kw]), name='p') def intrin_func(ins, outs): dinp = ins[0] dout = outs[0] return tvm.call_packed("op", dinp, dout) with tvm.build_config(offset_factor=1): return tvm.decl_tensor_intrin(P.op, intrin_func)
def intrin_conv(in_h, in_w, kern_h, kern_w): in_height = in_h in_width = in_w kernel_h = kern_h kernel_w = kern_w stride_h = 1 stride_w = 1 out_h = ((in_height - kernel_h) // stride_h + 1) out_w = ((in_width - kernel_w) // stride_w + 1) Input = tvm.placeholder((in_height, in_width), name='input') Filter = tvm.placeholder((kernel_h, kernel_w), name='filter') kh = tvm.reduce_axis((0, kernel_h), name='kh') kw = tvm.reduce_axis((0, kernel_w), name='kw') conv = tvm.compute( (out_h, out_w), lambda oh, ow: tvm.sum(Filter[kh, kw] * Input[oh + kh, ow + kw], axis=[kh, kw]), name='c') def intrin_func(ins, outs): ib = tvm.ir_builder.create() inp, filt = ins outp = outs[0] ib.emit( tvm.call_extern( "int32", "inst_conv", outp.access_ptr("w"), inp.access_ptr("r"), filt.access_ptr("r"), )) return ib.get() with tvm.build_config(offset_factor=1) as cfg: scopes = {Input: "local", Filter: "local", conv: "local"} binds = { t: tvm.decl_buffer(t.shape, t.dtype, t.op.name, offset_factor=1) for t in [Input, Filter, conv] } return tvm.decl_tensor_intrin(conv.op, intrin_func, binds=binds)
def intrin_gemm(m, n, l): k = tvm.reduce_axis((0, l)) x = tvm.placeholder((m, l)) y = tvm.placeholder((n, l)) # in theory, no relation z = tvm.compute((m, n), lambda i, j: tvm.sum(x[i][k] * y[j][k], axis=k)) def intrin_func(ins, outs): x_ptr = ins[0].access_ptr("r") y_ptr = ins[1].access_ptr("r") z_ptr = outs[0].access_ptr("w") body = tvm.call_packed("gemv", x_ptr, y_ptr, z_ptr, m, n, l) reset = tvm.call_packed("fill_zero", z_ptr, m, n) update = tvm.call_packed("gemv_add", x_ptr, y_ptr, z_ptr, m, n, l) return body, reset, update with tvm.build_config(offset_factor=n): return tvm.decl_tensor_intrin(z.op, intrin_func)
def intrin_test(): m1 = tvm.var("m1") n1 = tvm.var("n1") a = tvm.placeholder((m1, n1), name='a') c = tvm.compute((1, n1), lambda i, j : a[0, j] + a[1, j] + a[2, j], name='c') Ab = tvm.decl_buffer(a.shape, name="Abuf", offset_factor=1) Cb = tvm.decl_buffer(c.shape, name="Cbuf", offset_factor=1) def intrin_func(ins, outs): aa = ins[0] cc = outs[0] def _body(): ib = tvm.ir_builder.create() ib.emit(tvm.call_extern("int32", "test", cc.access_ptr("w"), aa.access_ptr("r"))) return ib.get() return _body() with tvm.build_config(offset_factor=1): return tvm.decl_tensor_intrin(c.op, intrin_func, binds={a : Ab, c : Cb})
def int4_copy(x_scope="global", y_scope="global", add_on=None, bidx=tvm.thread_axis('blockIdx.x'), npq=1024, blk_size=64): n = 8 # int4_copy requires operands packed by 8 for fp16 x = tvm.placeholder((n, ), name='x', dtype='float16') y = tvm.compute((n, ), lambda i: x[i]) xb = tvm.decl_buffer(x.shape, x.dtype, name="x", scope=x_scope, offset_factor=1) yb = tvm.decl_buffer(y.shape, y.dtype, name="y", scope=y_scope, offset_factor=1) offset_x = -(blk_size) * (bidx / npq) def intrin_func(ins, outs): ib = tvm.ir_builder.create() if (add_on == None): int4copy = tvm.call_intrin('float32', 'int4_copy', yb.access_ptr("w"), 0, xb.access_ptr("r"), offset_x) ib.emit(int4copy) elif (add_on == "relu"): relu = tvm.call_intrin('float32', 'relu', xb.access_ptr("w"), 0) ib.emit(relu) int4copy = tvm.call_intrin('float32', 'int4_copy', yb.access_ptr("w"), 0, xb.access_ptr("r"), offset_x) ib.emit(int4copy) return ib.get() with tvm.build_config() as cfg: return tvm.decl_tensor_intrin(y.op, intrin_func, binds={x: xb, y: yb})
def intrin_multivadd(n): n_a = tvm.var("n_a") Ab = tvm.decl_buffer((n, ), tvm.float32, strides=[n_a]) n_b = tvm.var("n_b") Bb = tvm.decl_buffer((n, ), tvm.float32, strides=[n_b]) n_c = tvm.var("n_c") Cb = tvm.decl_buffer((n, ), tvm.float32, strides=[n_c]) z = tvm.compute((n,), lambda i: tvm.call_extern("float32", 'vadd', Ab.access_ptr("w", offset=n_a*i), Bb.access_ptr("r", offset=n_b*i), Cb.access_ptr("r", offset=n_c*i))) # replace the pattern with the multivadd call. I need to figure out # how to pass it the right parameters. def intrin_func(ins, outs): return tvm.call_packed("multivadd") with tvm.build_config(): return tvm.decl_tensor_intrin(z.op, intrin_func, name="multivadd")
def intrin_multivadd(n): n_a = tvm.var("n_a") Ab = tvm.decl_buffer((n, ), tvm.float32, strides=[n_a]) n_b = tvm.var("n_b") Bb = tvm.decl_buffer((n, ), tvm.float32, strides=[n_b]) n_c = tvm.var("n_c") Cb = tvm.decl_buffer((n, ), tvm.float32, strides=[n_c]) z = tvm.compute((n, ), lambda i: tvm.call_extern( "float32", 'vadd', Ab.access_ptr("w", offset=n_a * i), Bb.access_ptr("r", offset=n_b * i), Cb.access_ptr("r", offset=n_c * i))) # replace the pattern with the multivadd call. I need to figure out # how to pass it the right parameters. def intrin_func(ins, outs): return tvm.call_packed("multivadd") with tvm.build_config(): return tvm.decl_tensor_intrin(z.op, intrin_func, name="multivadd")
def test_tensor_intrin(): n = 16 x = tvm.placeholder((n,), name='x') y = tvm.placeholder((n,), name='y') z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z') def intrin_func(ins, outs): assert(isinstance(ins[0], tvm.schedule.Buffer)) assert(ins[0].shape[0].value == n) return tvm.call_packed("vadd", ins[0].data, outs[0].data, ins[0].shape[0]) intrin = tvm.decl_tensor_intrin(z.op, intrin_func) assert intrin.op == z.op assert intrin.reduce_init is None assert tuple(intrin.inputs) == tuple(z.op.input_tensors) assert(intrin.buffers[0].shape[0].value == n) m = 32 x = tvm.placeholder((m,), name='x') y = tvm.placeholder((m,), name='y') z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z') s = tvm.create_schedule(z.op) xo, xi = s[z].split(z.op.axis[0], factor=n) s[z].tensorize(xi, intrin) assert(s[z].iter_var_attrs[xi].tensor_intrin == intrin) assert(s[z].iter_var_attrs[xi].iter_type == tvm.schedule.IterVar.Tensorized)
def intrin_vadd(n): dtype = 'float32' x = tvm.placeholder((n,), dtype=dtype, name='vx') y = tvm.placeholder((n,), dtype=dtype, name='vy') z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z') s = tvm.create_schedule(z.op) def create_buffer(t): return tvm.decl_buffer(t.shape, t.dtype, name='W'+t.name, offset_factor=16) def intrin_func(ins, outs): ib = tvm.ir_builder.create() ib.emit(tvm.call_extern("float32", 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) return ib.get() with tvm.build_config(offset_factor=16): return tvm.decl_tensor_intrin(z.op, intrin_func, binds={x: create_buffer(x), y: create_buffer(y), z: create_buffer(z)})
def dot_16x1x16_int8_int8_int32(): """ Int8 dot product by every 4 elements using AVX2 Skylake instructions. This function takes two arrays of int8 datatype -- data[4] and kernel[16][4] -- and computes a dot product of data[4] with every 4 elements of kernels, resulting in output[16] of int32 datatype. The pseudo code is as follows. .. code-block:: c void dot_16x1x16_int8_int8_int32(int8 data[4], int8 kernel[16][4], int32 output[16]){ for (int i = 0; i < 16; i++){ out[i] = 0; for (int k = 0; k < 4; k++){ out[i] += data[k] * kernel[i][k] } } } Physically, the kernel array sits in an AVX512 vector register and the data[4] is broadcasted to another AVX512 vector register. This function returns a TensorIntrin that can be used to tensorize a schedule. Returns ------- intrin : TensorIntrin The Skylake int8 TensorIntrin that can be used in tensorizing schedule """ int32_lanes = 16 # 16 int32 lanes in AVX512 num_int8_elements = 4 # 4 int8 elements in int32 data = tvm.placeholder((num_int8_elements,), dtype='uint8', name='data') kernel = tvm.placeholder((int32_lanes, num_int8_elements), dtype='int8', name='kernel') k = tvm.reduce_axis((0, num_int8_elements), name='k') C = tvm.compute((int32_lanes,), lambda i: tvm.sum(data[k].astype('int32') * kernel[i, k].astype('int32'), axis=k), name="C") a_buffer = tvm.decl_buffer(data.shape, dtype='uint8', name="a_buffer", offset_factor=1, strides=[1]) b_buffer = tvm.decl_buffer(kernel.shape, dtype='int8', name="b_buffer", offset_factor=1, strides=[tvm.var('ldw'), 1]) def _intrin_func(ins, outs): def _instr(index): ib = tvm.ir_builder.create() if index == 1: ib.emit(outs[0].vstore(0, tvm.const(0, 'int32x16'))) return ib.get() a_int8 = ins[0].vload([0], "uint8x4") re_int32 = tvm.call_pure_intrin('int32', 'reinterpret', a_int8) vec_ai32 = re_int32.astype('int32x16') vec_a = tvm.call_pure_intrin('int8x64', 'reinterpret', vec_ai32) vec_b = ins[1].vload([0, 0], "int8x64") vec_one = tvm.const(1, "int16x32") pair_reduction = tvm.call_llvm_intrin('int16x32', 'llvm.x86.avx512.pmaddubs.w.512', tvm.const(0, 'uint32'), vec_a, vec_b) quad_reduction = tvm.call_llvm_intrin('int32x16', 'llvm.x86.avx512.pmaddw.d.512', tvm.const(0, 'uint32'), pair_reduction, vec_one) if index == 0: ib.emit(outs[0].vstore(0, quad_reduction)) else: ib.emit(outs[0].vstore(0, quad_reduction + outs[0].vload([0], 'int32x16'))) return ib.get() # body, reset, update return _instr(0), _instr(1), _instr(2) with tvm.build_config(offset_factor=1, partition_const_loop=True): return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer})
def intrin_conv(args): """intrin_conv is a conv inner interface""" ( ndim, ci_, vh_, vw_, vc_, _, _, _, _, _, _, _, _, dtype, acum_dtype, opname, core_id, ) = args ci_ = tvm.var("ci_") if ci_ is None else ci_ kvshape = (ci_, vc_) if ndim == 2: dvshape = (vw_, ci_) ovshape = (vw_, vc_) data_vec = tvm.placeholder(dvshape, name="a", dtype=dtype) kernel_vec = tvm.placeholder(kvshape, name="b", dtype=dtype) ci = tvm.reduce_axis((0, ci_), name="ci") conv = tvm.compute( ovshape, lambda vw, vc: tvm.sum( data_vec[vw, ci].astype(acum_dtype) * kernel_vec[ci, vc]. astype(acum_dtype), axis=[ci], ), name="conv", ) else: dvshape = (vh_, vw_, ci_) ovshape = (vh_, vw_, vc_) data_vec = tvm.placeholder(dvshape, name="a", dtype=dtype) kernel_vec = tvm.placeholder(kvshape, name="b", dtype=dtype) ci = tvm.reduce_axis((0, ci_), name="ci") conv = tvm.compute( ovshape, lambda vh, vw, vc: tvm.sum( data_vec[vh, vw, ci].astype(acum_dtype) * kernel_vec[ci, vc]. astype(acum_dtype), axis=[ci], ), name="conv", ) stride_a = [ functools.reduce(lambda x, y: x * y, dvshape[i + 1:len(dvshape)]) for i in range(0, len(dvshape) - 1) ] stride_a.append(1) stride_b = [ functools.reduce(lambda x, y: x * y, kvshape[i + 1:len(kvshape)]) for i in range(0, len(kvshape) - 1) ] stride_b.append(1) stride_c = [ functools.reduce(lambda x, y: x * y, ovshape[i + 1:len(ovshape)]) for i in range(0, len(ovshape) - 1) ] stride_c.append(1) ab_ = tvm.decl_buffer(data_vec.shape, data_vec.dtype, name="a_", offset_factor=1, strides=stride_a) bb_ = tvm.decl_buffer(kernel_vec.shape, kernel_vec.dtype, name="b_", offset_factor=1, strides=stride_b) cb_ = tvm.decl_buffer(conv.shape, conv.dtype, name="C", offset_factor=1, strides=stride_c) def intrin_func(ins, outs): aa, bb = ins cc = outs[0] def _body(): b_ = tvm.ir_builder.create() b_.emit( tvm.call_extern( "int32", opname, cc.access_ptr("w"), aa.access_ptr("r"), bb.access_ptr("r"), ci_, vh_, vw_, vc_, core_id, )) return b_.get() return _body() return tvm.decl_tensor_intrin(conv.op, intrin_func, binds={ data_vec: ab_, kernel_vec: bb_, conv: cb_ })
def intrin_conv(args): """intrin_conv""" ( ci_, vh_, vw_, vc_, kh_, kw_, sh_, sw_, dila_h, dila_w, dtype, acum_dtype, opname, core_id, ) = args hstr, wstr = sh_, sw_ ci_ = tvm.var("ci_") if ci_ is None else ci_ kvshape = (ci_, kh_, kw_, vc_) ovshape = (vh_, vw_, vc_) if dila_h != 1 or dila_w != 1: dvshape = (kh_, kw_, vh_, vw_, ci_) else: dvshape = ((vh_ - 1) * hstr + kh_, (vw_ - 1) * wstr + kw_, ci_) data_vec = tvm.placeholder(dvshape, name="a", dtype=dtype) kernel_vec = tvm.placeholder(kvshape, name="b", dtype=dtype) ci = tvm.reduce_axis((0, ci_), name="ci") kh = tvm.reduce_axis((0, kh_), name="kh") kw = tvm.reduce_axis((0, kw_), name="kw") if dila_h != 1 or dila_w != 1: conv = tvm.compute( ovshape, lambda vh, vw, vc: tvm.sum( data_vec[kh, kw, vh, vw, ci].astype(acum_dtype) * kernel_vec[ ci, kh, kw, vc].astype(acum_dtype), axis=[ci, kh, kw], ), name="conv", ) else: conv = tvm.compute( ovshape, lambda vh, vw, vc: tvm.sum( data_vec[vh * hstr + kh, vw * wstr + kw, ci].astype(acum_dtype) * kernel_vec[ci, kh, kw, vc].astype(acum_dtype), axis=[ci, kh, kw], ), name="conv", ) stride_a = [ functools.reduce(lambda x, y: x * y, dvshape[i + 1:len(dvshape)]) for i in range(0, len(dvshape) - 1) ] stride_a.append(1) stride_b = [ functools.reduce(lambda x, y: x * y, kvshape[i + 1:len(kvshape)]) for i in range(0, len(kvshape) - 1) ] stride_b.append(1) stride_c = [ functools.reduce(lambda x, y: x * y, ovshape[i + 1:len(ovshape)]) for i in range(0, len(ovshape) - 1) ] stride_c.append(1) a_buffer = tvm.decl_buffer(data_vec.shape, data_vec.dtype, name="A", offset_factor=1, strides=stride_a) b_buffer = tvm.decl_buffer(kernel_vec.shape, kernel_vec.dtype, name="B", offset_factor=1, strides=stride_b) c_buffer = tvm.decl_buffer(conv.shape, conv.dtype, name="C", offset_factor=1, strides=stride_c) def intrin_func(ins, outs): aa, bb = ins cc = outs[0] def _body(): ib = tvm.ir_builder.create() ib.emit( tvm.call_extern( "int32", opname, cc.access_ptr("w"), aa.access_ptr("r"), bb.access_ptr("r"), ci_, vh_, vw_, vc_, kh_, sh_, core_id, )) return ib.get() return _body() return tvm.decl_tensor_intrin(conv.op, intrin_func, binds={ data_vec: a_buffer, kernel_vec: b_buffer, conv: c_buffer })
def intrin_col2im(input_shape, output_shape, kernel, stride, pad, dtype): ''' Compute col2im via cce col2im intrin function call directly Args: input_shape: the shape of the image output_shape: the shape of the result of im2col given the input image kernel: kernel sizes for im2col stride: stride sizes for im2col pad: padding sizes for im2col, including padding top, bottom, left, and right dtype: type of the data Return: cce intrin function call for col2im ''' _, _, _, _, WINDOW_H, WINDOW_W, _ = input_shape _, _, H, W, _ = output_shape kernel_h, kernel_w = kernel stride_h, stride_w = stride pad_t, pad_b, pad_l, pad_r = pad assert ( WINDOW_H * WINDOW_W ) % 16 == 0, "Number of windows over the input must be divisible by 16 (col2im repeat)" assert ( H * W * 16) % 64 == 0, "Input size must be divisible by 64 (vector_dup repeat)" # FCOL2IMG ------------------------------------------- INPUT_W = W INPUT_H = H PAD_LEFT = pad_l PAD_RIGHT = pad_r PAD_TOP = pad_t PAD_BOTTOM = pad_b # --------------------------------------------------- # Xm ------------------------------------------------ W_IDX_KERNEL = 0 H_IDX_KERNEL = 0 H_IDX = (-pad_l) & 0xFFFF # fix negative numbers W_IDX = (-pad_t) & 0xFFFF C1_IDX = 0 # --------------------------------------------------- # Xt ------------------------------------------------ STRIDE_H = stride_h STRIDE_W = stride_w KERNEL_H = kernel_h KERNEL_W = kernel_w DILATION_H = 1 DILATION_W = 1 JUMP_OFFSET = 0 REPEAT_MODE = 1 REPEAT_TIME = (WINDOW_H * WINDOW_W) // 16 # --------------------------------------------------- INPUT_B = 1 INPUT_C1 = 1 INPUT_C0 = 16 input_data = tvm.placeholder( (INPUT_B, INPUT_C1, KERNEL_H, KERNEL_W, WINDOW_H, WINDOW_W, INPUT_C0), dtype=dtype) result = tvm.compute( (INPUT_B, INPUT_C1, INPUT_H, INPUT_W, INPUT_C0), lambda b, c1, h, w, c0: input_data[b, c1, h % KERNEL_H, w % KERNEL_W, h % WINDOW_H, w % WINDOW_W, c0], name="col2im_intrinsic", ) input_data_buff = tvm.decl_buffer(input_data.shape, input_data.dtype, name="input_data_buff", offset_factor=1, scope="local.UB") result_buff = tvm.decl_buffer(result.shape, result.dtype, name="result_buff", offset_factor=1, scope="local.UB") def pack_args(sp): assert len(sp) == 20 fcol2img = (akg.tvm.const(sp[0], "uint64") + akg.tvm.const(sp[1] * 2**16, "uint64") + akg.tvm.const(sp[2] * 2**32, "uint64") + akg.tvm.const(sp[3] * 2**40, "uint64") + akg.tvm.const(sp[4] * 2**48, "uint64") + akg.tvm.const(sp[5] * 2**56, "uint64")) Xm = (akg.tvm.const(sp[6] * 2**16, "uint64") + akg.tvm.const(sp[7] * 2**24, "uint64") + akg.tvm.const(sp[8] * 2**32, "uint64") + akg.tvm.const(sp[9] * 2**48, "uint64") + akg.tvm.const(sp[10], "uint64")) Xt = (akg.tvm.const(sp[11], "uint64") + akg.tvm.const(sp[12] * 2**6, "uint64") + akg.tvm.const(sp[13] * 2**12, "uint64") + akg.tvm.const(sp[14] * 2**20, "uint64") + akg.tvm.const(sp[15] * 2**28, "uint64") + akg.tvm.const(sp[16] * 2**36, "uint64") + akg.tvm.const(sp[17] * 2**44, "uint64") + akg.tvm.const(sp[18] * 2**52, "uint64") + akg.tvm.const(sp[19] * 2**56, "uint64")) return (fcol2img, Xm, Xt) def intrin_func(ins, outs): sp = [ INPUT_W, INPUT_H, PAD_LEFT, PAD_RIGHT, PAD_TOP, PAD_BOTTOM, # FMATRIX W_IDX_KERNEL, H_IDX_KERNEL, W_IDX, H_IDX, C1_IDX, # Xm STRIDE_W, STRIDE_H, KERNEL_W, KERNEL_H, DILATION_W, DILATION_H, JUMP_OFFSET, REPEAT_MODE, REPEAT_TIME, # Xt ] aa = ins[0] bb = outs[0] ib = tvm.ir_builder.create() fcol2img, Xm, Xt = pack_args(sp) ib.emit(tvm.call_extern(dtype, "set_fcol2img", fcol2img)) ib.emit( tvm.call_extern(dtype, "vector_dup", bb.access_ptr("w"), 0, (INPUT_H * INPUT_W * 16) // 64, 1, 1, 8, 8)) c = 0 for kh in range(KERNEL_H): for kw in range(KERNEL_W): sp[6] = kw sp[7] = kh fcol2img, Xm, Xt = pack_args(sp) ib.emit( tvm.call_extern( dtype, "col2img", bb.access_ptr("rw"), aa.access_ptr("r", offset=c * 16 * INPUT_C0 * REPEAT_TIME), Xm, Xt, )) c += 1 return ib.get() with tvm.build_config(offset_factor=1): return tvm.decl_tensor_intrin(result.op, intrin_func, binds={ input_data: input_data_buff, result: result_buff })
def _intrin_popcount(m, k_i, w_b, x_b, unipolar): pack_dtype = 'uint8' w = tvm.placeholder((w_b, m, k_i), dtype=pack_dtype, name='w') x = tvm.placeholder(( x_b, k_i, ), dtype=pack_dtype, name='x') k = tvm.reduce_axis((0, k_i), name='k') bw = tvm.reduce_axis((0, w_b), name='bw') bx = tvm.reduce_axis((0, x_b), name='bx') if unipolar: dtype = 'int16' z = tvm.compute( (m, ), lambda i: tvm.sum((tvm.popcount(w[bw, i, k].astype(dtype) & x[ bx, k].astype(dtype)) - tvm.popcount(~w[bw, i, k].astype( dtype) & x[bx, k].astype(dtype))) << (bw + bx).astype(dtype), axis=[bw, bx, k]), name='z') else: dtype = 'uint16' z = tvm.compute((m, ), lambda i: tvm.sum(tvm.popcount(w[bw, i, k].astype( dtype) & x[bx, k].astype(dtype)) << (bw + bx).astype(dtype), axis=[bw, bx, k]), name='z') Wb = tvm.decl_buffer(w.shape, w.dtype, name="W", offset_factor=k_i, strides=[tvm.var('ldw'), tvm.var('ldw'), 1]) # stride can be inferred Xb = tvm.decl_buffer(x.shape, x.dtype, name="X", offset_factor=k_i, strides=[tvm.var('ldw'), 1]) Zb = tvm.decl_buffer(z.shape, z.dtype, name="Z", offset_factor=1, strides=[1]) def _intrin_func(ins, outs): ww, xx = ins zz = outs[0] args_1 = tvm.const(1, 'uint32') args_2 = tvm.const(2, 'uint32') if unipolar: vpadd = "llvm.arm.neon.vpadd.v8i8" vpadalu = "llvm.arm.neon.vpadals.v16i8.v8i16" full_dtype = 'int8x16' half_dtype = 'int8x8' return_dtype = 'int16x8' else: vpadd = "llvm.arm.neon.vpadd.v8u8" vpadalu = "llvm.arm.neon.vpadalu.v16u8.v8u16" full_dtype = 'uint8x16' half_dtype = 'uint8x8' return_dtype = 'uint16x8' def _instr(index): irb = tvm.ir_builder.create() if index == 1: # reduce reset irb.emit(zz.vstore(0, tvm.const(0, return_dtype))) return irb.get() # body and reduce update cnts8 = [None] * 8 cnts4 = [None] * 4 cnts2 = [None] * 2 for bw in range(w_b): for bx in range(x_b): if k_i == 16: for i in range(m): w_ = ww.vload([bw, i, 0], 'uint8x16').astype(full_dtype) x_ = xx.vload([bx, 0], 'uint8x16').astype(full_dtype) if unipolar: cnts = tvm.popcount(w_ & x_) - tvm.popcount(~w_ & x_) else: cnts = tvm.popcount(w_ & x_) upper_half = tvm.call_pure_intrin( half_dtype, 'vectorhigh', cnts) lower_half = tvm.call_pure_intrin( half_dtype, 'vectorlow', cnts) cnts8[i] = upper_half + lower_half for i in range(m // 2): cnts4[i] = tvm.call_llvm_intrin( half_dtype, vpadd, args_1, cnts8[i * 2], cnts8[i * 2 + 1]) for i in range(m // 4): cnts2[i] = tvm.call_llvm_intrin( half_dtype, vpadd, args_1, cnts4[i * 2], cnts4[i * 2 + 1]) cnts = tvm.call_pure_intrin(full_dtype, 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw + bx, pack_dtype) out = tvm.call_llvm_intrin(return_dtype, vpadalu, args_2, zz.vload(0, return_dtype), shifted_cnts) else: # ki == 8 for i in range(m): w_ = ww.vload([bw, i, 0], 'uint8x8').astype(half_dtype) x_ = xx.vload([bx, 0], 'uint8x8').astype(half_dtype) if unipolar: cnts8[i] = tvm.popcount( w_ & x_) - tvm.popcount(~w_ & x_) else: cnts8[i] = tvm.popcount(w_ & x_) for i in range(m // 2): cnts4[i] = tvm.call_llvm_intrin( half_dtype, vpadd, args_1, cnts8[i * 2], cnts8[i * 2 + 1]) for i in range(m // 4): cnts2[i] = tvm.call_llvm_intrin( half_dtype, vpadd, args_1, cnts4[i * 2], cnts4[i * 2 + 1]) cnts = tvm.call_pure_intrin(full_dtype, 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw + bx, pack_dtype) out = tvm.call_llvm_intrin(return_dtype, vpadalu, args_2, zz.vload(0, return_dtype), shifted_cnts) irb.emit(zz.vstore(0, out)) return irb.get() # body, reset, update return _instr(0), _instr(1), _instr(2) with tvm.build_config(offset_factor=1, partition_const_loop=True): return tvm.decl_tensor_intrin(z.op, _intrin_func, binds={ w: Wb, x: Xb, z: Zb })
def intrin_gemm(M, N, K): assert (M, N) in MNTiles, (M, N) dtype = 'float32' A = tvm.placeholder((K, M), dtype=dtype, name='A') B = tvm.placeholder((K, N), dtype=dtype, name='B') k = tvm.reduce_axis((0, K), name='k') C = tvm.compute((M, N), lambda m, n: tvm.sum(A[k, m] * B[k, n], axis=[k]), name='C') Ab = tvm.decl_buffer(A.shape, A.dtype, name="A", offset_factor=M, strides=[M, 1]) Bb = tvm.decl_buffer(B.shape, B.dtype, name="B", offset_factor=N, strides=[N, 1]) Cb = tvm.decl_buffer(C.shape, C.dtype, name="C", offset_factor=1, strides=[tvm.var('ldc'), 1]) def intrin_func(ins, outs): aa, bb = ins cc = outs[0] def body(): irb = tvm.ir_builder.create() extern_call = tvm.call_extern( "int32", "sgemm_compute_{M}x{N}__{ARCH}".format(M=M, N=N, ARCH=ARCH), K, irb.buffer_ptr(aa), aa.elem_offset, irb.buffer_ptr(bb), bb.elem_offset, irb.buffer_ptr(cc), cc.elem_offset, cc.strides[0]) irb.emit(extern_call) return irb.get() def reset(): irb = tvm.ir_builder.create() extern_call = tvm.call_extern( "int32", "sgemm_reset_{M}x{N}__{ARCH}".format(M=M, N=N, ARCH=ARCH), irb.buffer_ptr(cc), cc.elem_offset, cc.strides[0]) irb.emit(extern_call) return irb.get() def update(): irb = tvm.ir_builder.create() extern_call = tvm.call_extern( "int32", "sgemm_update_{M}x{N}__{ARCH}".format(M=M, N=N, ARCH=ARCH), K, irb.buffer_ptr(aa), aa.elem_offset, irb.buffer_ptr(bb), bb.elem_offset, irb.buffer_ptr(cc), cc.elem_offset, cc.strides[0]) irb.emit(extern_call) return irb.get() return body(), reset(), update() with tvm.build_config(): return tvm.decl_tensor_intrin(C.op, intrin_func, binds={ A: Ab, B: Bb, C: Cb })
def dot_16x1x16_uint8_int8_int32_skylake(): """ Int8 dot product by every 4 elements using AVX512 Skylake instructions. This function takes two arrays of uint8 and int8 datatype -- data[4] and kernel[16][4] -- and computes a dot product of data[4] with every 4 elements of kernels, resulting in output[16] of int32 datatype. The pseudo code is as follows. .. code-block:: c void dot_16x1x16_uint8_int8_int32(uint8 data[4], int8 kernel[16][4], int32 output[16]){ for (int i = 0; i < 16; i++){ output[i] = 0; for (int k = 0; k < 4; k++){ output[i] += data[k] * kernel[i][k] } } } Physically, the kernel array sits in an AVX512 vector register and the data[4] is broadcasted to another AVX512 vector register. This function returns a TensorIntrin that can be used to tensorize a schedule. Returns ------- intrin : TensorIntrin The Skylake int8 TensorIntrin that can be used in tensorizing schedule """ int32_lanes = 16 # 16 int32 lanes in AVX512 num_int8_elements = 4 # 4 int8 elements in int32 data = tvm.placeholder((num_int8_elements, ), dtype='uint8', name='data') kernel = tvm.placeholder((int32_lanes, num_int8_elements), dtype='int8', name='kernel') k = tvm.reduce_axis((0, num_int8_elements), name='k') C = tvm.compute( (int32_lanes, ), lambda i: tvm.sum( data[k].astype('int32') * kernel[i, k].astype('int32'), axis=k), name="C") a_buffer = tvm.decl_buffer(data.shape, dtype='uint8', name="a_buffer", offset_factor=1, strides=[1]) b_buffer = tvm.decl_buffer(kernel.shape, dtype='int8', name="b_buffer", offset_factor=1, strides=[tvm.var('ldw'), 1]) def _intrin_func(ins, outs): def _instr(index): ib = tvm.ir_builder.create() if index == 1: ib.emit(outs[0].vstore(0, tvm.const(0, 'int32x16'))) return ib.get() a_int8 = ins[0].vload([0], "uint8x4") re_int32 = tvm.call_pure_intrin('int32', 'reinterpret', a_int8) vec_ai32 = re_int32.astype('int32x16') vec_a = tvm.call_pure_intrin('int8x64', 'reinterpret', vec_ai32) vec_b = ins[1].vload([0, 0], "int8x64") vec_one = tvm.const(1, "int16x32") pair_reduction = tvm.call_llvm_intrin( 'int16x32', 'llvm.x86.avx512.pmaddubs.w.512', tvm.const(0, 'uint32'), vec_a, vec_b) quad_reduction = tvm.call_llvm_intrin( 'int32x16', 'llvm.x86.avx512.pmaddw.d.512', tvm.const(0, 'uint32'), pair_reduction, vec_one) if index == 0: ib.emit(outs[0].vstore(0, quad_reduction)) else: ib.emit(outs[0].vstore( 0, quad_reduction + outs[0].vload([0], 'int32x16'))) return ib.get() # body, reset, update return _instr(0), _instr(1), _instr(2) with tvm.build_config(offset_factor=1, partition_const_loop=True): return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={ data: a_buffer, kernel: b_buffer })
def _intrin_popcount(m, k_i, w_b, x_b): dtype = 'uint8' w = tvm.placeholder((w_b, m, k_i), dtype=dtype, name='w') x = tvm.placeholder((x_b, k_i,), dtype=dtype, name='x') k = tvm.reduce_axis((0, k_i), name='k') bw = tvm.reduce_axis((0, w_b), name='bw') bx = tvm.reduce_axis((0, x_b), name='bx') z = tvm.compute((m,), lambda i: tvm.sum(tvm.popcount(w[bw, i, k].astype('uint16') & x[bx, k].astype('uint16')) << (bw+bx).astype('uint16'), axis=[bw, bx, k]), name='z') Wb = tvm.decl_buffer(w.shape, w.dtype, name="W", offset_factor=k_i, strides=[tvm.var('ldw'), tvm.var('ldw'), 1]) Xb = tvm.decl_buffer(x.shape, x.dtype, name="X", offset_factor=k_i, strides=[tvm.var('ldw'), 1]) def _intrin_func(ins, outs): ww, xx = ins zz = outs[0] vpadd = "llvm.arm.neon.vpadd.v8u8" vpadalu = "llvm.arm.neon.vpadalu.v16u8.v8u16" args_1 = tvm.const(1, 'uint32') args_2 = tvm.const(2, 'uint32') def _instr(index): irb = tvm.ir_builder.create() if index == 1: irb.emit(zz.vstore(0, tvm.const(0, 'uint16x8'))) return irb.get() cnts8 = [None] * 8 cnts4 = [None] * 4 cnts2 = [None] * 2 for bw in range(w_b): for bx in range(x_b): if k_i == 16: for i in range(m): ands = ww.vload([bw, i, 0], 'uint8x16') & xx.vload([bx, 0], 'uint8x16') cnts = tvm.popcount(ands) upper_half = tvm.call_pure_intrin('uint8x8', 'vectorhigh', cnts) lower_half = tvm.call_pure_intrin('uint8x8', 'vectorlow', cnts) cnts8[i] = upper_half + lower_half for i in range(m//2): cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd, args_1, cnts8[i*2], cnts8[i*2+1]) for i in range(m//4): cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd, args_1, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw+bx, dtype) out = tvm.call_llvm_intrin('uint16x8', vpadalu, args_2, zz.vload(0, 'uint16x8'), shifted_cnts) else: # ki == 8 for i in range(m): ands = ww.vload([bw, i, 0], 'uint8x8') & xx.vload([bx, 0], 'uint8x8') cnts8[i] = tvm.popcount(ands) for i in range(m//2): cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd, args_1, cnts8[i*2], cnts8[i*2+1]) for i in range(m//4): cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd, args_1, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.const(bw+bx, dtype) out = tvm.call_llvm_intrin('uint16x8', vpadalu, args_2, zz.vload(0, 'uint16x8'), shifted_cnts) irb.emit(zz.vstore(0, out)) return irb.get() # body, reset, update return _instr(0), _instr(1), _instr(2) with tvm.build_config(offset_factor=1, partition_const_loop=True): return tvm.decl_tensor_intrin(z.op, _intrin_func, binds={w: Wb, x:Xb})
def dot_int8_int8_int32(int32_lanes, dtype='uint'): """ Int8 dot product by every 4 elements using ARM v8.2 udot. This function takes two arrays of int8 datatype -- data[4] and kernel[int32_lanes][4] -- and computes a dot product of data[4] with every 4 elements of kernels, resulting in output[int32_lanes] of uint32 datatype. The pseudo code is as follows. .. code-block:: c void dot_int8_int8_int32(int8 data[4], int8 kernel[16][4], int32 output[16]){ for (int i = 0; i < int32_lanes; i++){ out[i] = 0; for (int k = 0; k < 4; k++){ out[i] += data[k] * kernel[i][k] } } } Physically, the kernel array sits in a vector register and the data[4] is broadcasted to another vector register. This function returns a TensorIntrin that can be used to tensorize a schedule. Parameters ---------- int32_lanes: int How many int32/uint32 to produce dtype: str, optional, {"uint", "int"} Whether it works on unsigned int or signed int Returns ------- intrin : TensorIntrin The ARM uint8 TensorIntrin that can be used in tensorizing schedule """ num_int8_elements = 4 # 4 int8 elements in int32 data = tvm.placeholder((num_int8_elements, ), dtype='%s8' % dtype, name='data') kernel = tvm.placeholder((int32_lanes, num_int8_elements), dtype='%s8' % dtype, name='kernel') k = tvm.reduce_axis((0, num_int8_elements), name='k') C = tvm.compute((int32_lanes, ), lambda i: tvm.sum(data[k].astype('%s32' % dtype) * kernel[ i, k].astype('%s32' % dtype), axis=k), name="C") a_buffer = tvm.decl_buffer(data.shape, dtype='%s8' % dtype, name="a_buffer", offset_factor=1, strides=[1]) b_buffer = tvm.decl_buffer(kernel.shape, dtype='%s8' % dtype, name="b_buffer", offset_factor=1, strides=[tvm.var('s'), 1]) def _intrin_func(ins, outs): def _instr(index): ib = tvm.ir_builder.create() if index == 1: ib.emit(outs[0].vstore( 0, tvm.const(0, '%s32x%d' % (dtype, int32_lanes)))) return ib.get() dtype_a = '%s8x%d' % (dtype, num_int8_elements) dtype_b = '%s8x%d' % (dtype, int32_lanes * num_int8_elements) dtype_c = '%s32x%d' % (dtype, int32_lanes) a_int8 = ins[0].vload([0], dtype_a) re_int32 = tvm.call_pure_intrin('%s32' % dtype, 'reinterpret', a_int8) # broadcast a vec_ai32 = re_int32.astype(dtype_c) vec_a = tvm.call_pure_intrin(dtype_b, 'reinterpret', vec_ai32) vec_b = ins[1].vload([0, 0], dtype_b) vec_c = outs[0].vload([0], dtype_c) inst = 'udot' if dtype == 'uint' else 'sdot' inst = 'llvm.aarch64.neon.%s.v%di32.v%di8' % ( inst, int32_lanes, int32_lanes * num_int8_elements) vdot = tvm.call_llvm_intrin(dtype_c, inst, tvm.const(2, 'uint32'), vec_c, vec_a, vec_b) ib.emit(outs[0].vstore(0, vdot)) return ib.get() # body, reset, update return _instr(0), _instr(1), _instr(2) with tvm.build_config(offset_factor=1, partition_const_loop=True): return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={ data: a_buffer, kernel: b_buffer })
def gemm(env, mock=False): """Matrix-matrix multiply intrinsic Parameters ---------- env : Environment The Environment mock : bool Whether create a mock version. """ wgt_lanes = env.WGT_ELEM_BITS // env.WGT_WIDTH assert wgt_lanes == env.BLOCK_OUT * env.BLOCK_IN wgt_shape = (env.BLOCK_OUT, env.BLOCK_IN) assert wgt_shape[0] * wgt_shape[1] == wgt_lanes inp_lanes = env.INP_ELEM_BITS // env.INP_WIDTH assert inp_lanes == env.BATCH * env.BLOCK_IN inp_shape = (env.BATCH, env.BLOCK_IN) assert inp_shape[0] * inp_shape[1] == inp_lanes out_lanes = env.ACC_ELEM_BITS // env.ACC_WIDTH assert out_lanes == env.BATCH * env.BLOCK_OUT out_shape = (env.BATCH, env.BLOCK_OUT) assert out_shape[0] * out_shape[1] == out_lanes wgt = tvm.placeholder((wgt_shape[0], wgt_shape[1]), dtype="int%d" % env.WGT_WIDTH, name=env.wgt_scope) inp = tvm.placeholder((inp_shape[0], inp_shape[1]), dtype="int%d" % env.INP_WIDTH, name=env.inp_scope) k = tvm.reduce_axis((0, wgt_shape[1]), name="k") out_dtype = "int%d" % env.ACC_WIDTH out = tvm.compute((out_shape[0], out_shape[1]), lambda i, j: tvm.sum(inp[i, k].astype(out_dtype) * wgt[j, k].astype(out_dtype), axis=[k]), name="out") wgt_layout = tvm.decl_buffer( wgt.shape, wgt.dtype, env.wgt_scope, scope=env.wgt_scope, offset_factor=wgt_lanes, data_alignment=wgt_lanes) inp_layout = tvm.decl_buffer( inp.shape, inp.dtype, env.inp_scope, scope=env.inp_scope, offset_factor=inp_lanes, data_alignment=inp_lanes) out_layout = tvm.decl_buffer( out.shape, out.dtype, env.acc_scope, scope=env.acc_scope, offset_factor=out_lanes, data_alignment=out_lanes) def intrin_func(ins, outs): """Matrix-matrix multiply intrinsic function""" dinp, dwgt = ins dout = outs[0] def instr(index): """Generate matrix-matrix multiply VTA instruction""" irb = tvm.ir_builder.create() dev = env.dev irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)) irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) if index in (0, 2): irb.emit(tvm.call_extern( "int32", "VTAUopPush", 0, 0, dout.access_ptr("rw", "int32"), dinp.access_ptr("r", "int32"), dwgt.access_ptr("r", "int32"), 0, 0, 0)) else: irb.emit(tvm.call_extern( "int32", "VTAUopPush", 0, 1, dout.access_ptr("rw", "int32"), 0, 0, 0, 0, 0)) return irb.get() # return a triple of normal-set, reset, update nop = tvm.make.Evaluate(0) if mock: return (nop, nop, nop) return (instr(0), instr(1), instr(2)) return tvm.decl_tensor_intrin(out.op, intrin_func, name="GEMM", binds={inp: inp_layout, wgt: wgt_layout, out: out_layout})
def intrinsic_gemm(i, j, k, il, jl, kl, ic, jc, kc): """ (i, k) * (k, j) i, j, k: normal iteration size il, jl, kl: last iteration size ic, jc, kc: last iteration condition """ assert i * k + k * j <= 256 * 1024, 'input too large for scratchpad' assert 4 * (i * j) <= 64 * 1024, 'input too large for accumulator' a = tvm.placeholder((i, k), name='a', dtype=dtype) b = tvm.placeholder((k, j), name='b', dtype=dtype) kk = tvm.reduce_axis((0, k), name='k') c = tvm.compute((i, j), lambda ii, jj: tvm.sum(a[ii, kk] * b[kk, jj], axis=kk), name='c') strideA = tvm.var("sA") Ab = tvm.decl_buffer(a.shape, a.dtype, name="A", offset_factor=1, strides=[strideA, 1]) strideB = tvm.var("sB") Bb = tvm.decl_buffer(b.shape, b.dtype, name="B", offset_factor=1, strides=[strideB, 1]) strideC = tvm.var("sC") Cb = tvm.decl_buffer(c.shape, c.dtype, name="C", offset_factor=1, strides=[strideC, 1]) II = i // DIM + (0 if i % DIM == 0 else 1) JJ = j // DIM + (0 if j % DIM == 0 else 1) KK = k // DIM + (0 if k % DIM == 0 else 1) pad_I = 0 if i % DIM == 0 else (DIM - i % DIM) pad_J = 0 if j % DIM == 0 else (DIM - j % DIM) pad_K = 0 if k % DIM == 0 else (DIM - k % DIM) IIl = il // DIM + (0 if il % DIM == 0 else 1) JJl = jl // DIM + (0 if jl % DIM == 0 else 1) KKl = kl // DIM + (0 if kl % DIM == 0 else 1) pad_Il = 0 if il % DIM == 0 else (DIM - il % DIM) pad_Jl = 0 if jl % DIM == 0 else (DIM - jl % DIM) pad_Kl = 0 if kl % DIM == 0 else (DIM - kl % DIM) II = tvm.if_then_else(ic, IIl, II) JJ = tvm.if_then_else(jc, JJl, JJ) KK = tvm.if_then_else(kc, KKl, KK) pad_I = tvm.if_then_else(ic, pad_Il, pad_I) pad_J = tvm.if_then_else(jc, pad_Jl, pad_J) pad_K = tvm.if_then_else(kc, pad_Kl, pad_K) # reset-update-finalize def intrin_func(ins, outs): aa, bb = ins cc, = outs def _body(): ib = tvm.ir_builder.create() # int32_t matmul_kernel(const elem_t *A, const elem_t *B, const acc_t *D, # elem_t *C, int32_t I, int32_t J, int32_t K, int32_t pad_I, # int32_t pad_J, int32_t pad_K, int32_t A_row_len, # int32_t B_row_len, int32_t D_row_len, int32_t C_row_len, # bool no_bias, bool repeating_bias); # D is set to a dummy address 1 to determine whether to overwrite # accumulator contents: on the first run, 1 will be retained and # overwrite the value in the accumulator; on subsequent runs D will be # replaced by NULL and C will accumulate on top of the accumulator's contents # This is controlled via bit 1 << (ADDR_LEN - 2) - see kernel source ib.emit( tvm.call_extern("int32", "matmul_kernel", aa.access_ptr("r"), bb.access_ptr("r"), 1, cc.access_ptr("rw"), II, JJ, KK, pad_I, pad_J, pad_K, strideA, strideB, 0, strideC, True, False)) return ib.get() def _reset(): ib = tvm.ir_builder.create() # int32_t matmul_reset(elem_t *C, int32_t I, int32_t J, int32_t pad_I, # int32_t pad_J, int32_t C_row_len); ib.emit( tvm.call_extern("int32", "matmul_reset", cc.access_ptr("w"), II, JJ, pad_I, pad_J, strideC)) return ib.get() def _finalize(): ib = tvm.ir_builder.create() # Move out C from accumulator # int32_t matmul_finalize(elem_t *C, int32_t I, int32_t J, int32_t pad_I, # int32_t pad_J, int32_t C_row_len); ib.emit( tvm.call_extern("int32", "matmul_finalize", cc.access_ptr("rw"), II, JJ, pad_I, pad_J, strideC)) return ib.get() # standalone (without reduce axis split), reset, update return None, _reset(), _body(), _finalize() with tvm.build_config(offset_factor=1): return tvm.decl_tensor_intrin(c.op, intrin_func, binds={ a: Ab, b: Bb, c: Cb }, name="sp_gemm")
def gemm(env, mock=False): """Matrix-matrix multiply intrinsic Parameters ---------- env : Environment The Environment mock : bool Whether create a mock version. """ wgt_lanes = env.WGT_ELEM_BITS // env.WGT_WIDTH assert wgt_lanes == env.BLOCK_OUT * env.BLOCK_IN wgt_shape = (env.BLOCK_OUT, env.BLOCK_IN) assert wgt_shape[0] * wgt_shape[1] == wgt_lanes inp_lanes = env.INP_ELEM_BITS // env.INP_WIDTH assert inp_lanes == env.BATCH * env.BLOCK_IN inp_shape = (env.BATCH, env.BLOCK_IN) assert inp_shape[0] * inp_shape[1] == inp_lanes out_lanes = env.ACC_ELEM_BITS // env.ACC_WIDTH assert out_lanes == env.BATCH * env.BLOCK_OUT out_shape = (env.BATCH, env.BLOCK_OUT) assert out_shape[0] * out_shape[1] == out_lanes wgt = tvm.placeholder((wgt_shape[0], wgt_shape[1]), dtype="int%d" % env.WGT_WIDTH, name=env.wgt_scope) inp = tvm.placeholder((inp_shape[0], inp_shape[1]), dtype="int%d" % env.INP_WIDTH, name=env.inp_scope) k = tvm.reduce_axis((0, wgt_shape[1]), name="k") out_dtype = "int%d" % env.ACC_WIDTH out = tvm.compute((out_shape[0], out_shape[1]), lambda i, j: tvm.sum(inp[i, k].astype(out_dtype) * wgt[ j, k].astype(out_dtype), axis=[k]), name="out") wgt_layout = tvm.decl_buffer(wgt.shape, wgt.dtype, env.wgt_scope, scope=env.wgt_scope, offset_factor=wgt_lanes, data_alignment=wgt_lanes) inp_layout = tvm.decl_buffer(inp.shape, inp.dtype, env.inp_scope, scope=env.inp_scope, offset_factor=inp_lanes, data_alignment=inp_lanes) out_layout = tvm.decl_buffer(out.shape, out.dtype, env.acc_scope, scope=env.acc_scope, offset_factor=out_lanes, data_alignment=out_lanes) def intrin_func(ins, outs): """Matrix-matrix multiply intrinsic function""" dinp, dwgt = ins dout = outs[0] def instr(index): """Generate matrix-matrix multiply VTA instruction""" irb = tvm.ir_builder.create() dev = env.dev irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)) irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) if index == 0 or index == 2: irb.emit( tvm.call_extern("int32", "VTAUopPush", 0, 0, dout.access_ptr("rw", "int32"), dinp.access_ptr("r", "int32"), dwgt.access_ptr("r", "int32"), 0, 0, 0)) else: irb.emit( tvm.call_extern("int32", "VTAUopPush", 0, 1, dout.access_ptr("rw", "int32"), 0, 0, 0, 0, 0)) return irb.get() # return a triple of normal-set, reset, update nop = tvm.make.Evaluate(0) if mock: return (nop, nop, nop) return (instr(0), instr(1), instr(2)) return tvm.decl_tensor_intrin(out.op, intrin_func, name="GEMM", binds={ inp: inp_layout, wgt: wgt_layout, out: out_layout })
def dp4a(x_scope='local', y_scope='local', z_scope='local'): """ Int8 dot product reduced by every 4 elements using __dp4a Parameters ---------- x_scope : str, optional The storage scope of buffer for lhs y_scope : str, optional The storage scope of buffer for rhs z_scope : str, optional The storage scope of buffer for result Returns ------- intrin : TensorIntrin The dp4a TensorIntrin that can be used in tensorizing schedule. """ n = 4 # dp4a requires operands packed by 4 x = tvm.placeholder((n, ), name='x', dtype='int8') y = tvm.placeholder((n, ), name='y', dtype='int8') k = tvm.reduce_axis((0, n), name='rc') z = tvm.compute( (1, ), lambda i: tvm.sum(x[k].astype('int32') * y[k].astype('int32'), axis=[k])) def _intrin_func(ins, outs): def _instr(index): xx, yy = ins zz = outs[0] if index == 1: return zz.vstore(0, 0) ib = tvm.ir_builder.create() vec_x = xx.vload(0, dtype='int8x4') vec_y = yy.vload(0, dtype='int8x4') prev_z = 0 if index == 0 else zz.vload(0) new_z = tvm.call_pure_extern('int32', '__dp4a', vec_x, vec_y, prev_z) ib.emit(zz.vstore(0, new_z)) return ib.get() return _instr(0), _instr(1), _instr(2) # body, reset, update with tvm.build_config(data_alignment=4, offset_factor=1) as cfg: scopes = {x: x_scope, y: y_scope, z: z_scope} binds = { t: tvm.decl_buffer(t.shape, t.dtype, t.op.name, data_alignment=cfg.data_alignment, offset_factor=cfg.offset_factor, scope=scopes[t]) for t in [x, y, z] } return tvm.decl_tensor_intrin(z.op, _intrin_func, binds=binds)