import KgeN from KgeN import te # 1. not vthread # M = 128 # A = te.placeholder((M, ), name= "A") # B = te.compute((M, ), lambda i: A[i], name="B") # C = te.compute((M, ), lambda i: B[i], name="C") # s = te.create_schedule(C.op) # x, = s[C].op.axis # xo, xi = s[C].split(x, factor=4) # s[C].reorder(xi, xo) # s[B].compute_at(s[C], xi) # tir = str(KgeN.lower(s, [A, C])) # print(tir) # 2. vthread M = 1024 A = te.placeholder((M, ), name="A") B = te.compute((M, ), lambda i: A[i], name="B") C = te.compute((M, ), lambda i: B[i], name="C") s = te.create_schedule(C.op) x, = s[C].op.axis xo, xi = s[C].split(x, factor=64) xio, xii = s[C].split(xi, factor=2) s[C].bind(xo, te.thread_axis("vthread", name="vx")) # s[C].bind(xio, te.thread_axis("vthread", name="vy")) s[B].compute_at(s[C], xio) tir = str(KgeN.lower(s, [A, C])) print(tir)
import KgeN from KgeN import te m = 8 n = 64 A = te.placeholder((m, n), name="A") B = te.compute((m, n), lambda i, j: 2 + A[i, j], name="B") # schedule s = te.create_schedule(B.op) ax = te.thread_axis(8, "threadIdx.x") s[B].bind(s[B].op.axis[0], ax) # lower func = KgeN.lower(s, [A, B]) print(str(func))
A = te.placeholder((M, K), name="A") B = te.placeholder((K, N), name="B") k = te.reduce_axis(K, name="k") C = te.compute((M, N), lambda i, j: te.reduce_sum(A[i, k] * B[k, j], axis=k), name="C") s = te.create_schedule(C.op) AA = s.cache_read(A, "shared", [C]) BB = s.cache_read(B, "shared", [C]) AAA = s.cache_read(AA, "local", [C]) BBB = s.cache_read(BB, "local", [C]) CCC = s.cache_write(C, "local") block_x = te.thread_axis("blockIdx.x") block_y = te.thread_axis("blockIdx.y") thread_x = te.thread_axis("threadIdx.x") thread_y = te.thread_axis("threadIdx.y") M, N = s[C].op.axis Mo, Mi = s[C].split(M, 4) No, Ni = s[C].split(N, 4) Bx, Tx = s[C].split(Mo, 4) By, Ty = s[C].split(No, 4) s[C].reorder(Bx, By, Tx, Ty, Mi, Ni) AM, AK = s[AA].op.axis BK, BN = s[BB].op.axis ATx, _ = s[AA].split(AM, 4)
s = te.create_schedule(B.op) AA = s.cache_read(Apad, "shared", [B]) WW = s.cache_read(W, "shared", [B]) AL = s.cache_read(AA, "local", [B]) WL = s.cache_read(WW, "local", [B]) BL = s.cache_write(B, "local") s[Apad].compute_inline() tile = 8 num_thread = 8 block_factor = 64 step = 8 # Get the GPU thread indices block_x = te.thread_axis("blockIdx.x") block_y = te.thread_axis("blockIdx.y") block_z = te.thread_axis("blockIdx.z") thread_x = te.thread_axis(num_thread, "threadIdx.x") thread_y = te.thread_axis(num_thread, "threadIdx.y") hi, wi, fi, ni = s[B].op.axis bz = s[B].fuse(hi, wi) by, fi = s[B].split(fi, factor=block_factor) bx, ni = s[B].split(ni, factor=block_factor) ty, fi = s[B].split(fi, nparts=num_thread) tx, ni = s[B].split(ni, nparts=num_thread) s[B].reorder(bz, by, bx, ty, tx, fi, ni) # Bind the iteration variables to GPU thread indices