def gemv_impl(): cc_code = """ extern "C" int gemv_update(float *cc, float *aa, float *bb, int m, int l, int stride) { for (int i = 0; i < m; ++i) { for (int j = 0; j < l; ++j) { cc[i] += aa[j] * bb[i * stride + j]; } } return 0; } """ from tvm.contrib import util, clang temp = util.tempdir() ll_path = temp.relpath("temp.ll") # Create LLVM ir from c source code ll_code = clang.create_llvm(cc_code, output=ll_path) return ll_code
def gemv_impl(): cc_code = """ extern "C" int gemv_update(float *cc, float *aa, float *bb, int m, int l, int stride) { for (int i = 0; i < m; ++i) { for (int j = 0; j < l; ++j) { cc[i] += aa[j] * bb[i * stride + j]; } } return 0; } """ from tvm.contrib import util, clang temp = util.tempdir() ll_path = temp.relpath("temp.ll") # Create LLVM ir from c source code ll_code = clang.create_llvm(cc_code, output=ll_path) return ll_code
def check_llvm(use_file): if not clang.find_clang(required=False): print("skip because clang is not available") return temp = utils.tempdir() ll_path = temp.relpath("temp.ll") ll_code = clang.create_llvm(cc_code, output=ll_path) s = te.create_schedule(B.op) if use_file: s[B].pragma(s[B].op.axis[0], "import_llvm", ll_path) else: s[B].pragma(s[B].op.axis[0], "import_llvm", ll_code) # BUILD and invoke the kernel. f = tvm.build(s, [A, B], "llvm") dev = tvm.cpu(0) # launch the kernel. a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev) f(a, b) tvm.testing.assert_allclose(b.asnumpy(), a.asnumpy() + 1.0)
def check_llvm(use_file): if not tvm.module.enabled("llvm"): return if not clang.find_clang(required=False): print("skip because clang is not available") return temp = util.tempdir() ll_path = temp.relpath("temp.ll") ll_code = clang.create_llvm(cc_code, output=ll_path) s = tvm.create_schedule(B.op) if use_file: s[B].pragma(s[B].op.axis[0], "import_llvm", ll_path) else: s[B].pragma(s[B].op.axis[0], "import_llvm", ll_code) # BUILD and invoke the kernel. f = tvm.build(s, [A, B], "llvm") ctx = tvm.cpu(0) # launch the kernel. a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx) f(a, b) tvm.testing.assert_allclose( b.asnumpy(), a.asnumpy() + 1.0)
def gemv_quantized_impl(M, N, data_type='uint8'): """ Assembly implementation of a blocked gemv. Given a block a of shape (4, k) and a block b' of shape (4, k) produces the output block c = a*b of shape (4,4) """ stepA = min(4, M) stepB = min(4, N) assert data_type in ['uint8', 'int8' ], 'Only uint8/int8 supported for this implementation' cc_code = """ extern "C" int gemv_{0}_{0}_int32_{1}_{2}(int *c_buffer, unsigned char *a_buffer, unsigned char *b_buffer, int K, int m, int n) """.format(data_type, stepA, stepB) cc_code += """ { unsigned char * a_ptr = a_buffer; unsigned char * b_ptr = b_buffer; int * c_ptr = c_buffer; int k = K / 16; __asm__ __volatile__ ( "movi v16.4s, #0\\n" "movi v17.4s, #0\\n" "movi v18.4s, #0\\n" "movi v19.4s, #0\\n" "movi v20.4s, #0\\n" "movi v21.4s, #0\\n" "movi v22.4s, #0\\n" "movi v23.4s, #0\\n" "movi v24.4s, #0\\n" "movi v25.4s, #0\\n" "movi v26.4s, #0\\n" "movi v27.4s, #0\\n" "movi v28.4s, #0\\n" "movi v29.4s, #0\\n" "movi v30.4s, #0\\n" "movi v31.4s, #0\\n" "1:" """ cc_code += ' "ldr q0, [%[a_ptr]]\\n" ' if M > 1: cc_code += ' "ldr q1, [%[a_ptr], #16]\\n" ' else: cc_code += ' "movi v1.4s, #0\\n" ' if M > 2: cc_code += ' "ldr q2, [%[a_ptr], #32]\\n" ' else: cc_code += ' "movi v2.4s, #0\\n" ' if M > 3: cc_code += ' "ldr q3, [%[a_ptr], #48]\\n" ' else: cc_code += ' "movi v3.4s, #0\\n" ' cc_code += ' "ldr q4, [%[b_ptr]]\\n" ' if N > 1: cc_code += ' "ldr q5, [%[b_ptr], #16]\\n" ' if N > 2: cc_code += ' "ldr q6, [%[b_ptr], #32]\\n" ' if N > 3: cc_code += ' "ldr q7, [%[b_ptr], #48]\\n" ' cc_code += """ // First half // Higher part of a0 * {b0,b1,b2,b3} "umull v8.8h, v0.8b, v4.8b\\n" "umull v9.8h, v0.8b, v5.8b\\n" "umull v10.8h, v0.8b, v6.8b\\n" "umull v11.8h, v0.8b, v7.8b\\n" // Higher part of a1 * {b0,b1,b2,b3} "umull v12.8h, v1.8b, v4.8b\\n" "umull v13.8h, v1.8b, v5.8b\\n" "umull v14.8h, v1.8b, v6.8b\\n" "umull v15.8h, v1.8b, v7.8b\\n" // Accumulate "uadalp v16.4s, v8.8h\\n" "uadalp v17.4s, v9.8h\\n" "uadalp v18.4s, v10.8h\\n" "uadalp v19.4s, v11.8h\\n" "uadalp v20.4s, v12.8h\\n" "uadalp v21.4s, v13.8h\\n" "uadalp v22.4s, v14.8h\\n" "uadalp v23.4s, v15.8h\\n" // Lower part of a0 * {b0,b1,b2,b3} "umull2 v8.8h, v0.16b, v4.16b\\n" "umull2 v9.8h, v0.16b, v5.16b\\n" "umull2 v10.8h, v0.16b, v6.16b\\n" "umull2 v11.8h, v0.16b, v7.16b\\n" // Lower part of a1 * {b0,b1,b2,b3} "umull2 v12.8h, v1.16b, v4.16b\\n" "umull2 v13.8h, v1.16b, v5.16b\\n" "umull2 v14.8h, v1.16b, v6.16b\\n" "umull2 v15.8h, v1.16b, v7.16b\\n" // Accumulate again "uadalp v16.4s, v8.8h\\n" "uadalp v17.4s, v9.8h\\n" "uadalp v18.4s, v10.8h\\n" "uadalp v19.4s, v11.8h\\n" "uadalp v20.4s, v12.8h\\n" "uadalp v21.4s, v13.8h\\n" "uadalp v22.4s, v14.8h\\n" "uadalp v23.4s, v15.8h\\n" // Second half // Lower part of a2 * {b0,b1,b2,b3} "umull v8.8h, v2.8b, v4.8b\\n" "umull v9.8h, v2.8b, v5.8b\\n" "umull v10.8h, v2.8b, v6.8b\\n" "umull v11.8h, v2.8b, v7.8b\\n" // Lower part of a3 * {b0,b1,b2,b3} "umull v12.8h, v3.8b, v4.8b\\n" "umull v13.8h, v3.8b, v5.8b\\n" "umull v14.8h, v3.8b, v6.8b\\n" "umull v15.8h, v3.8b, v7.8b\\n" // Accumulate "uadalp v24.4s, v8.8h\\n" "uadalp v25.4s, v9.8h\\n" "uadalp v26.4s, v10.8h\\n" "uadalp v27.4s, v11.8h\\n" "uadalp v28.4s, v12.8h\\n" "uadalp v29.4s, v13.8h\\n" "uadalp v30.4s, v14.8h\\n" "uadalp v31.4s, v15.8h\\n" // Higher part of a2 * {b0,b1,b2,b3} "umull2 v8.8h, v2.16b, v4.16b\\n" "umull2 v9.8h, v2.16b, v5.16b\\n" "umull2 v10.8h, v2.16b, v6.16b\\n" "umull2 v11.8h, v2.16b, v7.16b\\n" // Higher part of a3 * {b0,b1,b2,b3} "umull2 v12.8h, v3.16b, v4.16b\\n" "umull2 v13.8h, v3.16b, v5.16b\\n" "umull2 v14.8h, v3.16b, v6.16b\\n" "umull2 v15.8h, v3.16b, v7.16b\\n" // Accumulate again "uadalp v24.4s, v8.8h\\n" "uadalp v25.4s, v9.8h\\n" "uadalp v26.4s, v10.8h\\n" "uadalp v27.4s, v11.8h\\n" "uadalp v28.4s, v12.8h\\n" "uadalp v29.4s, v13.8h\\n" "uadalp v30.4s, v14.8h\\n" "uadalp v31.4s, v15.8h\\n" """ blockA = min(64, M * 16) blockB = min(64, N * 16) cc_code += """ // Increment pointers and decrement k "add %[a_ptr], %[a_ptr], #{0}\\n" "add %[b_ptr], %[b_ptr], #{1}\\n" "subs %w[k], %w[k], #1\\n" """.format(blockA, blockB) stepC = min(4, N) cc_code += """ "cbnz %w[k], 1b\\n" // Final additions // v16 contains the four partial sums of a[0, 0:K].*b[0,0:K], let's call them (a,b,c,d) // v17 contains the four partial sums of a[0, 0:K].*b[1,0:K], let's call them (e,f,g,h) // v18 contains the four partial sums of a[0, 0:K].*b[2,0:K], let's call them (i,j,k,l) // v19 contains the four partial sums of a[0, 0:K].*b[3,0:K], let's call them (m,n,o,p) "addp v16.4s, v16.4s, v17.4s\\n" // v16 = (a+b, c+d, e+f, g+h) "addp v17.4s, v18.4s, v19.4s\\n" // v17 = (i+j, k+l, m+n, o+p) "addp v16.4s, v16.4s, v17.4s\\n" // v16 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) // v20 contains the four partial sums of a[1, 0:K].*b[0,0:K], let's call them (a,b,c,d) // v21 contains the four partial sums of a[1, 0:K].*b[1,0:K], let's call them (e,f,g,h) // v22 contains the four partial sums of a[1, 0:K].*b[2,0:K], let's call them (i,j,k,l) // v23 contains the four partial sums of a[1, 0:K].*b[3,0:K], let's call them (m,n,o,p) "addp v20.4s, v20.4s, v21.4s\\n" // v20 = (a+b, c+d, e+f, g+h) "addp v21.4s, v22.4s, v23.4s\\n" // v21 = (i+j, k+l, m+n, o+p) "addp v20.4s, v20.4s, v21.4s\\n" // v20 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) // v24 contains the four partial sums of a[2, 0:K].*b[0,0:K], let's call them (a,b,c,d) // v25 contains the four partial sums of a[2, 0:K].*b[1,0:K], let's call them (e,f,g,h) // v26 contains the four partial sums of a[2, 0:K].*b[2,0:K], let's call them (i,j,k,l) // v27 contains the four partial sums of a[2, 0:K].*b[3,0:K], let's call them (m,n,o,p) "addp v24.4s, v24.4s, v25.4s\\n" // v24 = (a+b, c+d, e+f, g+h) "addp v25.4s, v26.4s, v27.4s\\n" // v25 = (i+j, k+l, m+n, o+p) "addp v24.4s, v24.4s, v25.4s\\n" // v24 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) // v28 contains the four partial sums of a[3, 0:K].*b[0,0:K], let's call them (a,b,c,d) // v29 contains the four partial sums of a[3, 0:K].*b[1,0:K], let's call them (e,f,g,h) // v30 contains the four partial sums of a[3, 0:K].*b[2,0:K], let's call them (i,j,k,l) // v31 contains the four partial sums of a[3, 0:K].*b[3,0:K], let's call them (m,n,o,p) "addp v28.4s, v28.4s, v29.4s\\n" // v28 = (a+b, c+d, e+f, g+h) "addp v29.4s, v30.4s, v31.4s\\n" // v29 = (i+j, k+l, m+n, o+p) "addp v28.4s, v28.4s, v29.4s\\n" // v28 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) "str q16, [%[c_ptr]]\\n" """ if M > 1: cc_code += ' "str q20, [%[c_ptr], #{0}]\\n" '.format(stepC * 4) if M > 2: cc_code += ' "str q24, [%[c_ptr], #{0}]\\n" '.format(stepC * 8) if M > 3: cc_code += ' "str q28, [%[c_ptr], #{0}]\\n" '.format(stepC * 12) cc_code += """ : [c_ptr] "+r" (c_ptr), [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [k] "+r" (k) : : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" ); return 0; } """ if data_type == 'int8': cc_code = cc_code.replace('unsigned char', 'char') cc_code = cc_code.replace('umull', 'smull') cc_code = cc_code.replace('uadalp', 'sadalp') temp = util.tempdir() ll_path = temp.relpath("temp.ll") # Create LLVM ir from c source code ll_code = clang.create_llvm( cc_code, options=["-mtriple=aarch64-linux-gnu -mattr=+neon"], output=ll_path) return ll_code
def gemm_quantized_impl(M, N, K, unroll, interleave, data_type="uint8"): """Assembly implementation of a blocked gemv. Given a block a of shape (4, k) and a block b' of shape (4, k) produces the output block c = a*b of shape (4,4)""" stepA = min(4, M) stepB = min(4, N) assert data_type in ["uint8", "int8" ], "Only uint8/int8 supported for this implementation" signature = """extern "C" int gemm_quantized_{0}_{0}_int32_{1}_{2}""".format( data_type, stepA, stepB) if unroll: signature += "_" + str(K) if interleave: signature += "_interleaved" signature += """(int *c_buffer, unsigned char *a_buffer, unsigned char *b_buffer, int K, int m, int n)""" cc_code = signature cc_code += """ { unsigned char * a_ptr = a_buffer; unsigned char * b_ptr = b_buffer; int * c_ptr = c_buffer; int k = K / 16; __asm__ __volatile__ ( "movi v16.4s, #0\\n" "movi v17.4s, #0\\n" "movi v18.4s, #0\\n" "movi v19.4s, #0\\n" "movi v20.4s, #0\\n" "movi v21.4s, #0\\n" "movi v22.4s, #0\\n" "movi v23.4s, #0\\n" "movi v24.4s, #0\\n" "movi v25.4s, #0\\n" "movi v26.4s, #0\\n" "movi v27.4s, #0\\n" "movi v28.4s, #0\\n" "movi v29.4s, #0\\n" "movi v30.4s, #0\\n" "movi v31.4s, #0\\n" "1:" """ main_loop = ' "ldr q0, [%[a_ptr]]\\n" ' if M > 1: main_loop += ' "ldr q1, [%[a_ptr], #16]\\n" ' else: main_loop += ' "movi v1.4s, #0\\n" ' if M > 2: main_loop += ' "ldr q2, [%[a_ptr], #32]\\n" ' else: main_loop += ' "movi v2.4s, #0\\n" ' if M > 3: main_loop += ' "ldr q3, [%[a_ptr], #48]\\n" ' else: main_loop += ' "movi v3.4s, #0\\n" ' main_loop += ' "ldr q4, [%[b_ptr]]\\n" ' if N > 1: main_loop += ' "ldr q5, [%[b_ptr], #16]\\n" ' if N > 2: main_loop += ' "ldr q6, [%[b_ptr], #32]\\n" ' if N > 3: main_loop += ' "ldr q7, [%[b_ptr], #48]\\n" ' # Main computation can interleave multiply/accumulate instructions # or schedule them in batches (first all multiplies then all accumulates) if interleave: main_loop += gemm_quantized_4_4_interleaved() else: main_loop += gemm_quantized_4_4_batched() blockA = min(64, M * 16) blockB = min(64, N * 16) main_loop += """// Increment pointers "add %[a_ptr], %[a_ptr], #{0}\\n" "add %[b_ptr], %[b_ptr], #{1}\\n" """.format( blockA, blockB) if unroll: k = int(K // 16) for l in range(0, k): cc_code += main_loop else: cc_code += main_loop cc_code += """ "subs %w[k], %w[k], #1\\n" "cbnz %w[k], 1b\\n" """ cc_code += """ // Final additions // v16 contains the four partial sums of a[0, 0:K].*b[0,0:K], let's call them (a,b,c,d) // v17 contains the four partial sums of a[0, 0:K].*b[1,0:K], let's call them (e,f,g,h) // v18 contains the four partial sums of a[0, 0:K].*b[2,0:K], let's call them (i,j,k,l) // v19 contains the four partial sums of a[0, 0:K].*b[3,0:K], let's call them (m,n,o,p) "addp v16.4s, v16.4s, v17.4s\\n" // v16 = (a+b, c+d, e+f, g+h) "addp v17.4s, v18.4s, v19.4s\\n" // v17 = (i+j, k+l, m+n, o+p) "addp v16.4s, v16.4s, v17.4s\\n" // v16 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) // v20 contains the four partial sums of a[1, 0:K].*b[0,0:K], let's call them (a,b,c,d) // v21 contains the four partial sums of a[1, 0:K].*b[1,0:K], let's call them (e,f,g,h) // v22 contains the four partial sums of a[1, 0:K].*b[2,0:K], let's call them (i,j,k,l) // v23 contains the four partial sums of a[1, 0:K].*b[3,0:K], let's call them (m,n,o,p) "addp v20.4s, v20.4s, v21.4s\\n" // v20 = (a+b, c+d, e+f, g+h) "addp v21.4s, v22.4s, v23.4s\\n" // v21 = (i+j, k+l, m+n, o+p) "addp v20.4s, v20.4s, v21.4s\\n" // v20 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) // v24 contains the four partial sums of a[2, 0:K].*b[0,0:K], let's call them (a,b,c,d) // v25 contains the four partial sums of a[2, 0:K].*b[1,0:K], let's call them (e,f,g,h) // v26 contains the four partial sums of a[2, 0:K].*b[2,0:K], let's call them (i,j,k,l) // v27 contains the four partial sums of a[2, 0:K].*b[3,0:K], let's call them (m,n,o,p) "addp v24.4s, v24.4s, v25.4s\\n" // v24 = (a+b, c+d, e+f, g+h) "addp v25.4s, v26.4s, v27.4s\\n" // v25 = (i+j, k+l, m+n, o+p) "addp v24.4s, v24.4s, v25.4s\\n" // v24 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) // v28 contains the four partial sums of a[3, 0:K].*b[0,0:K], let's call them (a,b,c,d) // v29 contains the four partial sums of a[3, 0:K].*b[1,0:K], let's call them (e,f,g,h) // v30 contains the four partial sums of a[3, 0:K].*b[2,0:K], let's call them (i,j,k,l) // v31 contains the four partial sums of a[3, 0:K].*b[3,0:K], let's call them (m,n,o,p) "addp v28.4s, v28.4s, v29.4s\\n" // v28 = (a+b, c+d, e+f, g+h) "addp v29.4s, v30.4s, v31.4s\\n" // v29 = (i+j, k+l, m+n, o+p) "addp v28.4s, v28.4s, v29.4s\\n" // v28 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p) "str q16, [%[c_ptr]]\\n" """ stepC = min(4, N) if M > 1: cc_code += ' "str q20, [%[c_ptr], #{0}]\\n" '.format(stepC * 4) if M > 2: cc_code += ' "str q24, [%[c_ptr], #{0}]\\n" '.format(stepC * 8) if M > 3: cc_code += ' "str q28, [%[c_ptr], #{0}]\\n" '.format(stepC * 12) cc_code += """ : [c_ptr] "+r" (c_ptr), [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [k] "+r" (k) : : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" ); return 0; } """ if data_type == "int8": cc_code = cc_code.replace("unsigned char", "char") cc_code = cc_code.replace("umull", "smull") cc_code = cc_code.replace("uadalp", "sadalp") temp = utils.tempdir() ll_path = temp.relpath("temp.ll") # Create LLVM ir from c source code ll_code = clang.create_llvm( cc_code, options=["--target=aarch64-linux-gnu -mattr=+neon"], output=ll_path) return ll_code
def _c_to_llvm(c_code: str) -> str: unique_filename = str(uuid.uuid4()) temp = utils.tempdir() ll_path = temp.relpath(f"{unique_filename}.ll") ll_code = clang.create_llvm([c_code], output=ll_path) return ll_code