def test_gemm(): # graph nn = 2048 n = te.var("n") n = tvm.runtime.convert(nn) m, l = n, n A = te.placeholder((l, n), name="A") B = te.placeholder((l, m), name="B") k = te.reduce_axis((0, l), name="k") C = te.compute((m, n), lambda ii, jj: te.sum(A[k, jj] * B[k, ii], axis=k), name="C") # schedule s = te.create_schedule(C.op) AA = s.cache_read(A, "shared", [C]) BB = s.cache_read(B, "shared", [C]) AL = s.cache_read(AA, "local", [C]) BL = s.cache_read(BB, "local", [C]) CC = s.cache_write(C, "local") scale = 8 num_thread = 8 block_factor = scale * num_thread block_x = te.thread_axis("blockIdx.x") thread_x = te.thread_axis((0, num_thread), "threadIdx.x") block_y = te.thread_axis("blockIdx.y") thread_y = te.thread_axis((0, num_thread), "threadIdx.y") thread_xz = te.thread_axis((0, 2), "vthread", name="vx") thread_yz = te.thread_axis((0, 2), "vthread", name="vy") by, yi = s[C].split(C.op.axis[0], factor=block_factor) bx, xi = s[C].split(C.op.axis[1], factor=block_factor) s[C].bind(by, block_y) s[C].bind(bx, block_x) s[C].reorder(by, bx, yi, xi) tyz, yi = s[C].split(yi, nparts=2) ty, yi = s[C].split(yi, nparts=num_thread) txz, xi = s[C].split(xi, nparts=2) tx, xi = s[C].split(xi, nparts=num_thread) s[C].bind(tyz, thread_yz) s[C].bind(txz, thread_xz) s[C].bind(ty, thread_y) s[C].bind(tx, thread_x) s[C].reorder(tyz, txz, ty, tx, yi, xi) s[CC].compute_at(s[C], tx) yo, xo = CC.op.axis ko, ki = s[CC].split(k, factor=8) kt, ki = s[CC].split(ki, factor=1) s[CC].reorder(ko, kt, ki, yo, xo) s[AA].compute_at(s[CC], ko) s[BB].compute_at(s[CC], ko) s[CC].unroll(kt) s[AL].compute_at(s[CC], kt) s[BL].compute_at(s[CC], kt) # Schedule for A's shared memory load ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread) _, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread * 4) tx, xi = s[AA].split(xi, nparts=num_thread) s[AA].bind(ty, thread_y) s[AA].bind(tx, thread_x) s[AA].vectorize(xi) # Schedule for B' shared memory load ty, xi = s[BB].split(s[BB].op.axis[0], nparts=num_thread) _, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread * 4) tx, xi = s[BB].split(xi, nparts=num_thread) s[BB].bind(ty, thread_y) s[BB].bind(tx, thread_x) s[BB].vectorize(xi) s[AA].double_buffer() s[BB].double_buffer() # correctness def check_device(device): dev = tvm.device(device, 0) if not dev.exist: print("Skip because %s is not enabled" % device) return print("Device %s" % device) f = tvm.build(s, [A, B, C], device) # launch the kernel. n, m, l = nn, nn, nn a_np = np.random.uniform(size=(n, l)).astype(A.dtype) b_np = np.random.uniform(size=(m, l)).astype(B.dtype) a = tvm.nd.array(a_np, dev) b = tvm.nd.array(b_np, dev) c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), dev) for i in range(2): f(a, b, c) tvm.testing.assert_allclose(c.asnumpy(), np.dot(b_np.T, a_np), rtol=1e-5) num_flops = 2 * nn * nn * nn num_runs = 10 timer_f = f.time_evaluator(f.entry_name, dev, number=num_runs) t = timer_f(a, b, c).mean GFLOPS = num_flops / (t * 1e3) / 1e6 print("average time cost of %d runs = %g ms, %g GFLOPS." % (num_runs, t * 1e3, GFLOPS)) for device in ["cuda", "opencl", "rocm", "nvptx", "vulkan"]: with tvm.transform.PassContext( config={"tir.UnrollLoop": {"auto_max_step": 128, "explicit_unroll": device != "cuda"}} ): check_device(device)
def test_expr_constructor(): x = tvm.tir.Var("xx", "float32") assert isinstance(x, tvm.tir.Var) assert x.name == "xx" x = tvm.tir.Reduce(None, [1], [tvm.tir.IterVar((0, 1), "x", 2)], None, 0) assert isinstance(x, tvm.tir.Reduce) assert x.combiner == None assert x.value_index == 0 x = tvm.tir.FloatImm("float32", 1.0) assert isinstance(x, tvm.tir.FloatImm) assert x.value == 1.0 assert x.dtype == "float32" x = tvm.tir.IntImm("int64", 2) assert isinstance(x, tvm.tir.IntImm) assert x.value == 2 assert x.dtype == "int64" x = tvm.tir.StringImm("xyza") assert isinstance(x, tvm.tir.StringImm) assert x.value == "xyza" x = tvm.tir.Cast("float32", tvm.tir.IntImm("uint32", 1)) assert isinstance(x, tvm.tir.Cast) assert x.dtype == "float32" assert x.value.value == 1 a = tvm.tir.const(1.0, dtype="float32") b = te.var("x", dtype="float32") for cls in [ tvm.tir.Add, tvm.tir.Sub, tvm.tir.Mul, tvm.tir.Div, tvm.tir.Mod, tvm.tir.Min, tvm.tir.Max, tvm.tir.LT, tvm.tir.LE, tvm.tir.GT, tvm.tir.GE ]: x = cls(a, b) assert isinstance(x, cls) assert x.a == a assert x.b.same_as(b) a = tvm.runtime.convert(te.var("x") > 1) b = tvm.runtime.convert(te.var("x") == 1) for cls in [tvm.tir.And, tvm.tir.Or]: x = cls(a, b) assert isinstance(x, cls) assert x.a == a assert x.b.same_as(b) x = tvm.tir.Not(a) assert isinstance(x, tvm.tir.Not) assert x.a == a x = tvm.tir.Select(a, a, b) assert isinstance(x, tvm.tir.Select) assert x.true_value == a assert x.false_value == b assert x.condition == a buffer_var = te.var("x", dtype="handle") x = tvm.tir.Load("float32", buffer_var, 1, a) assert isinstance(x, tvm.tir.Load) assert x.dtype == "float32" assert x.buffer_var == buffer_var assert x.index.value == 1 assert x.predicate == a x = tvm.tir.Ramp(1, 2, 10) assert isinstance(x, tvm.tir.Ramp) assert x.base.value == 1 assert x.stride.value == 2 assert x.lanes == 10 x = tvm.tir.Broadcast(a, 10) assert isinstance(x, tvm.tir.Broadcast) assert x.value == a assert x.lanes == 10 x = tvm.tir.Shuffle([a], [0]) assert isinstance(x, tvm.tir.Shuffle) assert x.vectors[0] == a assert x.indices[0].value == 0 x = tvm.tir.Call("float32", "tir.call_extern", [tvm.tir.StringImm("xyz"), a], tvm.tir.Call.Extern) assert isinstance(x, tvm.tir.Call) assert x.dtype == "float32" assert x.op.name == "tir.call_extern" assert x.args[1] == a assert x.call_type == tvm.tir.Call.Extern v = te.var("aa") x = tvm.tir.Let(v, 1, v) assert x.var == v assert x.value.value == 1 assert x.body == v
def test_select(): ck = IntSetChecker() x, y = te.var("x"), te.var("y") ck.verify(tvm.tir.Select(x > 0, x - 1, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (-1, 11))
# TVM adopts tensor semantics, with each intermediate result # represented as a multi-dimensional array. The user needs to describe # the computation rule that generates the tensors. # # We first define a symbolic variable n to represent the shape. # We then define two placeholder Tensors, A and B, with given shape (n,) # # We then describe the result tensor C, with a compute operation. The # compute function takes the shape of the tensor, as well as a lambda # function that describes the computation rule for each position of # the tensor. # # No computation happens during this phase, as we are only declaring how # the computation should be done. # n = te.var("n") A = te.placeholder((n, ), name="A") B = te.placeholder((n, ), name="B") C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") print(type(C)) ###################################################################### # Schedule the Computation # ------------------------ # While the above lines describe the computation rule, we can compute # C in many ways since the axis of C can be computed in a data # parallel manner. TVM asks the user to provide a description of the # computation called a schedule. # # A schedule is a set of transformation of computation that transforms # the loop of computations in the program.
def test_basic_operation(): np.random.seed(0) shape = (10, 10) x = te.var("x", dtype='float32') k = te.reduce_axis((0, 10), name="k") l = te.reduce_axis((0, 10), name="l") A0 = te.placeholder(shape, name='A0') A1 = te.placeholder(shape, name='A1') zeros = np.zeros(shape) B = te.compute(shape, lambda i, j: A0[i, j], name='B') check_grad(B, [A0]) B = te.compute(shape, lambda i, j: A0[i, j] + A1[i, j], name='B') check_grad(B, [A0, A1]) B = te.compute(shape, lambda i, j: A0[i, j] + A0[j, i], name='B') check_grad(B, A0) B = te.compute(shape, lambda i, j: te.floor(A0[i, j]), name='B') check_grad(B, A0, desired_grads=[zeros]) B = te.compute(shape, lambda i, j: te.ceil(A0[i, j]), name='B') check_grad(B, A0, desired_grads=[zeros]) B = te.compute(shape, lambda i, j: te.trunc(A0[i, j]), name='B') check_grad(B, A0, desired_grads=[zeros]) B = te.compute(shape, lambda i, j: te.round(A0[i, j]), name='B') check_grad(B, A0, desired_grads=[zeros]) B = te.compute(shape, lambda i, j: A0[i, j] + te.exp(A0[j, i]), name='B') check_grad(B, A0) B = te.compute( shape, lambda i, j: te.log(0.1 + te.abs(A0[i, j] + te.exp(A0[j, i]))), name='B') check_grad(B, A0) B = te.compute(shape, lambda i, j: te.sigmoid(A0[i, j] * A0[i, j] * A0[j, i]), name='B') check_grad(B, A0) B = te.compute(shape, lambda i, j: te.tanh(A0[i, j] * A0[i, j] * A0[j, i]), name='B') check_grad(B, A0) B = te.compute(shape, lambda i, j: te.sqrt(A0[i, j] * A0[i, j] * A0[j, i]), name='B') check_grad(B, A0, data_range=(0.1, 10)) B = te.compute(shape, lambda i, j: te.power(te.abs(A0[i, j]), A0[j, i]), name='B') check_grad(B, A0, data_range=(-4, 4)) B = te.compute(shape, lambda i, j: A0[i, j] * A0[j, i], name='B') check_grad(B, A0) B = te.compute((10, ), lambda i: te.sum(A0[i, k] * A0[k, i], axis=k), name='B') check_grad(B, A0) B = te.compute(shape, lambda i, j: te.sum(A0[i, k] * A0[k, i] + 5, axis=k), name='B') check_grad(B, A0) B = te.compute(shape, lambda i, j: te.max(A0[i, k] * A0[k, j] + 5, axis=k), name='B') check_grad(B, A0) B = te.compute(shape, lambda i, j: A0[i, j] * (A1[j, i] + A0[j, i]), name='B') check_grad(B, [A0, A1]) B = te.compute(shape, lambda i, j: te.sum( A0[k, k] - A0[te.min(j + k, 9), j] * A0[i, k], axis=k), name='B') check_grad(B, A0) def fcombine(x, y): return x * y def fidentity(t0): return tvm.tir.const(1, t0) prod = te.comm_reducer(fcombine, fidentity, name='prod') B = te.compute((10, 10), lambda i, j: prod(A0[i, k] + A0[k, i], axis=k), name='B') check_grad(B, A0) X = te.placeholder((10, ), name='X') A = te.compute((10, ), lambda i: X[i] + X[9 - i]) B = te.compute((10, ), lambda i: X[i] * X[9 - i]) Y = topi.tensordot(A, B, 1) check_grad(Y, X)
def gemm_4x4_int8_int8_int32(M, N, K, unroll, in_type): """ Int8 4x4 matrix multiplication and accumulation using a sequence of umull -> uadalp -> umull2 -> uadalp instructions. This function takes two arrays of int8 data type A[4][K] and B[4][K], and produces a 4x4 matrix which is equal to A*B'. The pseudo code is as follows. .. code-block:: c void gemm_4x4_int8_int8_int32(int8 A[4][K], int8 B[4][K], int32 C[4][4]){ for (int i = 0; i < 4; i++){ for (int j = 0; j < 4; j++){ for (int k = 0; k < K; k++){ C[i][j] += A[i][k] * B[j][k] } } } Notes: * The tiling strategy is picked to maximize register usage. Parameters ---------- M : int rows of the matrix A N : int columns of the matrix B K : int columns of matrix A unroll : bool Unroll the loop accumulation if True in_type : str, {'uint8', 'int8'} Returns ------- intrin : TensorIntrin The ARM uint8/int8 TensorIntrin that can be used in tensorizing schedule """ assert in_type in ["uint8", "int8"] A = te.placeholder((K // 16, te.var("m"), 16), dtype=in_type, name="A") B = te.placeholder((K // 16, te.var("n"), 16), dtype=in_type, name="B") dtype_vec = in_type + "x16" idxm = tvm.tir.indexmod k = te.reduce_axis((0, K), "k") C = te.compute( (te.var("m"), te.var("n")), lambda x, y: te.sum( A[k // 16, x, idxm(k, 16)].astype("int32") * B[ k // 16, y, idxm(k, 16)].astype("int32"), axis=k, ), name="C", ) a_buffer = tvm.tir.decl_buffer( A.shape, dtype=in_type, name="a_buffer", offset_factor=1, strides=[te.var("sa_1"), te.var("sa_2"), 1], ) b_buffer = tvm.tir.decl_buffer( B.shape, dtype=in_type, name="b_buffer", offset_factor=1, strides=[te.var("sb_1"), te.var("sb_2"), 1], ) c_buffer = tvm.tir.decl_buffer(C.shape, dtype="int32", name="c_buffer", offset_factor=1, strides=[te.var("sc"), 1]) # Intrinsics used in the following algorithm umull_intrin = "llvm.aarch64.neon.umull" if in_type == "uint8" else "llvm.aarch64.neon.smull" uaddlp_intrin = "llvm.aarch64.neon.uaddlp" if in_type == "uint8" else "llvm.aarch64.neon.saddlp" addp_intrin = "llvm.aarch64.neon.addp" def uadalp(a, b): """Add pair and accumulate Parameters: ---------- a: int16x8 vector b: int16x8 vector Returns: -------- return a int32x4 vector Pseudocode: ---------- a += (b0+b1, b2+b3, b4+b5, b6+b7) """ return a + tvm.tir.call_llvm_pure_intrin("int32x4", uaddlp_intrin, tvm.tir.const(1, "uint32"), b) def umull(a, b): """Multiply long (higher part) Parameters: ---------- a: int8x16 vector b: int8x16 vector Returns: -------- return a int16x8 vector Pseudocode: ---------- c = (a0*b0, a1*b1, a2*b2, a3*b3, a4*b4, a5*b5, a6*b6, a7*b7) """ a_high = tvm.tir.call_intrin("int8x8", "tir.vectorhigh", a) b_high = tvm.tir.call_intrin("int8x8", "tir.vectorhigh", b) c = tvm.tir.call_llvm_pure_intrin("int16x8", umull_intrin, tvm.tir.const(2, "uint32"), a_high, b_high) return c def umull2(a, b): """Multiply long (lower part) Parameters: ---------- a: int8x16 vector b: int8x16 vector Returns: -------- return a int16x8 vector Pseudocode: ---------- c = (a8*b8, a9*b9, a10*b10, a11*b11, a12*b12, a13*b13, a14*b14, a15*b15) """ a_low = tvm.tir.call_intrin("int8x8", "tir.vectorlow", a) b_low = tvm.tir.call_intrin("int8x8", "tir.vectorlow", b) c = tvm.tir.call_llvm_pure_intrin("int16x8", umull_intrin, tvm.tir.const(2, "uint32"), a_low, b_low) return c def addp(a, b): """Add two vectors in pairs Parameters: ---------- a: int32x4 vector b: int32x4 vector Returns: -------- return a int32x4 vector Pseudocode: ---------- c = (a0+a1, a2+a3, b0+b1, b0+b3) """ return tvm.tir.call_llvm_pure_intrin("int32x4", addp_intrin, tvm.tir.const(2, "uint32"), a, b) def accumulation_loop(M, N, ins, acc, tile_idx): """Internal tile accumulation. This function takes two arrays of int8 data type A[tile_idx][4][16] and B[tile_idx][4][16], produces a 4x4 matrix which is equal to A*B' and accumulates into C[4][4] The pseudo code is as follows. .. code-block:: c void gemm_4x4_int8_int8_int32(int8 A[tile_idx][4][K], int8 B[tile_idx][4][K], int32 C[4][4]){ for (int i = 0; i < 4; i++){ for (int j = 0; j < 4; j++){ for (int k = 0; k < 16; k++){ C[i][j] += A[tile_idx][i][k] * B[tile_idx][j][k] } } } Notes: * The tiling strategy is picked to maximize register usage. Parameters: ---------- M : int Number of total rows of the output matrix N : int Number of total columns of the output matrix ins : list of tvm.tir.buffer Input buffers acc : tvm.tir.ir_builder.BufferVar Bank of register accumulators tiled_idx : int Index of a sub-tile of A and B in A[tile_idx][:][:] and B[tile_idx][:][:]. Please note that 0 <= tile_idx <= K//16 """ a0 = ins[0].vload([tile_idx, 0, 0], dtype_vec) a1 = tvm.tir.const(0, "int8x16") if M > 1: a1 = ins[0].vload([tile_idx, 1, 0], dtype_vec) a2 = tvm.tir.const(0, "int8x16") if M > 2: a2 = ins[0].vload([tile_idx, 2, 0], dtype_vec) a3 = tvm.tir.const(0, "int8x16") if M > 3: a3 = ins[0].vload([tile_idx, 3, 0], dtype_vec) b0 = ins[1].vload([tile_idx, 0, 0], dtype_vec) b1 = tvm.tir.const(0, "int8x16") if N > 1: b1 = ins[1].vload([tile_idx, 1, 0], dtype_vec) b2 = tvm.tir.const(0, "int8x16") if N > 2: b2 = ins[1].vload([tile_idx, 2, 0], dtype_vec) b3 = tvm.tir.const(0, "int8x16") if N > 3: b3 = ins[1].vload([tile_idx, 3, 0], dtype_vec) # First half # Lower part of a0 * {b0,b1,b2,b3} d00 = umull(a0, b0) d01 = umull(a0, b1) d02 = umull(a0, b2) d03 = umull(a0, b3) # Lower part of a1 * {b0,b1,b2,b3} d10 = umull(a1, b0) d11 = umull(a1, b1) d12 = umull(a1, b2) d13 = umull(a1, b3) # Accumulate acc[0] = uadalp(acc[0], d00) acc[1] = uadalp(acc[1], d01) acc[2] = uadalp(acc[2], d02) acc[3] = uadalp(acc[3], d03) acc[4] = uadalp(acc[4], d10) acc[5] = uadalp(acc[5], d11) acc[6] = uadalp(acc[6], d12) acc[7] = uadalp(acc[7], d13) # Higher part of a0 * {b0,b1,b2,b3} d00 = umull2(a0, b0) d01 = umull2(a0, b1) d02 = umull2(a0, b2) d03 = umull2(a0, b3) # Higher part of a1 * {b0,b1,b2,b3} d10 = umull2(a1, b0) d11 = umull2(a1, b1) d12 = umull2(a1, b2) d13 = umull2(a1, b3) # Accumulate again acc[0] = uadalp(acc[0], d00) acc[1] = uadalp(acc[1], d01) acc[2] = uadalp(acc[2], d02) acc[3] = uadalp(acc[3], d03) acc[4] = uadalp(acc[4], d10) acc[5] = uadalp(acc[5], d11) acc[6] = uadalp(acc[6], d12) acc[7] = uadalp(acc[7], d13) # Second half # Lower part of a2 * {b0,b1,b2,b3} d00 = umull(a2, b0) d01 = umull(a2, b1) d02 = umull(a2, b2) d03 = umull(a2, b3) # Lower part of a3 * {b0,b1,b2,b3} d10 = umull(a3, b0) d11 = umull(a3, b1) d12 = umull(a3, b2) d13 = umull(a3, b3) # Accumulate acc[8] = uadalp(acc[8], d00) acc[9] = uadalp(acc[9], d01) acc[10] = uadalp(acc[10], d02) acc[11] = uadalp(acc[11], d03) acc[12] = uadalp(acc[12], d10) acc[13] = uadalp(acc[13], d11) acc[14] = uadalp(acc[14], d12) acc[15] = uadalp(acc[15], d13) # Higher part of a2 * {b0,b1,b2,b3} d00 = umull2(a2, b0) d01 = umull2(a2, b1) d02 = umull2(a2, b2) d03 = umull2(a2, b3) # Lower part of a3 * {b0,b1,b2,b3} d10 = umull2(a3, b0) d11 = umull2(a3, b1) d12 = umull2(a3, b2) d13 = umull2(a3, b3) # Accumulate acc[8] = uadalp(acc[8], d00) acc[9] = uadalp(acc[9], d01) acc[10] = uadalp(acc[10], d02) acc[11] = uadalp(acc[11], d03) acc[12] = uadalp(acc[12], d10) acc[13] = uadalp(acc[13], d11) acc[14] = uadalp(acc[14], d12) acc[15] = uadalp(acc[15], d13) def _intrin_func(ins, outs): def _instr(): ib = tvm.tir.ir_builder.create() # Allocate a local buffer (possibly translates to registers) acc = ib.allocate("int32x4", 16, name="accs", scope="local") m = outs[0].shape[0] n = outs[0].shape[1] # Initialization for i in range(0, 16): acc[i] = tvm.tir.const(0, "int32x4") if unroll: for i in range(0, int(K // 16)): accumulation_loop(M, N, ins, acc, i) else: with ib.for_range(0, K // 16, name="i") as i: accumulation_loop(M, N, ins, acc, i) # Final accumulations # acc[4*r + c] contains the partial accumulations of element C[r][c] # # In particular: # acc[4*r] contains the partial sums of a[r,0:K].*b[0,0:K] -> (a,b,c,d) # acc[4*r+1] contains the partial sums of a[r, 0:K].*b[1,0:K] -> (e,f,g,h) # acc[4*r+2] contains the partial sums of a[r, 0:K].*b[2,0:K] -> (i,j,k,l) # acc[4*r+3] contains the partial sums of a[r, 0:K].*b[3,0:K] -> (m,n,o,p) # # Please note that 0<= r, c < 4 acc[0] = addp(acc[0], acc[1]) # (a+b, c+d, e+f, g+h) acc[1] = addp(acc[2], acc[3]) # (i+j, k+l, m+n, o+p) acc[0] = addp(acc[0], acc[1]) # (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) acc[4] = addp(acc[4], acc[5]) # (a+b, c+d, e+f, g+h) acc[5] = addp(acc[6], acc[7]) # (i+j, k+l, m+n, o+p) acc[4] = addp(acc[4], acc[5]) # (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) acc[8] = addp(acc[8], acc[9]) # (a+b, c+d, e+f, g+h) acc[9] = addp(acc[10], acc[11]) # (i+j, k+l, m+n, o+p) acc[8] = addp(acc[8], acc[9]) # (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) acc[12] = addp(acc[12], acc[13]) # (a+b, c+d, e+f, g+h) acc[13] = addp(acc[14], acc[15]) # (i+j, k+l, m+n, o+p) acc[12] = addp(acc[12], acc[13]) # (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) # Store the result if N > 3: out_0 = acc[0] out_1 = acc[4] out_2 = acc[8] out_3 = acc[12] elif N > 2: out_0 = tvm.tir.call_intrin("int32x3", "tir.reinterpret", acc[0]) out_1 = tvm.tir.call_intrin("int32x3", "tir.reinterpret", acc[4]) out_2 = tvm.tir.call_intrin("int32x3", "tir.reinterpret", acc[8]) out_3 = tvm.tir.call_intrin("int32x3", "tir.reinterpret", acc[12]) elif N > 1: out_0 = tvm.tir.call_intrin("int32x2", "tir.reinterpret", acc[0]) out_1 = tvm.tir.call_intrin("int32x2", "tir.reinterpret", acc[4]) out_2 = tvm.tir.call_intrin("int32x2", "tir.reinterpret", acc[8]) out_3 = tvm.tir.call_intrin("int32x2", "tir.reinterpret", acc[12]) else: out_0 = tvm.tir.call_intrin("int32", "tir.reinterpret", acc[0]) out_1 = tvm.tir.call_intrin("int32", "tir.reinterpret", acc[4]) out_2 = tvm.tir.call_intrin("int32", "tir.reinterpret", acc[8]) out_3 = tvm.tir.call_intrin("int32", "tir.reinterpret", acc[12]) ib.emit(outs[0].vstore([0, 0], out_0)) if M > 1: ib.emit(outs[0].vstore([1, 0], out_1)) if M > 2: ib.emit(outs[0].vstore([2, 0], out_2)) if M > 3: ib.emit(outs[0].vstore([3, 0], out_3)) return ib.get() # body, reset, update return _instr() buffer_params = {"offset_factor": 1} return te.decl_tensor_intrin( C.op, _intrin_func, binds={ A: a_buffer, B: b_buffer, C: c_buffer }, default_buffer_params=buffer_params, )
def gemm_acc_4x4_int8_int8_int32(dtype): """ Int8 4x4 matrix multiplication and accumulation using sdot/udot instructions. This function takes two arrays of int8 datatype -- A[4][4] and B[4][4] and produces a 4x4 matrix which is equal to A*B'. The pseudo code is as follows. .. code-block:: c void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){ for (int i = 0; i < 4; i++){ for (int j = 0; j < 4; j++){ for (int k = 0; k < 4; k++){ C[i][j] += A[i][k] * B[j][k] } } } Notes: * The tiling strategy is picked to maximize register usage. Parameters ---------- dtype : str, {"uint8", "int8"} Whether it works on unsigned int or signed int Returns ------- intrin : TensorIntrin The Arm TensorIntrin that can be used in tensorizing schedule """ assert dtype in ["uint8", "int8"] # This needs to be a variable number of "rows" since TVM # "thinks" I only need to compute one row because of # padding A = te.placeholder((te.var("rows"), 4), dtype, name="A") B = te.placeholder((4, 4), dtype, name="B") dtype_vec = dtype + "x16" k = te.reduce_axis((0, 4), name="k") C = te.compute( (te.var("rows"), 4), lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k), name="C", ) aa_buffer = tvm.tir.decl_buffer(A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]) bb_buffer = tvm.tir.decl_buffer(B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]) cc_buffer = tvm.tir.decl_buffer(C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]) llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot" def _intrin_func(ins, outs): def _instr(index): ib = tvm.tir.ir_builder.create() if index == 1: for i in range(0, 4): ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4"))) return ib.get() # Load all the elements of tile A. # vec_a = [a, b, c, d, # e, f, g, h, # l, m, n, o, # p, q, r, s]; vec_a = ins[0].vload([0, 0], dtype_vec) # Replicate 4 times the i-th row of A. For instance, # vec_a[0] = [a, b, c, d, # a, b, c, d, # a, b, c, d, # a, b, c, d,]; vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)] # Load all the elements of B. Remember that B # is transposed: # vec_b = [0, 4, 8, 12, # 1, 5, 9, 13, # 2, 6, 10, 14, # 3, 7, 11, 15,]; vec_b = ins[1].vload([0, 0], dtype_vec) # Execute the dot product for i in range(0, 4): vec_c = outs[0].vload([i, 0], "int32x4") # Compute the product between the i-th row of A # and all the rows of B. Remember that sdot/udot # subdive the input vectors in 16 elements # and then take the dot product among each group. # The result is stored in a int32x4 register # # For instance, for i=0, we have: # sdot(vec_aa[0], vec_b) = [a*0+b*4+c*8+d*12, # a*1+b*5+c*9+d*13, # a*2+b*6+c*10+d*14, # a*3+b*7+c*11+d*15] vdot = tvm.tir.call_llvm_intrin( "int32x4", llvm_intrin, tvm.tir.const(3, "uint32"), vec_c, vec_b, vec_aa[i], ) # Store the result ib.emit(outs[0].vstore([i, 0], vdot)) return ib.get() # body, reset, update return _instr(0), _instr(1), _instr(2) buffer_params = {"offset_factor": 1} return te.decl_tensor_intrin( C.op, _intrin_func, binds={ A: aa_buffer, B: bb_buffer, C: cc_buffer }, default_buffer_params=buffer_params, )
def test_max_index_simplify(): ck = RewriteChecker() x, y, z = te.var("x"), te.var("y"), te.var("z") flm = tvm.te.floormod fld = tvm.te.floordiv tdiv = tvm.tir.truncdiv tmod = tvm.tir.truncmod # const int bound ck.verify(tvm.te.max(tmod(x, 2), tmod(y, 2) + 10), tmod(y, 2) + 10) ck.verify(tvm.te.max(flm(x, 2), flm(y, 2) + 10), flm(y, 2) + 10) ck.verify(tvm.te.max(x + 1, x + 10), x + 10) ck.verify(tvm.te.max(x + 111, x + 10), x + 111) ck.verify(tvm.te.max(x + 1, x), x + 1) ck.verify(tvm.te.max(x, x + 2), x + 2) ck.verify(tvm.te.max(1 - x, 2 - x), 2 - x) ck.verify(tvm.te.max(3 - x, 2 - x), 3 - x) ck.verify(tvm.te.max(tvm.te.min(x, y), tvm.te.max(x, y)), tvm.te.max(x, y)) ck.verify(tvm.te.max(tvm.te.min(x, y), tvm.te.max(y, x)), tvm.te.max(x, y)) ck.verify(tvm.te.max(tvm.te.min(x, y), x), x) ck.verify(tvm.te.max(tvm.te.min(y, x), x), x) ck.verify(tvm.te.max(tvm.te.max(x, y), x), tvm.te.max(x, y)) ck.verify(tvm.te.max(tvm.te.max(x, y), y), tvm.te.max(x, y)) ck.verify(tvm.te.max(x, tvm.te.min(x, y)), x) ck.verify(tvm.te.max(x, tvm.te.min(y, x)), x) ck.verify(tvm.te.max(x, tvm.te.max(x, y)), tvm.te.max(x, y)) ck.verify(tvm.te.max(y, tvm.te.max(x, y)), tvm.te.max(x, y)) ck.verify(tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), y), tvm.te.max(tvm.te.max(x, y), z)) ck.verify( tvm.te.max(tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), x * 2), y), tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), x * 2), ) ck.verify( tvm.te.max(tvm.te.max(tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), x * 2), z * 2), y), tvm.te.max(tvm.te.max(tvm.te.max(tvm.te.max(x, y), z), x * 2), z * 2), ) ck.verify(tvm.te.max(tvm.te.min(x, y), tvm.te.min(x, z)), tvm.te.min(tvm.te.max(y, z), x)) ck.verify(tvm.te.max(tvm.te.min(x, y), tvm.te.min(z, x)), tvm.te.min(tvm.te.max(y, z), x)) ck.verify(tvm.te.max(tvm.te.min(y, x), tvm.te.min(x, z)), tvm.te.min(tvm.te.max(y, z), x)) ck.verify(tvm.te.max(tvm.te.min(y, x), tvm.te.min(z, x)), tvm.te.min(tvm.te.max(y, z), x)) ck.verify(tvm.te.max(y + x, z + x), tvm.te.max(y, z) + x) ck.verify(tvm.te.max(y + x, x + z), tvm.te.max(y, z) + x) ck.verify(tvm.te.max(x + y, z + x), tvm.te.max(y, z) + x) ck.verify(tvm.te.max(x + y, x + z), tvm.te.max(y, z) + x) ck.verify(tvm.te.max(x - y, x - z), x - tvm.te.min(y, z)) ck.verify(tvm.te.max(y - x, z - x), tvm.te.max(y, z) - x) ck.verify(tvm.te.max(tvm.te.max(x, 1), 10), tvm.te.max(x, 10)) ck.verify(tvm.te.max(tvm.te.max(x, 11), 10), tvm.te.max(x, 11)) ck.verify(tvm.te.max(x * 3, 9), tvm.te.max(x, 3) * 3) ck.verify(tvm.te.max(3 - x, 1), 3 - tvm.te.min(x, 2)) ck.verify(tvm.te.max(x * 2, 0), tvm.te.max(x, 0) * 2) ck.verify(tvm.te.max(0 - x * 2, 0), tvm.te.min(x, 0) * -2) ck.verify(tvm.te.max(x * (-2), -4), tvm.te.min(x, 2) * -2) ck.verify(tvm.te.max(x * (-2), 4), tvm.te.min(x, -2) * -2) ck.verify(tvm.te.max(x * (0), 4), 4) ck.verify(tvm.te.max(x * (0), -4), 0) # DivMod rules # truc div ck.verify(tvm.te.max(tdiv(x, 10), tdiv(y, 10)), tdiv(tvm.te.max(x, y), 10)) ck.verify(tvm.te.max(tdiv(x, (-10)), tdiv(y, (-10))), tdiv(tvm.te.min(x, y), (-10))) ck.verify(tvm.te.max(tdiv(x + 3, 4) * 4, x), tdiv(x + 3, 4) * 4) # floordiv ck.verify(tvm.te.max(fld(x, 10), fld(y, 10)), fld(tvm.te.max(x, y), 10)) ck.verify(tvm.te.max(fld(x, (-10)), fld(y, (-10))), fld(tvm.te.min(x, y), (-10))) ck.verify(tvm.te.max(fld(x + 3, 4) * 4, x), fld(x + 3, 4) * 4) ck.verify(tvm.te.max(fld(x, 4) * 4, x), x) ck.verify(tvm.te.max(x, fld(x, 4) * 4), x)
def test_cmp_simplify(): ck = RewriteChecker() x, y, z = te.var("x"), te.var("y"), te.var("z") flm = tvm.te.floormod fld = tvm.te.floordiv tdiv = tvm.tir.truncdiv tmod = tvm.tir.truncmod # const int bound ck.verify((tmod(x, 2) + 10).equal(0), tvm.tir.const(0, "bool")) ck.verify(tvm.tir.NE(tmod(x, 2) + 10, 0), tvm.tir.const(1, "bool")) ck.verify(tmod(x, 2) + 10 > 1, tvm.tir.const(1, "bool")) ck.verify(tmod(x, 2) + 10 <= 1, tvm.tir.const(0, "bool")) ck.verify(flm(x, 2) + 2 > 1, tvm.tir.const(1, "bool")) ck.verify(flm(x, 2) + 10 <= 1, tvm.tir.const(0, "bool")) ck.verify(x * 3 + 10 == 0, tvm.tir.const(0, "bool")) ck.verify(x * 3 + 10 != 0, tvm.tir.const(1, "bool")) # canonicalization ck.verify((x - 10).equal(0), x.equal(10)) ck.verify((10 - x).equal(0), x.equal(10)) ck.verify((x * y).equal(0), tvm.tir.Or(x.equal(0), y.equal(0))) # cmp bound ck.verify(x + y < x + z, y < z) ck.verify(x + y < z + x, y < z) ck.verify(y + x < x + z, y < z) ck.verify(y + x < z + x, y < z) ck.verify(y - x < z - x, y < z) ck.verify(x - y < x - z, z < y) ck.verify(x < z + x, tvm.tir.LT(0, z)) ck.verify(x < x + z, tvm.tir.LT(0, z)) ck.verify(100 < x + 1, tvm.tir.LT(99, x)) ck.verify(1 < 100 - x, tvm.tir.LT(x, 99)) ck.verify(x * 3 < y * 3, x < y) ck.verify(x * (-3) < y * (-3), y < x) ck.verify(x * 3 >= y * 3, y <= x) ck.verify(x * 4 >= 2, tvm.tir.LE(1, x)) ck.verify(x * 2 >= 50, tvm.tir.LE(25, x)) ck.verify(x * 4 <= 2, x <= 0) ck.verify((0 - x * 3) <= 0, tvm.tir.LE(0, x)) ck.verify((0 - x * 3) >= 0, tvm.tir.LE(x, 0)) ck.verify(2 * x <= 0, x <= 0) ck.verify(x * 2 >= 3, tvm.tir.LE(2, x)) ck.verify(x * 2 >= 2, tvm.tir.LE(1, x)) ck.verify(x * 2 >= 1, tvm.tir.LE(1, x)) ck.verify(x * 2 >= 0, tvm.tir.LE(0, x)) ck.verify(x * 2 >= -1, tvm.tir.LE(0, x)) ck.verify(x * 2 >= -2, tvm.tir.LE(-1, x)) ck.verify(x * 2 >= -3, tvm.tir.LE(-1, x)) ck.verify(x * 2 <= 3, tvm.tir.LE(x, 1)) ck.verify(x * 2 <= 2, tvm.tir.LE(x, 1)) ck.verify(x * 2 <= 1, tvm.tir.LE(x, 0)) ck.verify(x * 2 <= 0, tvm.tir.LE(x, 0)) ck.verify(x * 2 <= -1, tvm.tir.LE(x, -1)) ck.verify(x * 2 <= -2, tvm.tir.LE(x, -1)) ck.verify(x * 2 <= -3, tvm.tir.LE(x, -2)) ck.verify(x * (-2) >= 3, tvm.tir.LE(x, -2)) ck.verify(x * (-2) >= 2, tvm.tir.LE(x, -1)) ck.verify(x * (-2) >= 1, tvm.tir.LE(x, -1)) ck.verify(x * (-2) >= 0, tvm.tir.LE(x, 0)) ck.verify(x * (-2) >= -1, tvm.tir.LE(x, 0)) ck.verify(x * (-2) >= -2, tvm.tir.LE(x, 1)) ck.verify(x * (-2) >= -3, tvm.tir.LE(x, 1)) ck.verify(x * (-2) <= 3, tvm.tir.LE(-1, x)) ck.verify(x * (-2) <= 2, tvm.tir.LE(-1, x)) ck.verify(x * (-2) <= 1, tvm.tir.LE(0, x)) ck.verify(x * (-2) <= 0, tvm.tir.LE(0, x)) ck.verify(x * (-2) <= -1, tvm.tir.LE(1, x)) ck.verify(x * (-2) <= -2, tvm.tir.LE(1, x)) ck.verify(x * (-2) <= -3, tvm.tir.LE(2, x)) # DivMod rules # truc div ck.verify(tdiv(x, 2) < 3, x < 6) ck.verify(3 < tdiv(x, 2), tvm.tir.LT(7, x)) ck.verify(tdiv(x, 3) >= 0, tvm.tir.LE(-2, x)) ck.verify(tdiv(x, 2) >= 1, tvm.tir.LE(2, x)) ck.verify(tdiv(x, 2) >= 0, tvm.tir.LE(-1, x)) ck.verify(tdiv(x, 2) >= -1, tvm.tir.LE(-3, x)) ck.verify(tdiv(x, 2) <= 1, tvm.tir.LE(x, 3)) ck.verify(tdiv(x, 2) <= 0, tvm.tir.LE(x, 1)) ck.verify(tdiv(x, 2) <= -1, tvm.tir.LE(x, -2)) ck.verify(tdiv(x, 4) * 4 < x, tvm.tir.LT(0, tmod(x, 4))) ck.verify(tdiv(x, 4) * 4 >= x, tvm.tir.LE(tmod(x, 4), 0)) ck.verify(tdiv(x, 4) * 4 < x + y, tvm.tir.LT(0, tmod(x, 4) + y)) ck.verify(tdiv(x, 4) * 4 < x - y, tvm.tir.LT(y, tmod(x, 4))) ck.verify(tdiv(x + 2, 4) * 4 >= x, tvm.tir.LE(tmod(x + 2, 4), 2)) ck.verify(tdiv(x + 2, 4) * 4 >= x + y, tvm.tir.LE(tmod(x + 2, 4) + y, 2)) ck.verify(tdiv(x + 2, 4) * 4 >= x - y, tvm.tir.LE(tmod(x + 2, 4) + (-2), y)) # floor div ck.verify(fld(x, 2) < 3, x < 6) ck.verify(3 < fld(x, 2), tvm.tir.LT(7, x)) ck.verify(-3 < fld(x, 2), tvm.tir.LT(-5, x)) ck.verify(fld(x, 3) >= 0, tvm.tir.LE(0, x)) ck.verify(fld(x, 2) >= 1, tvm.tir.LE(2, x)) ck.verify(fld(x, 2) >= 0, tvm.tir.LE(0, x)) ck.verify(fld(x, 2) >= -1, tvm.tir.LE(-2, x)) ck.verify(fld(x, 2) <= 1, tvm.tir.LE(x, 3)) ck.verify(fld(x, 2) <= 0, tvm.tir.LE(x, 1)) ck.verify(fld(x, 2) <= -1, tvm.tir.LE(x, -1)) ck.verify(fld(x, 4) * 4 < x, tvm.tir.LT(0, flm(x, 4))) ck.verify(fld(x, 4) * 4 >= x, tvm.tir.LE(flm(x, 4), 0)) ck.verify(fld(x, 4) * 4 < x + y, tvm.tir.LT(0, flm(x, 4) + y)) ck.verify(fld(x, 4) * 4 < x - y, tvm.tir.LT(y, flm(x, 4))) ck.verify(fld(x + 2, 4) * 4 >= x, tvm.tir.LE(flm(x + 2, 4), 2)) ck.verify(fld(x + 2, 4) * 4 >= x + y, tvm.tir.LE(flm(x + 2, 4) + y, 2)) ck.verify(fld(x + 2, 4) * 4 >= x - y, tvm.tir.LE(flm(x + 2, 4) + (-2), y)) # End DivMod Rules ck.verify(tvm.te.min(x, 11) < 10, x < 10) ck.verify(tvm.te.min(x, 8) < 10, tvm.tir.const(1, "bool")) ck.verify(tvm.te.max(8, x) > 10, tvm.tir.LT(10, x)) ck.verify(x + 1 < tvm.te.max(8, x), x < 7) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 10), override=True) ck.analyzer.update(y, tvm.arith.ConstIntBound(-10, 0), override=True) ck.analyzer.update(z, tvm.arith.ConstIntBound(-5, 5), override=True) ck.verify(x < 11, tvm.tir.const(1, "bool")) ck.verify(x <= 10, tvm.tir.const(1, "bool")) ck.verify(z <= 5, tvm.tir.const(1, "bool")) ck.verify(x + y <= 10, tvm.tir.const(1, "bool")) ck.verify(x + y >= -10, tvm.tir.const(1, "bool")) ck.verify(z - 5 <= y + 10, tvm.tir.const(1, "bool")) ck.verify(tvm.tir.all(x > -1, z <= x + 5), tvm.tir.const(1, "bool")) ck.verify(x * y <= 0, tvm.tir.const(1, "bool")) ck.verify((x + 1) * (y - 1) < 0, tvm.tir.const(1, "bool")) ck.verify(y * y >= 0, tvm.tir.const(1, "bool")) ck.verify(x * 6 <= -3, tvm.tir.const(0, "bool")) ck.verify(tmod(y - 1, 3) == 0, tmod(y + (-1), 3) == 0)
def test_sub_index_simplify(): ck = RewriteChecker() x, y, z = te.var("x"), te.var("y"), te.var("z") a, b = tvm.tir.Any(), tvm.tir.Any() ck.verify(x + y - y, x) ck.verify(x + y - x, y) ck.verify(x - (y + x), 0 - y) ck.verify(x - (x + y), 0 - y) ck.verify(tvm.te.min(x, y) - x, tvm.te.min(0, y - x)) ck.verify(tvm.te.min(x, y) - y, tvm.te.min(x - y, 0)) ck.verify(tvm.te.max(x, y) - x, tvm.te.max(0, y - x)) ck.verify(tvm.te.max(x, y) - y, tvm.te.max(x - y, 0)) ck.verify(x - tvm.te.min(x, y), tvm.te.max(0, x - y)) ck.verify(y - tvm.te.min(x, y), tvm.te.max(y - x, 0)) ck.verify(x - tvm.te.max(x, y), tvm.te.min(0, x - y)) ck.verify(y - tvm.te.max(x, y), tvm.te.min(y - x, 0)) # mul co-efficient foldng ck.verify(x - x, 0) ck.verify(a - a, 0) ck.verify(a - b, a - b) ck.verify(x * y - x, x * (y + (-1))) ck.verify(x * y - 10 * x, x * (y + (-10))) ck.verify(y * x - x * z, x * (y - z)) ck.verify(y * x - z * x, x * (y - z)) ck.verify(x + 10 - 20, x + (-10)) # 4-operands pattern ck.verify((x + y) - (x + z), y - z) ck.verify((y + x) - (x + z), y - z) ck.verify((x + y) - (z + x), y - z) ck.verify((y + x) - (z + x), y - z) ck.verify(tvm.te.min(x + y, z) - x, tvm.te.min(y, z - x)) ck.verify(tvm.te.min(y + x, z) - x, tvm.te.min(y, z - x)) ck.verify(tvm.te.min(z, x + y) - x, tvm.te.min(z - x, y)) ck.verify(tvm.te.min(z, y + x) - x, tvm.te.min(z - x, y)) ck.verify(tvm.te.max(x + y, z) - x, tvm.te.max(y, z - x)) ck.verify(tvm.te.max(y + x, z) - x, tvm.te.max(y, z - x)) ck.verify(tvm.te.max(z, x + y) - x, tvm.te.max(z - x, y)) ck.verify(tvm.te.max(z, y + x) - x, tvm.te.max(z - x, y)) ck.verify(x - tvm.te.min(x + y, z), tvm.te.max(0 - y, x - z)) ck.verify(x - tvm.te.min(y + x, z), tvm.te.max(0 - y, x - z)) ck.verify(x - tvm.te.min(z, x + y), tvm.te.max(x - z, 0 - y)) ck.verify(x - tvm.te.min(z, y + x), tvm.te.max(x - z, 0 - y)) ck.verify(tvm.te.min(x, y) - tvm.te.min(y, x), 0) ck.verify(tvm.te.max(x, y) - tvm.te.max(y, x), 0) ck.verify(tvm.te.min(x, y) - tvm.te.min(x + 10, y + 10), -10) ck.verify(tvm.te.min(x + 10, y + 1) - tvm.te.min(x, y - 9), 10) # DivMod patterns # truc div tdiv = tvm.tir.truncdiv tmod = tvm.tir.truncmod ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.verify(x - tdiv(x, 3) * 3, tmod(x, 3)) ck.verify(tdiv(x + 5, 3) - tdiv(x, 3), tdiv(tmod(x, 3) + 5, 3)) ck.verify(tdiv(x + 5, 3) - tdiv(x + 1, 3), tdiv(tmod(x + 1, 3) + 4, 3)) ck.verify(y - tdiv(y, (-5)) * (-5), tmod(y, 5)) ck.verify(tdiv(y, 3) * 3 - y, 0 - tmod(y, 3)) ck.verify(y - tdiv(y - 6, 5) * 5, tmod(y + (-6), 5) + 6) ck.verify(tdiv(y - 6, 5) * 5 - y, (-6) - tmod(y + (-6), 5)) ck.verify(y - tdiv(y + z, 5) * 5, tmod(y + z, 5) - z) ck.verify(tdiv(y + z, 5) * 5 - y, z - tmod(y + z, 5)) ck.verify(y - tdiv(y - z, 5) * 5, tmod(y - z, 5) + z) ck.verify(tdiv(y - z, 5) * 5 - y, 0 - tmod(y - z, 5) - z) ck.verify(y * 3 - tdiv(y, 2) * 6, tmod(y, 2) * 3) ck.verify(tdiv(y, 3) * 6 - y * 2, tmod(y, 3) * (-2)) ck.verify(y * 5 - tdiv(y + z, 2) * 10, (tmod(y + z, 2) - z) * 5) ck.verify(y * 5 - tdiv(y - z, 2) * 10, (tmod(y - z, 2) + z) * 5) ck.verify(tdiv(y + z, 3) * 6 - y * 2, (z - tmod(y + z, 3)) * 2) ck.verify(tdiv(y - z, 3) * 6 - y * 2, (0 - tmod(y - z, 3) - z) * 2) ck.verify(5 * y - tdiv(y + z, 2) * 10, (tmod(y + z, 2) - z) * 5) ck.verify(5 * y - 10 * tdiv(y - z, 2), (tmod(y - z, 2) + z) * 5) ck.verify(6 * tdiv(y + z, 3) - y * 2, (z - tmod(y + z, 3)) * 2) ck.verify(tdiv(y - z, 3) * 6 - 2 * y, (0 - tmod(y - z, 3) - z) * 2) # floor div fld = tvm.te.floordiv flm = tvm.te.floormod ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), override=True) ck.analyzer.update(y, tvm.arith.ConstIntBound(-1000, 1000), override=True) ck.verify(x - fld(x, 3) * 3, flm(x, 3)) ck.verify(fld(x + 5, 3) - fld(x, 3), fld(flm(x, 3) + 5, 3)) ck.verify(fld(x + 5, 3) - fld(x + 2, 3), fld(flm(x + 2, 3), 3) + 1) ck.verify(fld(y, 3) * 3 - y, 0 - flm(y, 3)) ck.verify(y - fld(y - 6, 5) * 5, flm(y + (-6), 5) + 6) ck.verify(fld(y - 6, 5) * 5 - y, (-6) - flm(y + (-6), 5)) ck.verify(y - fld(y + z, 5) * 5, flm(y + z, 5) - z) ck.verify(fld(y + z, 5) * 5 - y, z - flm(y + z, 5)) ck.verify(y - fld(y - z, 5) * 5, flm(y - z, 5) + z) ck.verify(fld(y - z, 5) * 5 - y, 0 - flm(y - z, 5) - z) ck.verify(y * 3 - fld(y, 2) * 6, flm(y, 2) * 3) ck.verify(fld(y, 3) * 6 - y * 2, flm(y, 3) * (-2)) ck.verify(y * 5 - fld(y + z, 2) * 10, (flm(y + z, 2) - z) * 5) ck.verify(y * 5 - fld(y - z, 2) * 10, (flm(y - z, 2) + z) * 5) ck.verify(fld(y + z, 3) * 6 - y * 2, (z - flm(y + z, 3)) * 2) ck.verify(fld(y - z, 3) * 6 - y * 2, (0 - flm(y - z, 3) - z) * 2) ck.verify(5 * y - fld(y + z, 2) * 10, (flm(y + z, 2) - z) * 5) ck.verify(5 * y - 10 * fld(y - z, 2), (flm(y - z, 2) + z) * 5) ck.verify(6 * fld(y + z, 3) - y * 2, (z - flm(y + z, 3)) * 2) ck.verify(fld(y - z, 3) * 6 - 2 * y, (0 - flm(y - z, 3) - z) * 2)
def test_vector_simplify(): ck = RewriteChecker() x, y, z = te.var("x"), te.var("y"), te.var("z") # Add rules ck.verify(tvm.tir.Ramp(x, 1, 4) + tvm.tir.Ramp(y, 2, 4), tvm.tir.Ramp(x + y, 3, 4)) ck.verify(tvm.tir.Ramp(x, 1, 2) + y, tvm.tir.Ramp(x + y, 1, 2)) ck.verify(y + tvm.tir.Ramp(x, 1, 2), tvm.tir.Ramp(y + x, 1, 2)) ck.verify(y.astype("int32x2") + x.astype("int32x2"), (y + x).astype("int32x2")) ck.verify(tvm.tir.Broadcast(0, 4) + y, tvm.tir.Broadcast(y, 4)) ck.verify( tvm.tir.Ramp(x, 1, 4).astype("float32x4") + tvm.tir.Broadcast(0.0, 4), tvm.tir.Ramp(x, 1, 4).astype("float32x4"), ) # Sub rules ck.verify(tvm.tir.Ramp(x, 4, 4) - tvm.tir.Ramp(y, 2, 4), tvm.tir.Ramp(x - y, 2, 4)) ck.verify(tvm.tir.Ramp(x, 1, 2) - y, tvm.tir.Ramp(x - y, 1, 2)) ck.verify(y - tvm.tir.Ramp(x, 1, 2), tvm.tir.Ramp(y - x, -1, 2)) ck.verify(y.astype("int32x2") - x.astype("int32x2"), (y - x).astype("int32x2")) # Mul rules ck.verify(y.astype("int32x2") * x.astype("int32x2"), (y * x).astype("int32x2")) ck.verify(tvm.tir.Ramp(x, 4, 4) * 2, tvm.tir.Ramp(x * 2, 8, 4)) ck.verify(2 * tvm.tir.Ramp(x, 4, 4), tvm.tir.Ramp(x * 2, 8, 4)) ck.verify(tvm.tir.Broadcast(0, 4) * x, tvm.tir.Broadcast(0, 4)) ck.verify(tvm.tir.Broadcast(0.0, 4) * x, tvm.tir.Broadcast(0.0, 4)) ## DivMod rules tdiv = tvm.tir.truncdiv tmod = tvm.tir.truncmod # truc div ck.verify(tdiv(y.astype("int32x2"), x.astype("int32x2")), tdiv(y, x).astype("int32x2")) ck.verify(tdiv(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Ramp(tdiv(x, 2), 2, 4)) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) ck.verify(tdiv(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), (x).astype("int32x4")) ck.verify(tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8), tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8)) # truc mod ck.verify(tmod(y.astype("int32x2"), x.astype("int32x2")), tmod(y, x).astype("int32x2")) ck.verify(tmod(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Broadcast(tmod(x, 2), 4)) ck.verify(tmod(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), tvm.tir.Ramp(1, 1, 4)) ck.verify(tmod(tvm.tir.Ramp(x * 8 + 1, 15, 4), 8), tmod(tvm.tir.Ramp(1, 15, 4), 8)) # floor div fld = tvm.te.floordiv flm = tvm.te.floormod ck.analyzer.update(x, tvm.arith.ConstIntBound(-10, 1000), override=True) ck.verify(fld(y.astype("int32x2"), x.astype("int32x2")), fld(y, x).astype("int32x2")) ck.verify(fld(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Ramp(fld(x, 2), 2, 4)) ck.verify(fld(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), (x).astype("int32x4")) ck.verify(fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8), fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8)) ck.verify(fld(tvm.tir.Ramp(x, 8, 5), tvm.tir.Broadcast(4, 5)), tvm.tir.Ramp(fld(x, 4), 2, 5)) ck.verify( fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)), fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)), ) ck.verify(fld(tvm.tir.Ramp(x * 8, 1, 4), tvm.tir.Broadcast(4, 4)), tvm.tir.Broadcast(x * 2, 4)) ck.verify( fld(tvm.tir.Ramp(x * 8, 3, 4), tvm.tir.Broadcast(4, 4)), fld(tvm.tir.Ramp(x * 8, 3, 4), tvm.tir.Broadcast(4, 4)), ) ck.verify( fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), tvm.tir.Broadcast(4, 4)), fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), tvm.tir.Broadcast(4, 4)), ) ck.verify( fld(tvm.tir.Ramp(x * 4, 1, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Broadcast(fld(x, 16), 4) ) ck.verify( fld(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Broadcast(fld(x, 8), 4) ) ck.verify( fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), ) # Example negative case: x = 15; [60, 61, 62, 63, 64] / 64 = [0, 0, 0, 0, 1] ck.verify( fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), ) # Example negative case: x = 15; [63, 64, 65, 66] % 64 = [0, 1, 1, 1] ck.verify( fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), ) # Example negative case: x = 9; [63, 70, 77, 84] % 64 = [0, 1, 1, 1] # floor mod ck.verify(flm(y.astype("int32x2"), x.astype("int32x2")), flm(y, x).astype("int32x2")) ck.verify(flm(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Broadcast(flm(x, 2), 4)) ck.verify(flm(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), tvm.tir.Ramp(1, 1, 4)) ck.verify(flm(tvm.tir.Ramp(x * 8 + 1, 15, 4), 8), flm(tvm.tir.Ramp(1, 15, 4), 8)) ck.verify(flm(tvm.tir.Ramp(x, 8, 4), tvm.tir.Broadcast(4, 4)), tvm.tir.Broadcast(flm(x, 4), 4)) ck.verify( flm(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)), flm(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)), ) ck.verify(flm(tvm.tir.Ramp(x * 8, 1, 4), tvm.tir.Broadcast(4, 4)), tvm.tir.Ramp(0, 1, 4)) ck.verify( flm(tvm.tir.Ramp(x * 8, 1, 5), tvm.tir.Broadcast(4, 5)), flm(tvm.tir.Ramp(0, 1, 5), tvm.tir.Broadcast(4, 5)), ) ck.verify( flm(tvm.tir.Ramp(x * 8 + 7, 1, 4), tvm.tir.Broadcast(4, 4)), flm(tvm.tir.Ramp(3, 1, 4), tvm.tir.Broadcast(4, 4)), ) ck.verify( flm(tvm.tir.Ramp(x * 4, 1, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x * 4, 64), 1, 4) ) ck.verify( flm(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x * 8, 64), 2, 4) ) ck.verify( flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), ) # Example negative case: x = 15; [60, 61, 62, 63, 64] % 64 = [60, 61, 62, 63, 0] ck.verify( flm(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), flm(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), ) # Example negative case: x = 15; [63, 64, 65, 66] % 64 = [63, 0, 1, 2] ck.verify( flm(tvm.tir.Ramp(x * 2, 1, 8), tvm.tir.Broadcast(20, 8)), flm(tvm.tir.Ramp(x * 2, 1, 8), tvm.tir.Broadcast(20, 8)), ) # Example negative case: x = 9; [18, 19, 20, ..., 25] % 20 = [18, 19, 0, 1, ..., 5] ck.verify( flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), ) # Example negative case: x = 9; [63, 70, 77, 84] % 64 = [63, 6, 13, 20] # Min/Max rules vx = te.var("vx", dtype="int32x2") vc = te.var("vc", dtype="uint1") ck.verify( tvm.te.min(y.astype("int32x2"), x.astype("int32x2")), tvm.te.min(y, x).astype("int32x2") ) ck.verify( tvm.te.min(tvm.te.min(vx, y.astype("int32x2")), x.astype("int32x2")), tvm.te.min(vx, tvm.te.min(y, x).astype("int32x2")), ) ck.verify( tvm.te.max(y.astype("int32x2"), x.astype("int32x2")), tvm.te.max(y, x).astype("int32x2") ) ck.verify( tvm.te.max(tvm.te.max(vx, y.astype("int32x2")), x.astype("int32x2")), tvm.te.max(vx, tvm.te.max(y, x).astype("int32x2")), ) ## Logical rules ck.verify(y.astype("int32x2").equal(x.astype("int32x2")), (y.equal(x)).astype("uint1x2")) ck.verify( tvm.tir.NE(y.astype("int32x2"), (x.astype("int32x2"))), (tvm.tir.NE(y, x)).astype("uint1x2") ) ck.verify(y.astype("int32x2") > x.astype("int32x2"), (x < y).astype("uint1x2")) ck.verify(y.astype("int32x2") >= x.astype("int32x2"), (x <= y).astype("uint1x2")) ck.verify(y.astype("int32x2") < x.astype("int32x2"), (y < x).astype("uint1x2")) ck.verify(y.astype("int32x2") <= x.astype("int32x2"), (y <= x).astype("uint1x2")) ck.verify( tvm.tir.And(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")), (tvm.tir.And(y <= x, vc)).astype("uint1x2"), ) ck.verify( tvm.tir.Or(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")), (tvm.tir.Or(y <= x, vc)).astype("uint1x2"), )
def verify_batch_matmul(x_batch, y_batch, M, N, K, dynamic=False, debug=False): if not dynamic: x = te.placeholder((x_batch, M, K), name="x") y = te.placeholder((y_batch, N, K), name="y") dtype = x.dtype else: assert x_batch == y_batch or x_batch == 1 or y_batch == 1 batch_size = max(x_batch, y_batch) dynamic_batch_size = te.var("dynamic_batch_size") dynamic_M = te.var("dynamic_M") dynamic_N = te.var("dynamic_N") dynamic_K = te.var("dynamic_K") x = te.placeholder((dynamic_batch_size, dynamic_M, dynamic_K), name="x") y = te.placeholder((dynamic_batch_size, dynamic_N, dynamic_K), name="y") dtype = x.dtype # use memoize to pickle the test data for next time use @memoize("topi.tests.test_topi_batch_matmul") def get_ref_data(): a_np = np.random.uniform(size=(x_batch, M, K)).astype(dtype) b_np = np.random.uniform(size=(y_batch, N, K)).astype(dtype) c_np = tvm.topi.testing.batch_matmul(a_np, b_np) return (a_np, b_np, c_np) # get the test data a_np, b_np, c_np = get_ref_data() def check_device(target, dev): print("Running on target: %s" % target) with tvm.target.Target(target): fcompute, fschedule = tvm.topi.testing.dispatch( target, _batch_matmul_implement) out = fcompute(x, y) if not dynamic: s = fschedule([out]) out_shape = out.shape else: s = te.create_schedule(out.op) out_shape = (batch_size, M, N) if debug: print(tvm.lower(s, [x, y, out], simple_mode=True)) a = tvm.nd.array(a_np, dev) b = tvm.nd.array(b_np, dev) c = tvm.nd.array(np.zeros(get_const_tuple(out_shape), dtype=dtype), dev) f = tvm.build(s, [x, y, out], target, name="dense") f(a, b, c) tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5) for target, dev in tvm.testing.enabled_targets(): if dynamic and (target == "cuda" or target == "nvptx"): print("Dynamic batch matmul test is skippped on %s" % target) continue check_device(target, dev)
def test_stmt_constructor(): v = te.var("aa") buffer_var = te.var("buf", dtype="handle") nop = tvm.tir.Evaluate(1) x = tvm.tir.LetStmt(v, 1, tvm.tir.Evaluate(1)) assert isinstance(x, tvm.tir.LetStmt) assert x.var == v assert x.value.value == 1 assert isinstance(x.body, tvm.tir.Evaluate) x = tvm.tir.AttrStmt(v == 1, "xx", 1, tvm.tir.Evaluate(1)) assert isinstance(x, tvm.tir.AttrStmt) assert x.value.value == 1 x = tvm.tir.AssertStmt(tvm.tir.const(1, "uint1"), tvm.runtime.convert("hellow"), nop) assert isinstance(x, tvm.tir.AssertStmt) assert x.body == nop x = tvm.tir.For(te.var("x"), 0, 10, 0, 0, nop) assert isinstance(x, tvm.tir.For) assert x.min.value == 0 assert x.extent.value == 10 assert x.body == nop x = tvm.tir.Store(buffer_var, 1, 10, tvm.tir.const(1, "uint1")) assert isinstance(x, tvm.tir.Store) assert x.buffer_var == buffer_var assert x.index.value == 10 assert x.value.value == 1 tensor = te.placeholder((), dtype="float32") x = tvm.tir.Provide(tensor.op, 0, 10, []) assert isinstance(x, tvm.tir.Provide) assert x.value_index == 0 assert x.value.value == 10 x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), nop) assert isinstance(x, tvm.tir.Allocate) assert x.dtype == "float32" assert x.buffer_var == buffer_var assert x.body == nop x = tvm.tir.AttrStmt(buffer_var, "xyz", 1, nop) assert isinstance(x, tvm.tir.AttrStmt) assert x.node == buffer_var assert x.attr_key == "xyz" assert x.body == nop x = tvm.tir.Free(buffer_var) assert isinstance(x, tvm.tir.Free) assert x.buffer_var == buffer_var x = tvm.tir.Realize(None, 0, "float", [], tvm.tir.const(1, "uint1"), nop) assert isinstance(x, tvm.tir.Realize) assert x.body == nop x = tvm.tir.IfThenElse(tvm.tir.const(1, "uint1"), tvm.tir.Evaluate(11), nop) assert isinstance(x, tvm.tir.IfThenElse) assert x.then_case.value.value == 11 assert x.else_case == nop x = tvm.tir.Prefetch(None, 1, "float32", []) assert isinstance(x, tvm.tir.Prefetch) assert x.value_index == 1
def check_packed_func(target="llvm"): ib = tvm.tir.ir_builder.create() m = n = k = 16 # # Prepare buffer for a, b and c: # a = te.placeholder((m, k), name="a", dtype="float64") b = te.placeholder((k, n), name="b", dtype="float64") k = te.reduce_axis((0, k), name="k") c = te.compute((m, n), lambda i, j: te.sum(a[i, k] * b[k, j], axis=k), name="c") a_buffer = tvm.tir.decl_buffer(a.shape, a.dtype, name="a_buffer", offset_factor=1, strides=[te.var("s1"), 1]) b_buffer = tvm.tir.decl_buffer(b.shape, b.dtype, name="b_buffer", offset_factor=1, strides=[te.var("s2"), 1]) c_buffer = tvm.tir.decl_buffer(c.shape, c.dtype, name="c_buffer", offset_factor=1, strides=[te.var("s3"), 1]) with ib.for_range(0, 10, "i", kind="parallel"): ib.emit( tvm.tir.call_packed("tvm.test_matmul", a_buffer, b_buffer, c_buffer)) stmt = ib.get() # Construct a valid IRModule to be lowered: mod = tvm.IRModule.from_expr( tvm.tir.PrimFunc([a_buffer, b_buffer, c_buffer], stmt)) target = tvm.target.Target(target) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod) mod = tvm.tir.transform.Apply( lambda f: f.with_attr("global_symbol", "main"))(mod) mod = tvm.tir.transform.MakePackedAPI()(mod) # Do the lowering: mod = tvm.tir.transform.LowerTVMBuiltin()(mod) # Get the PrimFunc from module: prim_func = mod.functions.items()[0][1] node = prim_func.body # Recursively visit PrimFunc until we meet the for-loop: while isinstance(node, (tvm.tir.AssertStmt, tvm.tir.LetStmt, tvm.tir.AttrStmt)): node = node.body # For-loop: assert isinstance(node, tvm.tir.stmt.For) # # let stack_tcode = tir.tvm_stack_alloca("arg_tcode", 4) # alloca_tcode = node.body assert isinstance(alloca_tcode, tvm.tir.LetStmt) expected_value = tvm.tir.call_intrin("handle", tvm.ir.Op.get("tir.tvm_stack_alloca"), "arg_tcode", 4) expected_var = alloca_tcode.var expected_stmt = tvm.tir.LetStmt(expected_var, expected_value, alloca_tcode.body) tvm.ir.assert_structural_equal(alloca_tcode, expected_stmt, map_free_vars=True) # # let stack_value = tir.tvm_stack_alloca("arg_value", 4) # alloca_value = alloca_tcode.body assert isinstance(alloca_value, tvm.tir.LetStmt) expected_value = tvm.tir.call_intrin("handle", tvm.ir.Op.get("tir.tvm_stack_alloca"), "arg_value", 4) expected_var = alloca_value.var expected_stmt = tvm.tir.LetStmt(expected_var, expected_value, alloca_value.body) tvm.ir.assert_structural_equal(alloca_value, expected_stmt, map_free_vars=True) # # let stack_array = tir.tvm_stack_alloca("array", 3) # alloca_array = alloca_value.body assert isinstance(alloca_array, tvm.tir.LetStmt) expected_value = tvm.tir.call_intrin("handle", tvm.ir.Op.get("tir.tvm_stack_alloca"), "array", 3) expected_var = alloca_array.var expected_stmt = tvm.tir.LetStmt(expected_var, expected_value, alloca_array.body) tvm.ir.assert_structural_equal(alloca_array, expected_stmt, map_free_vars=True) # # let stack_shape = tir.tvm_stack_alloca("shape", 12) # alloca_shape = alloca_array.body assert isinstance(alloca_shape, tvm.tir.LetStmt) expected_value = tvm.tir.call_intrin("handle", tvm.ir.Op.get("tir.tvm_stack_alloca"), "shape", 12) expected_var = alloca_shape.var expected_stmt = tvm.tir.LetStmt(expected_var, expected_value, alloca_shape.body) tvm.ir.assert_structural_equal(alloca_shape, expected_stmt, map_free_vars=True)
def smlal_int16_int32(): """ Intrinsic to be used in order to load two int16x8 vectors and multiply them together through a pair of smlal/smlal2 instructions. The pseudo-code for the algorithm is as follows: vec_a = vload(A, "int16x8") vec_b = vload(B, "int16x8") vec_c[0:4] += vec_a[0:4]*vec_b[0:4] // -> smlal instruction vec_c[4:8] += vec_a[4:8]*vec_b[4:8] // -> smlal2 instruction So we load a single int16x8 vector and we accumulate its lower (0:4) and higher part separately. """ int16_lanes = 8 A = te.placeholder((int16_lanes, ), dtype="int16", name="A") B = te.placeholder((int16_lanes, 1), dtype="int16", name="B") C = te.compute( (int16_lanes, ), lambda i: A[i].astype("int32") * B[i, 0].astype("int32"), name="C", ) a_buffer = tvm.tir.decl_buffer(A.shape, dtype="int16", name="a_buffer", offset_factor=1, strides=[1]) b_buffer = tvm.tir.decl_buffer( B.shape, dtype="int16", name="b_buffer", offset_factor=1, strides=[te.var("sb"), 1], ) c_buffer = tvm.tir.decl_buffer( C.shape, dtype="int32", name="c_buffer", offset_factor=1, strides=[1], ) def _intrin_func(ins, outs): def _instr(index): ib = tvm.tir.ir_builder.create() if index == 1: ib.emit(outs[0].vstore(0, tvm.tir.const(0, "int32x8"))) return ib.get() vec_a = ins[0].vload([0], "int16x8") vec_b = ins[1].vload([0, 0], "int16x8") inst = "llvm.aarch64.neon.smull" # Higher part of the vector vec_c_h = outs[0].vload([4], "int32x4") vec_a_h = tvm.tir.call_intrin("int16x4", "tir.vectorhigh", vec_a) vec_b_h = tvm.tir.call_intrin("int16x4", "tir.vectorhigh", vec_b) vmull_h = tvm.tir.call_llvm_pure_intrin("int32x4", inst, tvm.tir.const(2, "uint32"), vec_a_h, vec_b_h) vec_out_h = vec_c_h + vmull_h # Lower part of the vector vec_c_l = outs[0].vload([0], "int32x4") vec_a_l = tvm.tir.call_intrin("int16x4", "tir.vectorlow", vec_a) vec_b_l = tvm.tir.call_intrin("int16x4", "tir.vectorlow", vec_b) vmull_l = tvm.tir.call_llvm_pure_intrin("int32x4", inst, tvm.tir.const(2, "uint32"), vec_a_l, vec_b_l) vec_out_l = vec_c_l + vmull_l # Combine higher and lower part in a single int32x8 vector to store # (this will require two different store instructions, since the # length of a NEON vector is fixed at 128 vec_out = tvm.tir.call_intrin("int32x8", "tir.vectorcombine", vec_out_l, vec_out_h) ib.emit(outs[0].vstore(0, vec_out)) return ib.get() # body, reset, update return _instr(0), _instr(1), _instr(2) buffer_params = {"offset_factor": 1} return te.decl_tensor_intrin( C.op, _intrin_func, binds={ A: a_buffer, B: b_buffer, C: c_buffer }, default_buffer_params=buffer_params, )
def test_let_simplify(): ck = RewriteChecker() x, y = te.var("x"), te.var("y") z = tvm.tir.Let(x, 1, x + 1) ck.verify(z + z, 4)
def gemm_acc_2x2_int8_int8_int32(dtype): """ Int8 2x2 matrix multiplication using smmla/ummla instructions This function takes two arrays of int8 datatype -- A[2][8] and B[2][8] and produces a 2x2 matrix which is equal to A*B' The pseudo code is as follows. .. code-block:: c void mmla_2x2_int8_int8_int32(int8 A[2][8], int8 B[2][8], int32 C[2][2]){ for (int i = 0; i < 2; i++){ for (int j = 0; j < 2; j++){ for (int k = 0; k < 8; k++){ C[i][j] += A[i][k] * B[j][k] } } } Parameters ---------- dtype : str, {"uint8", "int8"} Whether it works on unsigned int or signed int Returns ------- intrin : TensorIntrin The Arm TensorIntrin that can be used in tensorizing schedule """ assert dtype in ["uint8", "int8"] A = te.placeholder((2, 8), dtype, name="A") B = te.placeholder((2, 8), dtype, name="B") dtype_vec = dtype + "x16" k = te.reduce_axis((0, 8), name="k") C = te.compute( (2, 2), lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k), name="C", ) aa_buffer = tvm.tir.decl_buffer(A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]) bb_buffer = tvm.tir.decl_buffer(B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]) cc_buffer = tvm.tir.decl_buffer(C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]) llvm_intrin = "llvm.aarch64.neon.smmla" if dtype == "int8" else "llvm.aarch64.neon.ummla" def _intrin_func(ins, outs): def _instr(index): ib = tvm.tir.ir_builder.create() if index == 1: ib.emit(outs[0].vstore([0, 0], tvm.tir.const(0, "int32x4"))) return ib.get() # Load in vec_a the two rows of A # vec_a = [a, b, c, d, e, f, g, h; # i, j, k, l, m, n, o, p,] vec_a = ins[0].vload([0, 0], dtype_vec) # Load in vec_b the two rows of B # vec_b = [0, 2, 4, 6, 8, 10, 12, 14; # 1, 3, 5, 7, 9, 11, 13, 14,] vec_b = ins[1].vload([0, 0], dtype_vec) # Execute the matrix multiplication via (s/u)mmla: # vec_c = [a*0 + b*2 + c*4 + d*6 +e*8 + f*10 + g*12 + h*14; # a*1 + b*3 + c*5 + d*7 +e*9 + f*11 + g*13 + h*15; # i*0 + j*2 + k*4 + l*6 +m*8 + n*10 + o*12 + p*14; # i*1 + j*3 + k*5 + l*7 +m*9 + n*11 + o*13 + p*15] vec_c = outs[0].vload([0, 0], "int32x4") vmmla = tvm.tir.call_llvm_intrin( "int32x4", llvm_intrin, tvm.tir.const(3, "uint32"), vec_c, vec_a, vec_b, ) # Store the result ib.emit(outs[0].vstore([0, 0], vmmla)) return ib.get() # body, reset, update return _instr(0), _instr(1), _instr(2) buffer_params = {"offset_factor": 1} return te.decl_tensor_intrin( C.op, _intrin_func, binds={ A: aa_buffer, B: bb_buffer, C: cc_buffer }, default_buffer_params=buffer_params, )
def rnn_matexp(): n_num_step = 128 n_num_hidden = 1152 n_batch_size = 4 detect_global_barrier = DETECT_GLOBAL_BARRIER num_step = te.var("num_step") num_hidden = tvm.runtime.convert(n_num_hidden) batch_size = tvm.runtime.convert(n_batch_size) num_thread_y = 8 num_thread_x = 16 * 3 num_sm = 24 Whh = te.placeholder((num_hidden, num_hidden), name="Whh") s_init = te.compute((1, batch_size, num_hidden), lambda _, i, j: 1.0, name="init") s_state = te.placeholder((num_step, batch_size, num_hidden)) kh = te.reduce_axis((0, num_hidden), name="kh") s_update = te.compute( (num_step, batch_size, num_hidden), lambda t, i, j: te.sum(s_state[t - 1, i, kh] * Whh[kh, j], axis=kh), name="update", ) s_scan = tvm.te.scan(s_init, s_update, s_state) # schedule s = te.create_schedule(s_scan.op) CL = s_update SS = s.cache_read(s_state, "shared", [CL]) SL = s.cache_read(SS, "local", [CL]) WhhL = s.cache_read(Whh, "local", [CL]) ko, ki = s[CL].split(s[CL].op.reduce_axis[0], nparts=num_thread_y) CLF = s.rfactor(CL, ko) block_x = te.thread_axis((0, num_sm), "blockIdx.x") thread_x = te.thread_axis((0, num_thread_x), "threadIdx.x") thread_y = te.thread_axis((0, num_thread_y), "threadIdx.y") if PERSIST_KERNEL: s[s_scan.op].env_threads([block_x, thread_y, thread_x]) bx, xi = s[s_init].split(s_init.op.axis[2], nparts=num_sm) tx, xi = s[s_init].split(xi, nparts=num_thread_x) s[s_init].bind(bx, block_x) s[s_init].bind(tx, thread_x) bx, xi = s[s_update].split(s[CL].op.axis[2], nparts=num_sm) tx, xi = s[s_update].split(xi, nparts=num_thread_x) s[s_update].bind(bx, block_x) s[s_update].bind(tx, thread_x) s[CL].bind(s[CL].op.reduce_axis[0], thread_y) s[CLF].compute_at(s[CL], s[CL].op.reduce_axis[0]) # Duplicate store predicate. s[CL].set_store_predicate(thread_y.equal(0)) if PERSIST_KERNEL: s[WhhL].compute_at(s[s_scan], thread_x) s[WhhL].unroll(WhhL.op.axis[0]) else: s[WhhL].compute_at(s[CLF], CLF.op.axis[3]) kr, ki = s[CLF].split(CLF.op.reduce_axis[0], nparts=1) ko, ki = s[CLF].split(ki, factor=4) s[SS].compute_at(s[CLF], kr) s[SL].compute_at(s[CLF], ko) xo, xi = s[SS].split(SS.op.axis[2], factor=num_thread_x * num_thread_y * 3) ty, xi = s[SS].split(xi, nparts=num_thread_y) tx, xi = s[SS].split(xi, nparts=num_thread_x) s[SS].bind(ty, thread_y) s[SS].bind(tx, thread_x) def check_device(target): with tvm.transform.PassContext( config={ "tir.UnrollLoop": { "auto_max_step": 128, }, "tir.detect_global_barrier": detect_global_barrier, }): f = tvm.build(s, [s_scan, Whh], target) dev = tvm.gpu(0) if target == "cuda" else tvm.cl(0) # launch the kernel. res_np = np.zeros( (n_num_step, n_batch_size, n_num_hidden)).astype("float32") Whh_np = np.zeros((n_num_hidden, n_num_hidden)).astype("float32") Whh_np[:] = 2.0 / n_num_hidden Whh_np[:, n_num_hidden // 2:] = 0 res_a = tvm.nd.array(res_np, dev) Whh_a = tvm.nd.array(Whh_np, dev) # Skip first pass as it is compilation f(res_a, Whh_a) dev.sync() # measure time cost of second step. tstart = time.time() f(res_a, Whh_a) dev.sync() tgap = time.time() - tstart print("Time cost=%g" % tgap) # correctness if not SKIP_CHECK: res_gpu = res_a.asnumpy() res_cmp = np.ones_like(res_np).astype("float64") Whh_np = Whh_np.astype("float64") for t in range(1, n_num_step): res_cmp[t][:] = np.dot(res_cmp[t - 1], Whh_np) for i in range(n_num_step): for j in range(n_num_hidden): if abs(res_cmp[i, 0, j] - res_gpu[i, 0, j]) > 1e-5: print("%d, %d: %g vs %g" % (i, j, res_cmp[i, 0, j], res_gpu[i, 0, j])) tvm.testing.assert_allclose(res_gpu, res_cmp, rtol=1e-3) check_device("cuda")
def dot_int8_int8_int32_neon(): """ Int8 dot product using vmlal instructions .. code-block:: c void dot_int8_int8_int32(int8 data[4], int8 kernel[4][4], int32 output[4]){ for (int i = 0; i < 4; i++){ out[i] = 0; for (int k = 0; k < 4; k++){ out[i] += data[k] * kernel[i][k] } } } We use the smull and saddlp instructions to compute the dot product. smull : int8x16 -> int8x16 -> int16x8 elementwise multiplication saddlp: int16x8 -> int32x4 pairwise addition of elements Data is broadcast across the register int8 elements | data | data | | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | smull int8 elements | kernel[i] | kernel[i+1] | | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | = int16 elements | data * kernel[i] | data * kernel[i+1] | | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | saddlp = int32 elements | partial sum(data * kernel[i]) | partial sum(data * kernel[i+1]) | | 0 | 1 | 2 | 3 | We apply the above kernel twice and use addp to compute the second set of pairwise additions int32 elements (narrowed for so they fit on a line) | psum d*k[i] | psum d*k[i+1] | | psum d*k[i+2] | psum d*k[i+3] | | 0 | 1 | 2 | 3 | addp | 4 | 5 | 6 | 7 | = |sum d*ki |sum d*ki1|sum d*ki2|sum d*ki3| | 0 | 1 | 2 | 3 | """ int32_lanes = 4 # 4 int32 lanes = 128 num_int8_elements = 4 # 4 int8 elements in int32 data = te.placeholder((num_int8_elements, ), dtype="int8", name="data") kernel = te.placeholder((int32_lanes, num_int8_elements), dtype="int8", name="kernel") k = te.reduce_axis((0, num_int8_elements), name="k") C = te.compute( (int32_lanes, ), lambda i: te.sum( data[k].astype("int32") * kernel[i, k].astype("int32"), axis=k), name="C", ) a_buffer = tvm.tir.decl_buffer(data.shape, dtype="int8", name="a_buffer", offset_factor=1, strides=[1]) b_buffer = tvm.tir.decl_buffer(kernel.shape, dtype="int8", name="b_buffer", offset_factor=1, strides=[te.var("ldw"), 1]) def _intrin_func(ins, outs): def _instr(index): int_8xl = "int8x8" int_32xl = "int32x4" ib = tvm.tir.ir_builder.create() if index == 1: ib.emit(outs[0].vstore(0, tvm.tir.const(0, int_32xl))) return ib.get() # this broadcasts data to the vector size a_int8 = ins[0].vload([0], "int8x4") re_int32 = tvm.tir.call_intrin("int32", "tir.reinterpret", a_int8) vec_ai32 = re_int32.astype("int32x2") vec_a = tvm.tir.call_intrin(int_8xl, "tir.reinterpret", vec_ai32) vec_b = ins[1].vload([0, 0], "int8x16") def pairwise_add_mul(extract_half): vec_b_half = tvm.tir.call_intrin("int8x8", extract_half, vec_b) multiply = tvm.tir.call_llvm_pure_intrin( "int16x8", "llvm.aarch64.neon.smull.v8i16", # saturating pairwise multiplication tvm.tir.const(2, "uint32"), vec_a, vec_b_half, ) pairwise_reduction = tvm.tir.call_llvm_pure_intrin( "int32x4", "llvm.aarch64.neon.saddlp.v4i32.v8i16", tvm.tir.const(1, "uint32"), multiply, ) return pairwise_reduction pair_1 = pairwise_add_mul("tir.vectorlow") pair_2 = pairwise_add_mul("tir.vectorhigh") quad_reduction = tvm.tir.call_llvm_pure_intrin( "int32x4", "llvm.aarch64.neon.addp.v4i32", tvm.tir.const(2, "uint32"), pair_1, pair_2, ) if index == 0: ib.emit(outs[0].vstore(0, quad_reduction)) else: ib.emit(outs[0].vstore( 0, quad_reduction + outs[0].vload([0], int_32xl))) return ib.get() # body, reset, update return _instr(0), _instr(1), _instr(2) buffer_params = {"offset_factor": 1} return te.decl_tensor_intrin( C.op, _intrin_func, binds={ data: a_buffer, kernel: b_buffer }, default_buffer_params=buffer_params, )
def test_domain_touched(): i = te.var("i") j = te.var("j") n = tvm.runtime.convert(100) m = te.var("m") a = tvm.tir.decl_buffer((n, m), name="a") b = tvm.tir.decl_buffer((n, m), name="b") ir = tvm.tir.For( i, 0, n, 0, 0, tvm.tir.For( j, 0, m, 0, 0, tvm.tir.BufferStore( a, tvm.tir.BufferLoad(b, [i - 1, j + 1]) + tvm.tir.BufferLoad(a, [i - 1, j - 1]), [i, j], ), ), ) a_domain_r = tvm.arith._ffi_api.DomainTouched(ir, a, True, False) assert a_domain_r[0].min.value == -1 assert a_domain_r[0].extent.value == 100 assert a_domain_r[1].min.value == -1 assert a_domain_r[1].extent.name == "m" a_domain_w = tvm.arith._ffi_api.DomainTouched(ir, a, False, True) assert a_domain_w[0].min.value == 0 assert a_domain_w[0].extent.value == 100 assert a_domain_w[1].min.value == 0 assert a_domain_w[1].extent.name == "m" a_domain_rw = tvm.arith._ffi_api.DomainTouched(ir, a, True, True) assert a_domain_rw[0].min.value == -1 assert a_domain_rw[0].extent.value == 101 assert a_domain_rw[1].min.value == -1 assert isinstance(a_domain_rw[1].extent, tvm.tir.Add) assert a_domain_rw[1].extent.a.name == "m" assert a_domain_rw[1].extent.b.value == 1 b_domain_r = tvm.arith._ffi_api.DomainTouched(ir, b, True, False) assert b_domain_r assert b_domain_r[0].min.value == -1 assert b_domain_r[0].extent.value == 100 assert b_domain_r[1].min.value == 1 assert b_domain_r[1].extent.name == "m" b_domain_w = tvm.arith._ffi_api.DomainTouched(ir, b, False, True) assert isinstance(b_domain_w, tvm.container.Array) assert len(b_domain_w) == 0
def gemm_acc_nx16_int8_int8_int32(dtype, rows): """ Int8 nx16 matrix multiplication and accumulation using sdot/udot instructions This function takes two arrays of int8 datatype -- A[n][4] and B[4][16] and produces a rowsx16 matrix which is equal to A*B' The pseudo code is as follows. .. code-block:: c void mmla_nx16_int8_int8_int32(int8 A[n][16], int8 B[4][16][4], int32 output[n][16]){ for (int i = 0; i < n; i++){ for (int j = 0; j < 16; j++){ for (int k = 0; k < 16; k++){ out[i][j] += A[i][k] * B[k//4][j][k%4] } } } } Notes: * The tile size of B is 16x4. Since the reduction variable k moves between 0 and 16 we need 4 tiles of B to compute a single row of the output. The first 4 values of k will be fetched from B[0][j][k], the second batch of 4 from B[1][j][k] and so on * The tiling strategy is picked to maximize register usage. Parameters ---------- dtype : str, {"uint8", "int8"} Whether it works on unsigned int or signed int rows : int Number of of the output rows "n" Returns ------- intrin : TensorIntrin The Arm TensorIntrin that can be used in tensorizing schedule """ assert dtype in ["uint8", "int8"] A = te.placeholder((rows, 16), dtype, name="A") B = te.placeholder((4, 16, 4), dtype, name="B") dtype_vec = dtype + "x16" idxm = tvm.tir.indexmod k = te.reduce_axis((0, 16), name="k") C = te.compute( (rows, 16), lambda i, j: te.sum(A[i, k].astype("int32") * B[ k // 4, j, idxm(k, 4)].astype("int32"), axis=k), name="C", ) aa_buffer = tvm.tir.decl_buffer(A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]) bb_buffer = tvm.tir.decl_buffer( B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb0"), te.var("sb1"), 1], ) cc_buffer = tvm.tir.decl_buffer(C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]) llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot" def _intrin_func(ins, outs): def _instr(index): ib = tvm.tir.ir_builder.create() if index == 1: for i in range(0, rows): ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x16"))) return ib.get() # Iterate on the number of rows of the output for k in range(0, rows): # Load 16 elements of A # vec_a = [a, b, c, d, e, f, g, h, l, m, n, o, p, q, r, s]; vec_a = ins[0].vload([k, 0], dtype_vec) # Iterate over each of the 4 rowsx4 tiles of the output for j in range(0, 4): # Accumulate over each of the 4 (16x4) tiles contained in B for i in range(0, 4): # Replicate a single 4-element group of A (A[k, i:i+4]) vec_aa = select_word(vec_a, i, dtype_vec) # Load 4 rows (each rows with 4 elements) from B (B[i:i+4, j:j+4]) # vec_b = [0, 16, 32, 48, # 1, 17, 33, 49, # 2, 18, 34, 50, # 3, 19, 35, 51,]; vec_b = ins[1].vload([i, 4 * j, 0], dtype_vec) # Accumulate in the correct part of the output vec_c = outs[0].vload([k, 4 * j], "int32x4") # Compute the dot product between the rowsx4 tile # from A and the 4x4 tile from B # # For instance, for i=0, we have: # sdot(vec_aa[0], vec_b) = [a*0+b*16+c*32+d*48, # a*1+b*17+c*33+d*49, # a*2+b*18+c*34+d*50, # a*3+b*19+c*35+d*51] vdot = tvm.tir.call_llvm_intrin( "int32x4", llvm_intrin, tvm.tir.const(3, "uint32"), vec_c, vec_b, vec_aa, ) ib.emit(outs[0].vstore([k, 4 * j], vdot)) return ib.get() # body, reset, update return _instr(0), _instr(1), _instr(2) buffer_params = {"offset_factor": 1} return te.decl_tensor_intrin( C.op, _intrin_func, binds={ A: aa_buffer, B: bb_buffer, C: cc_buffer }, default_buffer_params=buffer_params, )
def intrin_gemm_MxKxN(M, K, N, in_dtype, out_dtype): """Defines a SIMD-accelerated transposed matmul.""" # we generate a unique ID for every intrinsic definition, to prevent name # collisions in the generated source (e.g., if there are multiple operators # in the same module that use the same intrinsic) # # TODO(weberlo, areusch): to cut down on memory usage, we should cache each intrinsic # instantiation and include it only once, eliminating the need for unique # IDs UNIQ_ID_LEN = 8 uniq_id = "".join(random.choices(string.ascii_uppercase, k=UNIQ_ID_LEN)) if isinstance(M, tvm.tir.IntImm): M = M.value if isinstance(K, tvm.tir.IntImm): K = K.value if isinstance(N, tvm.tir.IntImm): N = N.value assert K % 4 == 0 # TODO(weberlo, areusch): support more dtypes? assert in_dtype == "int8" assert out_dtype == "int32" A = te.placeholder((M, K), name="a", dtype=in_dtype) B = te.placeholder((N, K), name="b", dtype=in_dtype) k = te.reduce_axis((0, K), name="k") C = te.compute( (M, N), lambda i, j: te.sum(A[i, k].astype(out_dtype) * B[j, k].astype(out_dtype), axis=k), name="c", ) A_buf = tvm.tir.decl_buffer( A.shape, A.dtype, name="A", offset_factor=1, strides=[te.var("A_s"), 1] ) B_buf = tvm.tir.decl_buffer( B.shape, B.dtype, name="B", offset_factor=1, strides=[te.var("B_s"), 1] ) C_buf = tvm.tir.decl_buffer( C.shape, C.dtype, name="C", offset_factor=1, strides=[te.var("C_s"), 1] ) def intrin_func(ins, outs): aa, bb = ins cc = outs[0] def _reduce_update(): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_extern( "int32", f"gemm_{M}x{K}x{N}_update_{uniq_id}", aa.access_ptr("r"), bb.access_ptr("r"), cc.access_ptr("w"), aa.strides[0], bb.strides[0], cc.strides[0], ) ) return ib.get() def _reduce_reset(): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_extern( "int32", f"gemm_{M}x{K}x{N}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0] ) ) return ib.get() def _body(): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_extern( "int32", f"gemm_{M}x{K}x{N}_body_{uniq_id}", aa.access_ptr("r"), bb.access_ptr("r"), cc.access_ptr("w"), aa.strides[0], bb.strides[0], cc.strides[0], ) ) return ib.get() return _body(), _reduce_reset(), _reduce_update() intrin_decl = te.decl_tensor_intrin(C.op, intrin_func, binds={A: A_buf, B: B_buf, C: C_buf}) return intrin_decl, uniq_id
def dot_int8_int8_int32(int32_lanes, dtype='uint'): """ Int8 dot product by every 4 elements using ARM v8.2 udot. This function takes two arrays of int8 datatype -- data[4] and kernel[int32_lanes][4] -- and computes a dot product of data[4] with every 4 elements of kernels, resulting in output[int32_lanes] of uint32 datatype. The pseudo code is as follows. .. code-block:: c void dot_int8_int8_int32(int8 data[4], int8 kernel[16][4], int32 output[16]){ for (int i = 0; i < int32_lanes; i++){ out[i] = 0; for (int k = 0; k < 4; k++){ out[i] += data[k] * kernel[i][k] } } } Physically, the kernel array sits in a vector register and the data[4] is broadcasted to another vector register. This function returns a TensorIntrin that can be used to tensorize a schedule. Parameters ---------- int32_lanes: int How many int32/uint32 to produce dtype: str, optional, {"uint", "int"} Whether it works on unsigned int or signed int Returns ------- intrin : TensorIntrin The ARM uint8 TensorIntrin that can be used in tensorizing schedule """ num_int8_elements = 4 # 4 int8 elements in int32 data = te.placeholder((num_int8_elements, ), dtype='%s8' % dtype, name='data') kernel = te.placeholder((int32_lanes, num_int8_elements), dtype='%s8' % dtype, name='kernel') k = te.reduce_axis((0, num_int8_elements), name='k') C = te.compute((int32_lanes, ), lambda i: te.sum(data[k].astype('%s32' % dtype) * kernel[ i, k].astype('%s32' % dtype), axis=k), name="C") a_buffer = tvm.tir.decl_buffer(data.shape, dtype='%s8' % dtype, name="a_buffer", offset_factor=1, strides=[1]) b_buffer = tvm.tir.decl_buffer(kernel.shape, dtype='%s8' % dtype, name="b_buffer", offset_factor=1, strides=[te.var('s'), 1]) def _intrin_func(ins, outs): def _instr(index): ib = tvm.tir.ir_builder.create() if index == 1: ib.emit(outs[0].vstore( 0, tvm.tir.const(0, '%s32x%d' % (dtype, int32_lanes)))) return ib.get() dtype_a = '%s8x%d' % (dtype, num_int8_elements) dtype_b = '%s8x%d' % (dtype, int32_lanes * num_int8_elements) dtype_c = '%s32x%d' % (dtype, int32_lanes) a_int8 = ins[0].vload([0], dtype_a) re_int32 = tvm.tir.call_pure_intrin('%s32' % dtype, 'reinterpret', a_int8) # broadcast a vec_ai32 = re_int32.astype(dtype_c) vec_a = tvm.tir.call_pure_intrin(dtype_b, 'reinterpret', vec_ai32) vec_b = ins[1].vload([0, 0], dtype_b) vec_c = outs[0].vload([0], dtype_c) inst = 'udot' if dtype == 'uint' else 'sdot' inst = 'llvm.aarch64.neon.%s.v%di32.v%di8' % ( inst, int32_lanes, int32_lanes * num_int8_elements) vdot = tvm.tir.call_llvm_intrin(dtype_c, inst, tvm.tir.const(2, 'uint32'), vec_c, vec_a, vec_b) ib.emit(outs[0].vstore(0, vdot)) return ib.get() # body, reset, update return _instr(0), _instr(1), _instr(2) buffer_params = {"offset_factor": 1} return te.decl_tensor_intrin(C.op, _intrin_func, binds={ data: a_buffer, kernel: b_buffer }, default_buffer_params=buffer_params)
from __future__ import absolute_import, print_function import tvm import tvm.testing from tvm import te from tvm import topi import numpy as np ###################################################################### # Basic example # ------------- # Let's revisit the sum of rows operation (equivalent to :code:`B = numpy.sum(A, axis=1)`') \ # To compute the sum of rows of a two dimensional TVM tensor A, we should # specify the symbolic operation as well as schedule as follows # n = te.var("n") m = te.var("m") A = te.placeholder((n, m), name="A") k = te.reduce_axis((0, m), "k") B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B") s = te.create_schedule(B.op) ###################################################################### # and to examine the IR code in human readable format, we can do # print(tvm.lower(s, [A], simple_mode=True)) ###################################################################### # However, for such a common operation we had to define the reduce axis ourselves as well as explicit computation with # :code:`te.compute`. Imagine for more complicated operations how much details we need to provide. # Fortunately, we can replace those two lines with simple :code:`topi.sum` much like :code:`numpy.sum`
from __future__ import absolute_import, print_function import tvm import tvm.testing from tvm import te import numpy as np m = te.var('m') n = te.var('n') X = te.placeholder((m, n), name='x') s_state = te.placeholder((m, n)) s_init = te.compute((1, n), lambda _, i: X[0, i]) s_update = te.compute((m, n), lambda t, i: s_state[t - 1, i] + X[t, i]) s_scan = te.scan(s_init, s_update, s_state, inputs=[X]) s = te.create_schedule(s_scan.op) num_thread = 256 block_X = te.thread_axis('blockIdx.x') thread_X = te.thread_axis('threadIdx.x') xo, xi = s[s_init].split(s_init.op.axis[1], factor=num_thread) s[s_init].bind(xo, block_X) s[s_init].bind(xi, thread_X) xo, xi = s[s_update].split(s.update.op.axis[1], factor=num_thread) s[s_update].bind(xo, block_X) s[s_update].bind(xi, thread_X) print(tvm.lower(s, [X, s_scan], simple_mode=True)) # multi-stage scan cell m = te.var('m') n = te.var('n')
def test_cse(): z1 = te.var("z1") z2 = te.var("z2") z3 = te.var("z3") i1 = te.var("i1") i2 = te.var("i2") x = te.var("x") y = te.var("y") a = te.var("a") b = te.var("b") dtype = "int32" buffer = tvm.tir.decl_buffer((50,), dtype) # Test prog : # let z1=1 in let z2=2 in # Mem[i1] = z1+z2; # let x = 1 in let y = 1 in # let a = (x+y) + (z1+z2) in # let b = (x+y) + z3 in # Mem[i2] = a+b; body = tvm.tir.LetStmt( z1, 1, tvm.tir.LetStmt( z2, 2, tvm.tir.SeqStmt( [ tvm.tir.BufferStore(buffer, z1 + z2, [i1]), tvm.tir.LetStmt( x, 1, tvm.tir.LetStmt( y, 1, tvm.tir.LetStmt( a, (x + y) + (z1 + z2), tvm.tir.LetStmt( b, (x + y) + z3, tvm.tir.BufferStore(buffer, a + b, [i2]) ), ), ), ), ] ), ), ) # This test program gives the opportunity to introduce two new variables, at two different levels # and to perform replacements in the value of "a" and "b", using these new variables # We will check all of that underneath and more, making also sure that nothing else has been changed mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([i1, i2, z3], body)) body = tvm.tir.transform.CommonSubexprElimTIR()(mod) tvm.transform.PrintIR()(body) body = body["main"].body # Gets the body of the main, i.e. the full statement assert body.var.name == "z1" assert body.value == 1 body = body.body assert body.var.name == "z2" assert body.value == 2 # This is the let-in for the first variable generated cse_var_1 assert isinstance(body.body, tvm.tir.LetStmt) body = body.body # And this is the name and value of this variable cse_var_1 = body.var # Keep the variable accessible for later checking the replacements assert body.var.name == "cse_var_1" assert tvm.ir.structural_equal(body.value, z1 + z2) assert isinstance(body.body, tvm.tir.SeqStmt) body = body.body assert isinstance(body[0], tvm.tir.BufferStore) assert isinstance(body[1], tvm.tir.LetStmt) body = body[1] assert body.var.name == "x" assert body.value == 1 body = body.body assert body.var.name == "y" assert body.value == 1 # This is the let-in for the second variable generated cse_var_2 assert isinstance(body.body, tvm.tir.LetStmt) body = body.body # And this is the name and value of this variable cse_var_2 = body.var # Keep the variable accessible for later checking the replacements assert body.var.name == "cse_var_2" assert tvm.ir.structural_equal(body.value, x + y) body = body.body body.var.name == "a" # Check that the replacement has been done correctly! assert tvm.ir.structural_equal(body.value, cse_var_2 + cse_var_1) body = body.body body.var.name == "b" # Check that the replacement has been done correctly! assert tvm.ir.structural_equal(body.value, cse_var_2 + z3) assert isinstance(body.body, tvm.tir.BufferStore)
def test_split_infer_type(): def verify_split(dshape, indices_or_sections, ret_type, axis=None): x = relay.var("x", relay.ty.TensorType(dshape, "float32")) y = relay.split(x, indices_or_sections, axis=axis) yy = run_infer_type(y.astuple()) assert yy.checked_type == ret_type idxd = tvm.tir.indexdiv d1, d2, d3, d4 = te.var("d1"), te.var("d2"), te.var("d3"), te.var("d4") axis = te.var("axis") verify_split((5, 5, 2, 2), 5, relay.ty.TupleType( tvm.runtime.convert([ relay.ty.TensorType((5, 1, 2, 2), "float32"), relay.ty.TensorType((5, 1, 2, 2), "float32"), relay.ty.TensorType((5, 1, 2, 2), "float32"), relay.ty.TensorType((5, 1, 2, 2), "float32"), relay.ty.TensorType((5, 1, 2, 2), "float32") ])), axis=1) verify_split((5, 5, 2, 2), 5, relay.ty.TupleType( tvm.runtime.convert([ relay.ty.TensorType((1, 5, 2, 2), "float32"), relay.ty.TensorType((1, 5, 2, 2), "float32"), relay.ty.TensorType((1, 5, 2, 2), "float32"), relay.ty.TensorType((1, 5, 2, 2), "float32"), relay.ty.TensorType((1, 5, 2, 2), "float32") ])), axis=0) verify_split( (d1, d2, d3, d4), 4, relay.ty.TupleType( tvm.runtime.convert([ relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"), relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"), relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"), relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32") ])), axis=2) verify_split( (d1, d2, d3, d4), 2, relay.ty.TupleType( tvm.runtime.convert([ relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32"), relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32") ])), axis=0) verify_split((d1, d2, d3, d4), (2, 4, 7), relay.ty.TupleType( tvm.runtime.convert([ relay.ty.TensorType((d1, 2, d3, d4), "float32"), relay.ty.TensorType((d1, 2, d3, d4), "float32"), relay.ty.TensorType((d1, 3, d3, d4), "float32"), relay.ty.TensorType((d1, (d2 - 7), d3, d4), "float32") ])), axis=1)
def test_reduce_functions(): def _with_keepdims(func): def _wrapper(data, axis=None, keepdims=False): if not keepdims: return func(data, axis=axis) else: if axis is not None: axis = axis if isinstance(axis, int) else axis[0] out_shape = list(data.shape) out_shape[axis] = 1 else: out_shape = [1 for _ in range(len(data.shape))] return func(data, axis=axis).reshape(out_shape) return _wrapper def _np_log_sum_exp(x, axis, keepdims=False): max_x = np.max(x, axis=axis, keepdims=True) x = np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=True)) x = x + max_x if not keepdims: x = np.squeeze(x, axis=axis) return x def _unbiased_relay_wrapper(f): def _unbiased_func(x, axis=None, keepdims=False, exclude=False): return f(x, axis=axis, keepdims=keepdims, exclude=exclude, unbiased=True) return _unbiased_func def _unbiased_np_wrapper(f): def _unbiased_func(a, axis=None, dtype=None, keepdims=None): return f(a, axis=axis, dtype=dtype, ddof=1, keepdims=keepdims) return _unbiased_func d1, d2, d3, d4 = te.var("d1"), te.var("d2"), te.var("d3"), te.var("d4") for func in [ [relay.sum, np.sum], [relay.max, np.max], [relay.min, np.min], [relay.mean, np.mean], [relay.variance, np.var], [_unbiased_relay_wrapper(relay.variance), _unbiased_np_wrapper(np.var)], [relay.std, np.std], [_unbiased_relay_wrapper(relay.std), _unbiased_np_wrapper(np.std)], [relay.prod, np.prod], [relay.all, np.all], [relay.any, np.any], [relay.logsumexp, _np_log_sum_exp], [relay.argmin, _with_keepdims(np.argmin)], [relay.argmax, _with_keepdims(np.argmax)], ]: verify_reduce(func, (d1, d2, d3, d4), None, False, False, ()) verify_reduce(func, (d1, d2, d3, d4), 2, True, False, (d1, d2, 1, d4)) verify_reduce(func, (d1, d2, d3, d4), 0, True, False, (1, d2, d3, d4)) verify_reduce(func, (d1, d2, d3), 1, True, False, (d1, 1, d3)) verify_reduce(func, (d1, d2, d3), 0, True, False, (1, d2, d3)) verify_reduce(func, (d1, d2, d3), None, True, False, (1, 1, 1)) verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3)) verify_reduce(func, (2, 3, 4), 1, True, False, (2, 1, 4)) verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4)) verify_reduce(func, (2, 3, 4), -1, True, False, (2, 3, 1)) verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ()) verify_reduce(func, (4, 4, 3), None, False, False, ()) verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,)) verify_reduce(func, (128, 24, 128), (0, 1), False, False, (128,)) verify_reduce(func, (128, 24, 128), (0, 2), False, False, (24,)) verify_reduce(func, (128, 24, 128), (0, 1), True, False, (1, 1, 128)) verify_reduce(func, (128, 24, 128), (0, 2), True, False, (1, 24, 1))
""" Split split是fuse的反操作,把iter以factor为间隔分离成outer与inner两层迭代,增加循环层数,用于将循环操作分割为更小的子任务。 事实上,以CUDA为例,gridDim和blockDim都可以最多是三维,所以通过split可以产生新的维度用于绑定到grid和block上 """ import tvm from tvm import te import numpy as np # declare some variables for use later n = te.var('n') m = te.var('m') A = te.placeholder((m, ), name='A') B = te.compute((m, ), lambda i: A[i] * 2, name='B') s = te.create_schedule(B.op) xo, xi = s[B].split( B.op.axis[0], factor=32) # split can split a specified axis into two axises by factor. print(tvm.lower(s, [A, B], simple_mode=True)) A = te.placeholder((m, ), name='A') B = te.compute((m, ), lambda i: A[i], name='B') s = te.create_schedule(B.op)
# The scan is carried over the highest dimension of the tensor. # :code:`s_state` is a placeholder that describes the transition state of the scan. # :code:`s_init` describes how we can initialize the first k timesteps. # Here since s_init's first dimension is 1, it describes how we initialize # The state at first timestep. # # :code:`s_update` describes how to update the value at timestep t. The update # value can refer back to the values of previous timestep via state placeholder. # Note that while it is invalid to refer to :code:`s_state` at current or later timestep. # # The scan takes in state placeholder, initial value and update description. # It is also recommended(although not necessary) to list the inputs to the scan cell. # The result of the scan is a tensor, giving the result of :code:`s_state` after the # update over the time domain. # m = te.var("m") n = te.var("n") X = te.placeholder((m, n), name="X") s_state = te.placeholder((m, n)) s_init = te.compute((1, n), lambda _, i: X[0, i]) s_update = te.compute((m, n), lambda t, i: s_state[t - 1, i] + X[t, i]) s_scan = tvm.te.scan(s_init, s_update, s_state, inputs=[X]) ###################################################################### # Schedule the Scan Cell # ---------------------- # We can schedule the body of the scan by scheduling the update and # init part seperately. Note that it is invalid to schedule the # first iteration dimension of the update part. # To split on the time iteration, user can schedule on scan_op.scan_axis instead. #