Example #1
0
def gemv_impl():
    cc_code = """
      extern "C" int gemv_update(float *cc, float *aa, float *bb, int m, int l, int stride) {
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < l; ++j) {
                cc[i] += aa[j] * bb[i * stride + j];
            }
        }
        return 0;
      }
    """
    from tvm.contrib import util, clang
    temp = util.tempdir()
    ll_path = temp.relpath("temp.ll")
    # Create LLVM ir from c source code
    ll_code = clang.create_llvm(cc_code, output=ll_path)
    return ll_code
Example #2
0
def gemv_impl():
    cc_code = """
      extern "C" int gemv_update(float *cc, float *aa, float *bb, int m, int l, int stride) {
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < l; ++j) {
                cc[i] += aa[j] * bb[i * stride + j];
            }
        }
        return 0;
      }
    """
    from tvm.contrib import util, clang
    temp = util.tempdir()
    ll_path = temp.relpath("temp.ll")
    # Create LLVM ir from c source code
    ll_code = clang.create_llvm(cc_code, output=ll_path)
    return ll_code
 def check_llvm(use_file):
     if not clang.find_clang(required=False):
         print("skip because clang is not available")
         return
     temp = utils.tempdir()
     ll_path = temp.relpath("temp.ll")
     ll_code = clang.create_llvm(cc_code, output=ll_path)
     s = te.create_schedule(B.op)
     if use_file:
         s[B].pragma(s[B].op.axis[0], "import_llvm", ll_path)
     else:
         s[B].pragma(s[B].op.axis[0], "import_llvm", ll_code)
     # BUILD and invoke the kernel.
     f = tvm.build(s, [A, B], "llvm")
     dev = tvm.cpu(0)
     # launch the kernel.
     a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
     b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev)
     f(a, b)
     tvm.testing.assert_allclose(b.asnumpy(), a.asnumpy() + 1.0)
Example #4
0
 def check_llvm(use_file):
     if not tvm.module.enabled("llvm"):
         return
     if not clang.find_clang(required=False):
         print("skip because clang is not available")
         return
     temp = util.tempdir()
     ll_path = temp.relpath("temp.ll")
     ll_code = clang.create_llvm(cc_code, output=ll_path)
     s = tvm.create_schedule(B.op)
     if use_file:
         s[B].pragma(s[B].op.axis[0], "import_llvm", ll_path)
     else:
         s[B].pragma(s[B].op.axis[0], "import_llvm", ll_code)
     # BUILD and invoke the kernel.
     f = tvm.build(s, [A, B], "llvm")
     ctx = tvm.cpu(0)
     # launch the kernel.
     a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
     b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
     f(a, b)
     tvm.testing.assert_allclose(
         b.asnumpy(), a.asnumpy() + 1.0)
Example #5
0
def gemv_quantized_impl(M, N, data_type='uint8'):
    """ Assembly implementation of a blocked gemv. Given
    a block a of shape (4, k) and a block b' of shape (4, k)
    produces the output block c = a*b of shape (4,4) """

    stepA = min(4, M)
    stepB = min(4, N)
    assert data_type in ['uint8', 'int8'
                         ], 'Only uint8/int8 supported for this implementation'

    cc_code = """
          extern "C" int gemv_{0}_{0}_int32_{1}_{2}(int *c_buffer,
                                                    unsigned char *a_buffer,
                                                    unsigned char *b_buffer,
                                                    int K, int m, int n)
              """.format(data_type, stepA, stepB)

    cc_code += """
    {
            unsigned char * a_ptr = a_buffer;
            unsigned char * b_ptr = b_buffer;
            int * c_ptr = c_buffer;

            int k = K / 16;

            __asm__  __volatile__ (
                "movi v16.4s, #0\\n"
                "movi v17.4s, #0\\n"
                "movi v18.4s, #0\\n"
                "movi v19.4s, #0\\n"
                "movi v20.4s, #0\\n"
                "movi v21.4s, #0\\n"
                "movi v22.4s, #0\\n"
                "movi v23.4s, #0\\n"
                "movi v24.4s, #0\\n"
                "movi v25.4s, #0\\n"
                "movi v26.4s, #0\\n"
                "movi v27.4s, #0\\n"
                "movi v28.4s, #0\\n"
                "movi v29.4s, #0\\n"
                "movi v30.4s, #0\\n"
                "movi v31.4s, #0\\n"
            "1:"
    """

    cc_code += ' "ldr q0, [%[a_ptr]]\\n" '

    if M > 1:
        cc_code += ' "ldr q1, [%[a_ptr], #16]\\n" '
    else:
        cc_code += ' "movi v1.4s, #0\\n" '

    if M > 2:
        cc_code += ' "ldr q2, [%[a_ptr], #32]\\n" '
    else:
        cc_code += ' "movi v2.4s, #0\\n" '

    if M > 3:
        cc_code += ' "ldr q3, [%[a_ptr], #48]\\n" '
    else:
        cc_code += ' "movi v3.4s, #0\\n" '

    cc_code += ' "ldr q4, [%[b_ptr]]\\n" '

    if N > 1:
        cc_code += ' "ldr q5, [%[b_ptr], #16]\\n" '

    if N > 2:
        cc_code += ' "ldr q6, [%[b_ptr], #32]\\n" '

    if N > 3:
        cc_code += ' "ldr q7, [%[b_ptr], #48]\\n" '

    cc_code += """
                // First half
                // Higher part of a0 * {b0,b1,b2,b3}
                "umull v8.8h, v0.8b, v4.8b\\n"
                "umull v9.8h, v0.8b, v5.8b\\n"
                "umull v10.8h, v0.8b, v6.8b\\n"
                "umull v11.8h, v0.8b, v7.8b\\n"

                // Higher part of a1 * {b0,b1,b2,b3}
                "umull v12.8h, v1.8b, v4.8b\\n"
                "umull v13.8h, v1.8b, v5.8b\\n"
                "umull v14.8h, v1.8b, v6.8b\\n"
                "umull v15.8h, v1.8b, v7.8b\\n"

                // Accumulate
                "uadalp v16.4s, v8.8h\\n"
                "uadalp v17.4s, v9.8h\\n"
                "uadalp v18.4s, v10.8h\\n"
                "uadalp v19.4s, v11.8h\\n"
                "uadalp v20.4s, v12.8h\\n"
                "uadalp v21.4s, v13.8h\\n"
                "uadalp v22.4s, v14.8h\\n"
                "uadalp v23.4s, v15.8h\\n"

                // Lower part of a0 * {b0,b1,b2,b3}
                "umull2 v8.8h, v0.16b, v4.16b\\n"
                "umull2 v9.8h, v0.16b, v5.16b\\n"
                "umull2 v10.8h, v0.16b, v6.16b\\n"
                "umull2 v11.8h, v0.16b, v7.16b\\n"

                // Lower part of a1 * {b0,b1,b2,b3}
                "umull2 v12.8h, v1.16b, v4.16b\\n"
                "umull2 v13.8h, v1.16b, v5.16b\\n"
                "umull2 v14.8h, v1.16b, v6.16b\\n"
                "umull2 v15.8h, v1.16b, v7.16b\\n"

                 // Accumulate again
                "uadalp v16.4s, v8.8h\\n"
                "uadalp v17.4s, v9.8h\\n"
                "uadalp v18.4s, v10.8h\\n"
                "uadalp v19.4s, v11.8h\\n"
                "uadalp v20.4s, v12.8h\\n"
                "uadalp v21.4s, v13.8h\\n"
                "uadalp v22.4s, v14.8h\\n"
                "uadalp v23.4s, v15.8h\\n"

                // Second half

                // Lower part of a2 * {b0,b1,b2,b3}
                "umull v8.8h, v2.8b, v4.8b\\n"
                "umull v9.8h, v2.8b, v5.8b\\n"
                "umull v10.8h, v2.8b, v6.8b\\n"
                "umull v11.8h, v2.8b, v7.8b\\n"

                // Lower part of a3 * {b0,b1,b2,b3}
                "umull v12.8h, v3.8b, v4.8b\\n"
                "umull v13.8h, v3.8b, v5.8b\\n"
                "umull v14.8h, v3.8b, v6.8b\\n"
                "umull v15.8h, v3.8b, v7.8b\\n"

                // Accumulate
                "uadalp v24.4s, v8.8h\\n"
                "uadalp v25.4s, v9.8h\\n"
                "uadalp v26.4s, v10.8h\\n"
                "uadalp v27.4s, v11.8h\\n"
                "uadalp v28.4s, v12.8h\\n"
                "uadalp v29.4s, v13.8h\\n"
                "uadalp v30.4s, v14.8h\\n"
                "uadalp v31.4s, v15.8h\\n"

                // Higher part of a2 * {b0,b1,b2,b3}
                "umull2 v8.8h, v2.16b, v4.16b\\n"
                "umull2 v9.8h, v2.16b, v5.16b\\n"
                "umull2 v10.8h, v2.16b, v6.16b\\n"
                "umull2 v11.8h, v2.16b, v7.16b\\n"

                // Higher part of a3 * {b0,b1,b2,b3}
                "umull2 v12.8h, v3.16b, v4.16b\\n"
                "umull2 v13.8h, v3.16b, v5.16b\\n"
                "umull2 v14.8h, v3.16b, v6.16b\\n"
                "umull2 v15.8h, v3.16b, v7.16b\\n"

                // Accumulate again
                "uadalp v24.4s, v8.8h\\n"
                "uadalp v25.4s, v9.8h\\n"
                "uadalp v26.4s, v10.8h\\n"
                "uadalp v27.4s, v11.8h\\n"
                "uadalp v28.4s, v12.8h\\n"
                "uadalp v29.4s, v13.8h\\n"
                "uadalp v30.4s, v14.8h\\n"
                "uadalp v31.4s, v15.8h\\n"
    """
    blockA = min(64, M * 16)
    blockB = min(64, N * 16)

    cc_code += """
                // Increment pointers and decrement k
                "add %[a_ptr], %[a_ptr], #{0}\\n"
                "add %[b_ptr], %[b_ptr], #{1}\\n"
                "subs %w[k], %w[k], #1\\n"
    """.format(blockA, blockB)

    stepC = min(4, N)

    cc_code += """
                "cbnz %w[k], 1b\\n"

                // Final additions

                // v16 contains the four partial sums of a[0, 0:K].*b[0,0:K], let's call them (a,b,c,d)
                // v17 contains the four partial sums of a[0, 0:K].*b[1,0:K], let's call them (e,f,g,h)
                // v18 contains the four partial sums of a[0, 0:K].*b[2,0:K], let's call them (i,j,k,l)
                // v19 contains the four partial sums of a[0, 0:K].*b[3,0:K], let's call them (m,n,o,p)
                "addp v16.4s, v16.4s, v17.4s\\n" // v16 = (a+b, c+d, e+f, g+h)
                "addp v17.4s, v18.4s, v19.4s\\n" // v17 = (i+j, k+l, m+n, o+p)
                "addp v16.4s, v16.4s, v17.4s\\n" // v16 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)

                // v20 contains the four partial sums of a[1, 0:K].*b[0,0:K], let's call them (a,b,c,d)
                // v21 contains the four partial sums of a[1, 0:K].*b[1,0:K], let's call them (e,f,g,h)
                // v22 contains the four partial sums of a[1, 0:K].*b[2,0:K], let's call them (i,j,k,l)
                // v23 contains the four partial sums of a[1, 0:K].*b[3,0:K], let's call them (m,n,o,p)
                "addp v20.4s, v20.4s, v21.4s\\n" // v20 = (a+b, c+d, e+f, g+h)
                "addp v21.4s, v22.4s, v23.4s\\n" // v21 = (i+j, k+l, m+n, o+p)
                "addp v20.4s, v20.4s, v21.4s\\n" // v20 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)

                // v24 contains the four partial sums of a[2, 0:K].*b[0,0:K], let's call them (a,b,c,d)
                // v25 contains the four partial sums of a[2, 0:K].*b[1,0:K], let's call them (e,f,g,h)
                // v26 contains the four partial sums of a[2, 0:K].*b[2,0:K], let's call them (i,j,k,l)
                // v27 contains the four partial sums of a[2, 0:K].*b[3,0:K], let's call them (m,n,o,p)
                "addp v24.4s, v24.4s, v25.4s\\n"  // v24 = (a+b, c+d, e+f, g+h)
                "addp v25.4s, v26.4s, v27.4s\\n"  // v25 = (i+j, k+l, m+n, o+p)
                "addp v24.4s, v24.4s, v25.4s\\n"  // v24 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)

                // v28 contains the four partial sums of a[3, 0:K].*b[0,0:K], let's call them (a,b,c,d)
                // v29 contains the four partial sums of a[3, 0:K].*b[1,0:K], let's call them (e,f,g,h)
                // v30 contains the four partial sums of a[3, 0:K].*b[2,0:K], let's call them (i,j,k,l)
                // v31 contains the four partial sums of a[3, 0:K].*b[3,0:K], let's call them (m,n,o,p)
                "addp v28.4s, v28.4s, v29.4s\\n" // v28 = (a+b, c+d, e+f, g+h)
                "addp v29.4s, v30.4s, v31.4s\\n" // v29 = (i+j, k+l, m+n, o+p)
                "addp v28.4s, v28.4s, v29.4s\\n" // v28 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)

                "str q16, [%[c_ptr]]\\n"
            """

    if M > 1:
        cc_code += ' "str q20, [%[c_ptr], #{0}]\\n" '.format(stepC * 4)

    if M > 2:
        cc_code += ' "str q24, [%[c_ptr], #{0}]\\n" '.format(stepC * 8)

    if M > 3:
        cc_code += ' "str q28, [%[c_ptr], #{0}]\\n" '.format(stepC * 12)

    cc_code += """
             : [c_ptr] "+r" (c_ptr), [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [k] "+r" (k)
             :
             : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
                    "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
                    "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
                    "v27", "v28", "v29", "v30", "v31"
             );
        return 0;
        }
    """

    if data_type == 'int8':
        cc_code = cc_code.replace('unsigned char', 'char')
        cc_code = cc_code.replace('umull', 'smull')
        cc_code = cc_code.replace('uadalp', 'sadalp')

    temp = util.tempdir()
    ll_path = temp.relpath("temp.ll")
    # Create LLVM ir from c source code
    ll_code = clang.create_llvm(
        cc_code,
        options=["-mtriple=aarch64-linux-gnu -mattr=+neon"],
        output=ll_path)
    return ll_code
Example #6
0
def gemm_quantized_impl(M, N, K, unroll, interleave, data_type="uint8"):
    """Assembly implementation of a blocked gemv. Given
    a block a of shape (4, k) and a block b' of shape (4, k)
    produces the output block c = a*b of shape (4,4)"""

    stepA = min(4, M)
    stepB = min(4, N)
    assert data_type in ["uint8", "int8"
                         ], "Only uint8/int8 supported for this implementation"

    signature = """extern "C" int gemm_quantized_{0}_{0}_int32_{1}_{2}""".format(
        data_type, stepA, stepB)
    if unroll:
        signature += "_" + str(K)

    if interleave:
        signature += "_interleaved"

    signature += """(int *c_buffer,
                      unsigned char *a_buffer,
                      unsigned char *b_buffer,
                      int K, int m, int n)"""

    cc_code = signature
    cc_code += """
    {
            unsigned char * a_ptr = a_buffer;
            unsigned char * b_ptr = b_buffer;
            int * c_ptr = c_buffer;

            int k = K / 16;

            __asm__  __volatile__ (
                "movi v16.4s, #0\\n"
                "movi v17.4s, #0\\n"
                "movi v18.4s, #0\\n"
                "movi v19.4s, #0\\n"
                "movi v20.4s, #0\\n"
                "movi v21.4s, #0\\n"
                "movi v22.4s, #0\\n"
                "movi v23.4s, #0\\n"
                "movi v24.4s, #0\\n"
                "movi v25.4s, #0\\n"
                "movi v26.4s, #0\\n"
                "movi v27.4s, #0\\n"
                "movi v28.4s, #0\\n"
                "movi v29.4s, #0\\n"
                "movi v30.4s, #0\\n"
                "movi v31.4s, #0\\n"
            "1:"
    """

    main_loop = ' "ldr q0, [%[a_ptr]]\\n" '

    if M > 1:
        main_loop += ' "ldr q1, [%[a_ptr], #16]\\n" '
    else:
        main_loop += ' "movi v1.4s, #0\\n" '

    if M > 2:
        main_loop += ' "ldr q2, [%[a_ptr], #32]\\n" '
    else:
        main_loop += ' "movi v2.4s, #0\\n" '

    if M > 3:
        main_loop += ' "ldr q3, [%[a_ptr], #48]\\n" '
    else:
        main_loop += ' "movi v3.4s, #0\\n" '

    main_loop += ' "ldr q4, [%[b_ptr]]\\n" '

    if N > 1:
        main_loop += ' "ldr q5, [%[b_ptr], #16]\\n" '

    if N > 2:
        main_loop += ' "ldr q6, [%[b_ptr], #32]\\n" '

    if N > 3:
        main_loop += ' "ldr q7, [%[b_ptr], #48]\\n" '

    # Main computation can interleave multiply/accumulate instructions
    # or schedule them in batches (first all multiplies then all accumulates)
    if interleave:
        main_loop += gemm_quantized_4_4_interleaved()
    else:
        main_loop += gemm_quantized_4_4_batched()

    blockA = min(64, M * 16)
    blockB = min(64, N * 16)
    main_loop += """// Increment pointers
                    "add %[a_ptr], %[a_ptr], #{0}\\n"
                    "add %[b_ptr], %[b_ptr], #{1}\\n" """.format(
        blockA, blockB)

    if unroll:
        k = int(K // 16)
        for l in range(0, k):
            cc_code += main_loop
    else:
        cc_code += main_loop
        cc_code += """
                    "subs %w[k], %w[k], #1\\n"
                    "cbnz %w[k], 1b\\n"
                   """
    cc_code += """
                // Final additions

                // v16 contains the four partial sums of a[0, 0:K].*b[0,0:K], let's call them (a,b,c,d)
                // v17 contains the four partial sums of a[0, 0:K].*b[1,0:K], let's call them (e,f,g,h)
                // v18 contains the four partial sums of a[0, 0:K].*b[2,0:K], let's call them (i,j,k,l)
                // v19 contains the four partial sums of a[0, 0:K].*b[3,0:K], let's call them (m,n,o,p)
                "addp v16.4s, v16.4s, v17.4s\\n" // v16 = (a+b, c+d, e+f, g+h)
                "addp v17.4s, v18.4s, v19.4s\\n" // v17 = (i+j, k+l, m+n, o+p)
                "addp v16.4s, v16.4s, v17.4s\\n" // v16 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)

                // v20 contains the four partial sums of a[1, 0:K].*b[0,0:K], let's call them (a,b,c,d)
                // v21 contains the four partial sums of a[1, 0:K].*b[1,0:K], let's call them (e,f,g,h)
                // v22 contains the four partial sums of a[1, 0:K].*b[2,0:K], let's call them (i,j,k,l)
                // v23 contains the four partial sums of a[1, 0:K].*b[3,0:K], let's call them (m,n,o,p)
                "addp v20.4s, v20.4s, v21.4s\\n" // v20 = (a+b, c+d, e+f, g+h)
                "addp v21.4s, v22.4s, v23.4s\\n" // v21 = (i+j, k+l, m+n, o+p)
                "addp v20.4s, v20.4s, v21.4s\\n" // v20 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)

                // v24 contains the four partial sums of a[2, 0:K].*b[0,0:K], let's call them (a,b,c,d)
                // v25 contains the four partial sums of a[2, 0:K].*b[1,0:K], let's call them (e,f,g,h)
                // v26 contains the four partial sums of a[2, 0:K].*b[2,0:K], let's call them (i,j,k,l)
                // v27 contains the four partial sums of a[2, 0:K].*b[3,0:K], let's call them (m,n,o,p)
                "addp v24.4s, v24.4s, v25.4s\\n"  // v24 = (a+b, c+d, e+f, g+h)
                "addp v25.4s, v26.4s, v27.4s\\n"  // v25 = (i+j, k+l, m+n, o+p)
                "addp v24.4s, v24.4s, v25.4s\\n"  // v24 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)

                // v28 contains the four partial sums of a[3, 0:K].*b[0,0:K], let's call them (a,b,c,d)
                // v29 contains the four partial sums of a[3, 0:K].*b[1,0:K], let's call them (e,f,g,h)
                // v30 contains the four partial sums of a[3, 0:K].*b[2,0:K], let's call them (i,j,k,l)
                // v31 contains the four partial sums of a[3, 0:K].*b[3,0:K], let's call them (m,n,o,p)
                "addp v28.4s, v28.4s, v29.4s\\n" // v28 = (a+b, c+d, e+f, g+h)
                "addp v29.4s, v30.4s, v31.4s\\n" // v29 = (i+j, k+l, m+n, o+p)
                "addp v28.4s, v28.4s, v29.4s\\n" // v28 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)

                "str q16, [%[c_ptr]]\\n"
            """

    stepC = min(4, N)
    if M > 1:
        cc_code += ' "str q20, [%[c_ptr], #{0}]\\n" '.format(stepC * 4)

    if M > 2:
        cc_code += ' "str q24, [%[c_ptr], #{0}]\\n" '.format(stepC * 8)

    if M > 3:
        cc_code += ' "str q28, [%[c_ptr], #{0}]\\n" '.format(stepC * 12)

    cc_code += """
             : [c_ptr] "+r" (c_ptr), [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [k] "+r" (k)
             :
             : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
                    "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
                    "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
                    "v27", "v28", "v29", "v30", "v31"
             );
        return 0;
        }
    """

    if data_type == "int8":
        cc_code = cc_code.replace("unsigned char", "char")
        cc_code = cc_code.replace("umull", "smull")
        cc_code = cc_code.replace("uadalp", "sadalp")

    temp = utils.tempdir()
    ll_path = temp.relpath("temp.ll")
    # Create LLVM ir from c source code
    ll_code = clang.create_llvm(
        cc_code,
        options=["--target=aarch64-linux-gnu -mattr=+neon"],
        output=ll_path)
    return ll_code
Example #7
0
def _c_to_llvm(c_code: str) -> str:
    unique_filename = str(uuid.uuid4())
    temp = utils.tempdir()
    ll_path = temp.relpath(f"{unique_filename}.ll")
    ll_code = clang.create_llvm([c_code], output=ll_path)
    return ll_code