def intrin_wmma_store_matrix(shape): n, m, l = shape A = te.placeholder((n, m), name='A', dtype='float32') BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=n * m) C = te.compute((n, m), lambda i, j: A[i, j], name='C') BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=n * m) def intrin_func(ins, outs): ib = tvm.tir.ir_builder.create() BA = ins[0] BC = outs[0] ib.emit( tvm.tir.call_intrin('handle', 'tir.tvm_store_matrix_sync', BA.data, n, m, l, BA.elem_offset // (n * m), BC.access_ptr('w'), m, 'row_major')) return ib.get() return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
def intrin_wmma_gemm(shape): n, m, l = shape A = te.placeholder((n, l), name='A', dtype='float16') B = te.placeholder((l, m), name='B', dtype='float16') k = te.reduce_axis((0, l), name="k") C = te.compute((n, m), lambda ii, jj: te.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k), name='C') BA = tvm.tir.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=n * l) BB = tvm.tir.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=l * m) BC = tvm.tir.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=n * m) def intrin_func(ins, outs): BA, BB = ins BC, = outs def init(): ib = tvm.tir.ir_builder.create() ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_fill_fragment', BC.data, n, m, l, BC.elem_offset // (n * m), 0.0)) return ib.get() def update(): ib = tvm.tir.ir_builder.create() ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_mma_sync', BC.data, BC.elem_offset // (n * m), BA.data, BA.elem_offset // (n * l), BB.data, BB.elem_offset // (l * m), BC.data, BC.elem_offset // (n * m))) return ib.get() return update(), init(), update() return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC})
def intrin_gemv_no_reset(m, n): w = te.placeholder((m, n), name='w') x = te.placeholder((n, ), name='x') k = te.reduce_axis((0, n), name='k') z = te.compute((m, ), lambda i: te.sum(w[i, k] * x[k], axis=k), name='z') Wb = tvm.tir.decl_buffer(w.shape, w.dtype, name="W", offset_factor=16, strides=[te.var('ldw'), 1]) def intrin_func(ins, outs): ww, xx = ins zz = outs[0] ww_ptr = ww.access_ptr("r") xx_ptr = xx.access_ptr("r") zz_ptr = zz.access_ptr("w") body = tvm.tir.call_packed("gemv", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) update = tvm.tir.call_packed("gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) return body, None, update with tvm.target.build_config(data_alignment=16, offset_factor=16): return te.decl_tensor_intrin(z.op, intrin_func, binds={w: Wb})
def test_tensor_intrin(): n = 16 x = te.placeholder((n, ), name='x') y = te.placeholder((n, ), name='y') z = te.compute(x.shape, lambda i: x[i] + y[i], name='z') def intrin_func(ins, outs): assert (isinstance(ins[0], tvm.te.schedule.Buffer)) assert (ins[0].shape[0].value == n) return tvm.tir.call_packed("vadd", ins[0].data, outs[0].data, ins[0].shape[0]) intrin = te.decl_tensor_intrin(z.op, intrin_func) assert intrin.op == z.op assert intrin.reduce_init is None assert tuple(intrin.inputs) == tuple(z.op.input_tensors) assert (intrin.buffers[0].shape[0].value == n) m = 32 x = te.placeholder((m, ), name='x') y = te.placeholder((m, ), name='y') z = te.compute(x.shape, lambda i: x[i] + y[i], name='z') s = te.create_schedule(z.op) xo, xi = s[z].split(z.op.axis[0], factor=n) s[z].tensorize(xi, intrin) assert (s[z].iter_var_attrs[xi].tensor_intrin == intrin) assert (s[z].iter_var_attrs[xi].iter_type == tvm.te.schedule.IterVar.Tensorized)
def intrin_test(): m1 = te.var("m1") n1 = te.var("n1") a = te.placeholder((m1, n1), name='a') c = te.compute((1, n1), lambda i, j: a[0, j] + a[1, j] + a[2, j], name='c') Ab = tvm.tir.decl_buffer(a.shape, name="Abuf", offset_factor=1) Cb = tvm.tir.decl_buffer(c.shape, name="Cbuf", offset_factor=1) def intrin_func(ins, outs): aa = ins[0] cc = outs[0] def _body(): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_extern("int32", "test", cc.access_ptr("w"), aa.access_ptr("r"))) return ib.get() return _body() with tvm.target.build_config(offset_factor=1): return te.decl_tensor_intrin(c.op, intrin_func, binds={ a: Ab, c: Cb })
def intrin_vadd(n, cache_read=False, cache_write=False): scope_ubuf = 'local' dtype = 'float32' x = te.placeholder((n, ), dtype=dtype, name='vx') y = te.placeholder((n, ), dtype=dtype, name='vy') z = te.compute(x.shape, lambda i: x[i] + y[i], name='z') s = te.create_schedule(z.op) def create_buffer(t): return tvm.tir.decl_buffer(t.shape, t.dtype, name='W' + t.name, scope=scope_ubuf, offset_factor=16) binds = {} if cache_read: binds[x] = create_buffer(x) binds[y] = create_buffer(y) if cache_write: binds[z] = create_buffer(z) def intrin_func(ins, outs): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) return ib.get() return te.decl_tensor_intrin(z.op, intrin_func, binds=binds, default_buffer_params={"offset_factor": 16})
def intrin_wmma_load_matrix(shape, scope): n, m, l = shape if scope == "wmma.matrix_a": row, col = n, l elif scope == "wmma.matrix_b": row, col = l, m A = te.placeholder((row, col), name='A', dtype='float16') BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=row * col) C = te.compute((row, col), lambda i, j: A[i, j], name='C') BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=row * col) def intrin_func(ins, outs): ib = tvm.tir.ir_builder.create() BA = ins[0] BC = outs[0] ib.emit( tvm.tir.call_intrin('handle', 'tir.tvm_load_matrix_sync', BC.data, n, m, l, BC.elem_offset // (row * col), BA.access_ptr('r'), col, 'row_major')) return ib.get() return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
def intrin_wmma_store_matrix(): n = 16 A = te.placeholder((n, n), name="A", dtype="float32") BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope="wmma.accumulator", data_alignment=32, offset_factor=256) C = te.compute((n, n), lambda i, j: A[i, j], name="C") BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope="global", data_alignment=32, offset_factor=256) def intrin_func(ins, outs): ib = tvm.tir.ir_builder.create() BA = ins[0] BC = outs[0] ib.emit( tvm.tir.call_intrin( "handle", "tir.tvm_store_matrix_sync", BA.data, n, n, n, BA.elem_offset // 256, BC.access_ptr("w"), n, "row_major", )) return ib.get() return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
def intrin_wmma_store_matrix(strides_dst, strides_from, shape, out_dtype, A_shape, C_shape): """Intrin function for storing the results from wmma.accumulator to shared""" wmma_m, wmma_n, wmma_k = shape A = te.placeholder(A_shape, name='A', dtype=out_dtype) BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', strides=strides_from, data_alignment=32, offset_factor=8) C = te.compute(C_shape, lambda *i: A(*i), name='C') BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope='shared', strides=strides_dst, data_alignment=32, offset_factor=8) def intrin_func(ins, outs): ib = tvm.tir.ir_builder.create() BA = ins[0] BC = outs[0] row = wmma_m * wmma_n warp_index = BA.elem_offset // row + BA.elem_offset % row // wmma_n ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_store_matrix_sync', BA.data, wmma_m, wmma_n, wmma_k, warp_index, BC.access_ptr('w'), strides_dst[0], 'row_major')) return ib.get() return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
def intrin_wmma_load_matrix_A(strides_dst, strides_from, shape, layout, A_shape, C_shape, in_dtype): """Intrin function for loading data from shared memory to wmma.matrix_a""" wmma_m, wmma_n, wmma_k = shape A = te.placeholder(A_shape, name='A', dtype=in_dtype) BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='shared', strides=strides_from, data_alignment=32, offset_factor=8) C = te.compute(C_shape, lambda *i: A(*i), name='C') BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope="wmma.matrix_a", strides=strides_dst, data_alignment=32, offset_factor=8) def intrin_func(ins, outs): ib = tvm.tir.ir_builder.create() BA = ins[0] BC = outs[0] row = wmma_m * wmma_k warp_index = BC.elem_offset // row + BC.elem_offset % row // wmma_k ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_load_matrix_sync', BC.data, wmma_m, wmma_n, wmma_k, warp_index, BA.access_ptr('r'), strides_from[0], layout)) return ib.get() return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
def intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, strides_A, strides_W, strides_Conv, shape): """Intrin for wmma fill_fragment and mma_sync Parameters ---------- AL_gemm : tvm.te.placeholder wmma matrix A WL_gemm : tvm.te.placeholder wmma matrix B CL_compute : tvm.te.compute The definition of wmma gemm """ wmma_m, wmma_n, wmma_k = shape A = AL_gemm B = WL_gemm C = CL_compute BA = tvm.tir.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=8, strides=strides_A) BB = tvm.tir.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=8, strides=strides_W) BC = tvm.tir.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=8, strides=strides_Conv) def intrin_func(ins, outs): BA, BB = ins BC, = outs def warp_idnex(offset, row, col): row = row * col return offset // row + offset % row // col warp_index_A = warp_idnex(BA.elem_offset, wmma_m, wmma_k) warp_index_B = warp_idnex(BB.elem_offset, wmma_k, wmma_n) warp_index_C = warp_idnex(BC.elem_offset, wmma_m, wmma_n) def init(): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_intrin('handle', 'tir.tvm_fill_fragment', BC.data, wmma_m, wmma_n, wmma_k, warp_index_C, 0.0)) return ib.get() def update(): ib = tvm.tir.ir_builder.create() ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_mma_sync', BC.data, warp_index_C, BA.data, warp_index_A, BB.data, warp_index_B, BC.data, warp_index_C)) return ib.get() return update(), init(), update() return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC})
def intrin_vadd(n): dtype = "float32" x = te.placeholder((n, ), dtype=dtype, name="vx") y = te.placeholder((n, ), dtype=dtype, name="vy") z = te.compute(x.shape, lambda i: x[i] + y[i], name="z") s = te.create_schedule(z.op) def create_buffer(t): return tvm.tir.decl_buffer(t.shape, t.dtype, name="W" + t.name, offset_factor=16) def intrin_func(ins, outs): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_extern( "float32", "vadd", ins[0].access_ptr("r"), ins[1].access_ptr("r"), outs[0].access_ptr("wr"), )) return ib.get() return te.decl_tensor_intrin(z.op, intrin_func, binds={ x: create_buffer(x), y: create_buffer(y), z: create_buffer(z) })
def intrin_multivadd(n): n_a = te.var("n_a") Ab = tvm.tir.decl_buffer((n, ), "float32", strides=[n_a]) n_b = te.var("n_b") Bb = tvm.tir.decl_buffer((n, ), "float32", strides=[n_b]) n_c = te.var("n_c") Cb = tvm.tir.decl_buffer((n, ), "float32", strides=[n_c]) z = te.compute( (n, ), lambda i: tvm.tir.call_extern( "float32", "vadd", Ab.access_ptr("w", offset=n_a * i), Bb.access_ptr("r", offset=n_b * i), Cb.access_ptr("r", offset=n_c * i), ), ) # replace the pattern with the multivadd call. I need to figure out # how to pass it the right parameters. def intrin_func(ins, outs): return tvm.tir.call_packed("multivadd") return te.decl_tensor_intrin(z.op, intrin_func, name="multivadd")
def test_tensor_intrin_scalar_params(): n = te.size_var("n") x = te.placeholder((n, ), name='x') v = te.size_var("v") w = te.size_var("w") z = te.compute((n, ), lambda i: x[i] * v + w, name='z') def intrin_func(ins, outs, sp): assert (isinstance(ins[0], tvm.te.schedule.Buffer)) assert (ins[0].shape[0] == n) assert (sp[0] == v) assert (sp[1] == w) return tvm.tir.call_packed("hw_func", ins[0].data, outs[0].data, sp[0], sp[1]) with tvm.target.build_config(offset_factor=1): intrin = te.decl_tensor_intrin(z.op, intrin_func, scalar_params=[v, w]) assert intrin.op == z.op assert intrin.reduce_init is None assert tuple(intrin.inputs) == tuple(z.op.input_tensors) assert (intrin.buffers[0].shape[0] == n) assert tuple(intrin.scalar_params) == tuple((v, w)) A = te.placeholder((10, 10), name='A') # Pass scalar inputs to the TensorIntrin, interleaved with tensor inputs C = te.compute((10, 10), lambda i, j: intrin(i * i, A[i, j], i + j), name="C") s = te.create_schedule(C.op) stmt = tvm.lower(s, [A, C], simple_mode=True) assert isinstance(stmt.body.body.body, tvm.tir.Evaluate) assert len(stmt.body.body.body.value.args) == 5 assert str(stmt.body.body.body.value.args[3]) == "(i*i)" assert str(stmt.body.body.body.value.args[4]) == "(i + j)"
def intrin_gemv(m, l): a = te.placeholder((l,), name='a') b = te.placeholder((m, l), name='b') k = te.reduce_axis((0, l), name='k') c = te.compute((m,), lambda i: te.sum(a[k] * b[i, k], axis=k), name='c') Ab = tvm.tir.decl_buffer(a.shape, a.dtype, name="A", offset_factor=1, strides=[1]) Bb = tvm.tir.decl_buffer(b.shape, b.dtype, name="B", offset_factor=1, strides=[te.var("s1"), 1]) Cb = tvm.tir.decl_buffer(c.shape, c.dtype, name="C", offset_factor=1, strides=[1]) def intrin_func(ins, outs): ib = tvm.tir.ir_builder.create() aa, bb = ins cc = outs[0] ib.emit(tvm.tir.call_extern("int32", "gemv_update", cc.access_ptr("w"), aa.access_ptr("r"), bb.access_ptr("r"), m, l, bb.strides[0])) return ib.get() return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb})
def intrin_vadd(n): dtype = 'float32' x = te.placeholder((n, ), dtype=dtype, name='vx') y = te.placeholder((n, ), dtype=dtype, name='vy') z = te.compute(x.shape, lambda i: x[i] + y[i], name='z') s = te.create_schedule(z.op) def create_buffer(t): return tvm.tir.decl_buffer(t.shape, t.dtype, name='W' + t.name, offset_factor=16) def intrin_func(ins, outs): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_extern("float32", 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) return ib.get() with tvm.target.build_config(offset_factor=16): return te.decl_tensor_intrin(z.op, intrin_func, binds={ x: create_buffer(x), y: create_buffer(y), z: create_buffer(z) })
def intrin_gemv(m, n): w = te.placeholder((m, n), name="w") x = te.placeholder((n, ), name="x") k = te.reduce_axis((0, n), name="k") z = te.compute((m, ), lambda i: te.sum(w[i, k] * x[k], axis=k), name="z") Wb = tvm.tir.decl_buffer(w.shape, w.dtype, name="W", offset_factor=16, strides=[te.var("ldw"), 1]) def intrin_func(ins, outs): ww, xx = ins zz = outs[0] ww_ptr = ww.access_ptr("r") xx_ptr = xx.access_ptr("r") zz_ptr = zz.access_ptr("w") body = tvm.tir.call_packed("gemm", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) reset = tvm.tir.call_packed("fill_zero", zz_ptr, n) update = tvm.tir.call_packed("gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) return body, reset, update buffer_params = {"data_alignment": 16, "offset_factor": 16} return te.decl_tensor_intrin(z.op, intrin_func, binds={w: Wb}, default_buffer_params=buffer_params)
def dp4a(x_scope='local', y_scope='local', z_scope='local'): """ Int8 dot product reduced by every 4 elements using __dp4a Parameters ---------- x_scope : str, optional The storage scope of buffer for lhs y_scope : str, optional The storage scope of buffer for rhs z_scope : str, optional The storage scope of buffer for result Returns ------- intrin : TensorIntrin The dp4a TensorIntrin that can be used in tensorizing schedule. """ n = 4 # dp4a requires operands packed by 4 x = te.placeholder((n,), name='x', dtype='int8') y = te.placeholder((n,), name='y', dtype='int8') k = te.reduce_axis((0, n), name='rc') z = te.compute((1,), lambda i: te.sum( x[k].astype('int32') * y[k].astype('int32'), axis=[k])) def _intrin_func(ins, outs): def _instr(index): xx, yy = ins zz = outs[0] if index == 1: return zz.vstore(0, 0) ib = tvm.tir.ir_builder.create() vec_x = xx.vload(0, dtype='int8x4') vec_y = yy.vload(0, dtype='int8x4') prev_z = 0 if index == 0 else zz.vload(0) new_z = tvm.tir.call_pure_extern('int32', '__dp4a', vec_x, vec_y, prev_z) ib.emit(zz.vstore(0, new_z)) return ib.get() return _instr(0), _instr(1), _instr(2) # body, reset, update default_buffer_params = { "data_alignment": 4, "offset_factor": 1 } scopes = {x: x_scope, y: y_scope, z: z_scope} binds = {t: tvm.tir.decl_buffer(t.shape, t.dtype, t.op.name, scope=scopes[t], **default_buffer_params) for t in [x, y, z]} return te.decl_tensor_intrin( z.op, _intrin_func, binds=binds, default_buffer_params=default_buffer_params)
def intrin_max(shape, in_dtype, out_dtype): """Defines a v7e-m DSP-accelerated max pool.""" UNIQ_ID_LEN = 8 uniq_id = "".join(random.choices(string.ascii_uppercase, k=UNIQ_ID_LEN)) func_prefix = "max8" assert in_dtype == "int8" assert out_dtype == "int8" x = te.placeholder(shape, name="x", dtype=in_dtype) k = te.reduce_axis((0, 1), name="rc") z = te.compute(shape, lambda *i: tvm.tir.max(x[i], axis=[k]).astype(out_dtype)) def _intrin_func(ins, outs): aa = ins[0] cc = outs[0] def _body(): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_extern( cc.dtype, f"{func_prefix}_{uniq_id}", aa.access_ptr("r"), cc.access_ptr("w"), cc.strides[0], )) return ib.get() def _reduce_reset(): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_extern(cc.dtype, f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0])) return ib.get() def _reduce_update(): return _body() return _body(), _reduce_reset(), _reduce_update() binds = { t: tvm.tir.decl_buffer( t.shape, t.dtype, t.op.name, strides=[ te.var(f"{t.op.name}_s_{i}") for i in range(0, len(t.shape)) ], offset_factor=1, ) for t in [x, z] } intrin_decl = te.decl_tensor_intrin(z.op, _intrin_func, binds=binds) return intrin_decl, uniq_id
def intrin_vadd(n): x = te.placeholder((n,)) y = te.placeholder((n,)) z = te.compute(x.shape, lambda i: x[i] + y[i]) def intrin_func(ins, outs): ib = tvm.tir.ir_builder.create() ib.emit(tvm.tir.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) return ib.get() with tvm.target.build_config(offset_factor=n): return te.decl_tensor_intrin(z.op, intrin_func)
def op_intrin(): bh = 9 bw = 9 x = te.placeholder((5, 5), name="A") y = te.compute((bh, bw), lambda i, j: x[idxd(j, 3) + idxm(i, 3), idxm(j, 3) + idxd(i, 3)]) def intrin_func(ins, outs): (xx,) = ins zz = outs[0] return tvm.tir.call_packed("op", xx, zz) return te.decl_tensor_intrin(y.op, intrin_func, default_buffer_params={"offset_factor": 2})
def intrin_vadd(n): x = te.placeholder((n, ), name='vx') y = te.placeholder((n, ), name='vy') z = te.compute(x.shape, lambda i: x[i] + y[i], name='z') def intrin_func(ins, outs): xx, yy = ins zz = outs[0] return tvm.tir.call_packed("vadd", xx, yy, zz) with tvm.target.build_config(offset_factor=16): return te.decl_tensor_intrin(z.op, intrin_func)
def intrin_mem_copy(shape, dtype, dst_scope, src_scope): """Define and return tensor intrinsic for mem copy""" src = te.placeholder(shape=shape, dtype=dtype, name="src") dst = te.compute(shape, lambda i: src[i], name="dst") size = shape[0] * np.dtype(dtype).itemsize src_buffer = tvm.tir.decl_buffer( shape, dtype, scope=src_scope, offset_factor=1, name="mem_copy_src_buffer", ) dst_buffer = tvm.tir.decl_buffer( shape, dtype, scope=dst_scope, offset_factor=1, name="mem_copy_dst_buffer", ) zero_indices = [0 for _ in shape] def intrin_func(ins, outs): ir_builder = tvm.tir.ir_builder.create() _src = ins[0] _dst = outs[0] dst_handle = ir_builder.buffer_ptr(dst_buffer) src_handle = ir_builder.buffer_ptr(src_buffer) ir_builder.emit( tvm.tir.call_intrin( "handle", "tir.mem_copy", tvm.tir.call_intrin("handle", "tir.address_of", dst_handle[zero_indices]), tvm.tir.call_intrin("handle", "tir.address_of", src_handle[zero_indices]), size, )) return ir_builder.get() return te.decl_tensor_intrin(dst.op, intrin_func, binds={ src: src_buffer, dst: dst_buffer })
def op_intrin(): bh = 9 bw = 9 x = te.placeholder((5, 5), name='A') y = te.compute((bh, bw), lambda i, j: x[idxd(j, 3) + idxm(i, 3), idxm(j, 3) + idxd(i, 3)]) def intrin_func(ins, outs): xx, = ins zz = outs[0] return tvm.tir.call_packed("op", xx, zz) with tvm.target.build_config(offset_factor=2): return te.decl_tensor_intrin(y.op, intrin_func)
def intrin_vadd(n): x = te.placeholder((n, ), name='vx') y = te.placeholder((n, ), name='vy') z = te.compute(x.shape, lambda i: x[i] + y[i], name='z') def intrin_func(ins, outs): xx, yy = ins zz = outs[0] return tvm.tir.call_packed("vadd", xx, yy, zz) buffer_params = {"offset_factor": 16} return te.decl_tensor_intrin(z.op, intrin_func, default_buffer_params=buffer_params)
def intrin_libxsmm(m, k, n): a = te.placeholder((m, k), name='a') b = te.placeholder((k, n), name='b') k = te.reduce_axis((0, k), name='k') c = te.compute((m, n), lambda i, j: te.sum(a[i, k] * b[k, j], axis=k), name='c') a_buffer = tvm.tir.decl_buffer( a.shape, a.dtype, name='a_buffer', offset_factor=1, strides=[te.var('s1'), 1]) #[te.var('s1'), te.var('s11')]) b_buffer = tvm.tir.decl_buffer(b.shape, b.dtype, name='b_buffer', offset_factor=1, strides=[te.var('s2'), 1]) c_buffer = tvm.tir.decl_buffer(c.shape, c.dtype, name='c_buffer', offset_factor=1, strides=[te.var('s3'), 1]) def intrin_func(ins, outs): def _body(): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_packed("tvm.contrib.libxsmm.matmul", ins[0], ins[1], outs[0], False, False)) return ib.get() def _update(): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_packed("tvm.contrib.libxsmm.matmul", ins[0], ins[1], outs[0], False, False, 1, 1)) return ib.get() return _body(), None, _update() return te.decl_tensor_intrin(c.op, intrin_func, binds={ a: a_buffer, b: b_buffer, c: c_buffer })
def intrin_vadd(n): x = te.placeholder((n, )) y = te.placeholder((n, )) z = te.compute(x.shape, lambda i: x[i] + y[i]) def intrin_func(ins, outs): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) return ib.get() return te.decl_tensor_intrin( z.op, intrin_func, default_buffer_params={"offset_factor": n})
def intrin_pool(): A = te.placeholder((64, 16, 16), name='A') kh = te.reduce_axis((0, 3), name='kh') kw = te.reduce_axis((0, 3), name='kw') P = te.compute((64, 14, 14), lambda c, oh, ow: tvm.te.max(A[c, oh + kh, ow + kw], axis=[kh, kw]), name='p') def intrin_func(ins, outs): dinp = ins[0] dout = outs[0] return tvm.tir.call_packed("op", dinp, dout) with tvm.target.build_config(offset_factor=1): return te.decl_tensor_intrin(P.op, intrin_func)
def intrin_pool(): A = te.placeholder((64, 16, 16), name="A") kh = te.reduce_axis((0, 3), name="kh") kw = te.reduce_axis((0, 3), name="kw") P = te.compute( (64, 14, 14), lambda c, oh, ow: tvm.te.max(A[c, oh + kh, ow + kw], axis=[kh, kw]), name="p", ) def intrin_func(ins, outs): dinp = ins[0] dout = outs[0] return tvm.tir.call_packed("op", dinp, dout) return te.decl_tensor_intrin(P.op, intrin_func, default_buffer_params={"offset_factor": 1})
def intrin_gemm(m, n, l): k = te.reduce_axis((0, l)) x = te.placeholder((m, l)) y = te.placeholder((n, l)) # in theory, no relation z = te.compute((m, n), lambda i, j: te.sum(x[i][k] * y[j][k], axis=k)) def intrin_func(ins, outs): x_ptr = ins[0].access_ptr("r") y_ptr = ins[1].access_ptr("r") z_ptr = outs[0].access_ptr("w") body = tvm.tir.call_packed("gemv", x_ptr, y_ptr, z_ptr, m, n, l) reset = tvm.tir.call_packed("fill_zero", z_ptr, m, n) update = tvm.tir.call_packed("gemv_add", x_ptr, y_ptr, z_ptr, m, n, l) return body, reset, update return te.decl_tensor_intrin(z.op, intrin_func, default_buffer_params={"offset_factor": n})