def transformed_three_stage_compute(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]) -> None: for tx in T.thread_binding(16, thread="threadIdx.x"): with T.block(): T.reads(A[tx, 0:16]) T.writes(D[tx, 0:16]) B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") C = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") with T.block(): T.reads(A[tx, 0:2], B[0:2, tx, 0]) T.writes(B[0:2, tx, 0], C[0:2, tx, 0]) for i in T.unroll(2): with T.block(): T.reads(A[tx, i]) T.writes(B[0:2, tx, 0]) B[i, tx, 0] = A[tx, i] * T.float32(2) with T.block(): T.where(1 <= i) T.reads(B[0:2, tx, 0]) T.writes(C[0:2, tx, 0]) C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2) with T.block(): T.reads(A[tx, 2:16], B[0:2, tx, 0], C[0:2, tx, 0]) T.writes(B[0:2, tx, 0], C[0:2, tx, 0], D[tx, 0:14]) for i in T.serial(14): with T.block(): T.reads(A[tx, i + 2]) T.writes(B[0:2, tx, 0]) B[i % 2, tx, 0] = A[tx, i + 2] * T.float32(2) with T.block(): T.reads(B[0:2, tx, 0]) T.writes(C[0:2, tx, 0]) C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2) with T.block(): T.reads(C[0:2, tx, 0]) T.writes(D[tx, i]) D[tx, i] = C[i % 2, tx, 0] + T.float32(1) with T.block(): T.reads(B[0:2, tx, 0], C[0:2, tx, 0]) T.writes(C[0:2, tx, 0], D[tx, 14:16]) for i in T.unroll(2): with T.block(): T.where(i < 1) T.reads(B[0:2, tx, 0]) T.writes(C[0:2, tx, 0]) C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2) with T.block(): T.reads(C[0:2, tx, 0]) T.writes(D[tx, i + 14]) D[tx, i + 14] = C[i, tx, 0] + T.float32(1)
def cuda_matmul_2(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=undefined-loop-variable A = T.match_buffer(a, [2048, 2048], "float32") B = T.match_buffer(b, [2048, 2048], "float32") C = T.match_buffer(c, [2048, 2048], "float32") A_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") B_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") with T.block([2048, 2048], "A_shared") as [v0, v1]: A_shared[v0, v1] = A[v0, v1] with T.block([2048, 2048], "B_shared") as [v0, v1]: B_shared[v0, v1] = B[v0, v1] with T.block([2048, 2048], "B_shared_local") as [v0, v1]: B_shared_local[v0, v1] = B_shared[v0, v1] for by in T.thread_binding(0, 32, thread="blockIdx.y"): for bx in T.thread_binding(0, 32, thread="blockIdx.x"): for vy in T.thread_binding(0, 2, thread="vthread.y"): for vx in T.thread_binding(0, 2, thread="vthread.x"): for ty in T.thread_binding(0, 8, thread="threadIdx.y"): for tx in T.thread_binding(0, 8, thread="threadIdx.x"): for k_0 in T.serial(0, 256): for k_1 in T.unroll(0, 8): for i, j in T.grid(1, 4): with T.block( [2048, 2048], "A_shared_local") as [v0, v1]: T.bind(v0, k_0 * 8 + k_1 + i) T.bind( v1, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] for _, i, j in T.grid(1, 4, 4): with T.block([ 2048, 2048, T.reduce_axis(0, 2048) ], "C") as [vi, vj, vk]: T.bind( vi, by * 64 + vy * 32 + ty * 4 + i) T.bind( vj, bx * 64 + vx * 32 + tx * 4 + j) T.bind(vk, k_0 * 8 + k_1) with T.init(): C_local[vi, vj] = T.float32(0) C_local[vi, vj] = C_local[ vi, vj] + A_shared_local[ vk, vi] * B_shared_local[ vk, vj] for i, j in T.grid(4, 4): with T.block([2048, 2048], "C_local") as [v0, v1]: T.bind(v0, by * 64 + vy * 32 + ty * 4 + i) T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1]
def loop_syntax_sugar(a: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) for i in T.serial(128): for j in T.parallel(128): for k in T.vectorized(128): for x in T.unroll(128): for y in T.thread_binding(128, "threadIdx.x"): for z in T.thread_binding(128, thread="threadIdx.x"): A[i, j, k, x] = A[i, j, k, x] * 2.0
def rowsum_unrolled(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128,)) for i0 in T.unroll(0, 128): for i1 in T.serial(0, 128): with T.block("B"): vi, vk = T.axis.remap("SR", [i0, i1]) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk]
def rowsum_unrolled(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128,)) for i0 in T.unroll(0, 128): for i1 in T.serial(0, 128): with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: T.bind(vi, i0) T.bind(vk, i1) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk]
def after_unrolled_loop( placeholder: T.Buffer[(1, 56, 56, 64), "float32"], ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") bgemm = T.alloc_buffer([6, 6, 196, 64], dtype="float32") inverse = T.alloc_buffer([4, 4, 196, 64], dtype="float32") for i2_0_i3_0_i2_1_i3_1_fused_0 in T.thread_binding(13, thread="blockIdx.x"): for i2_0_i3_0_i2_1_i3_1_fused_1 in T.thread_binding( 1024, thread="threadIdx.x"): for i0 in T.unroll(4): for i1 in T.unroll(4): for i4 in T.unroll(6): for i5 in T.unroll(6): with T.block("inverse"): vh, vw = T.axis.remap("SS", [i0, i1]) p = T.axis.spatial( 196, (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) // 128 * 2 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) % 32 // 16, ) co = T.axis.spatial( 64, (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) % 128 // 32 * 16 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1) % 16, ) r_a, r_b = T.axis.remap("RR", [i4, i5]) T.where(i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 + i2_0_i3_0_i2_1_i3_1_fused_1 < 12544) T.reads(bgemm[r_a, r_b, p, co]) T.writes(inverse[vh, vw, p, co]) with T.init(): inverse[vh, vw, p, co] = T.float32(0) inverse[vh, vw, p, co] = (inverse[vh, vw, p, co] + bgemm[r_a, r_b, p, co])
def cuda_matmul_5(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=undefined-loop-variable A = T.match_buffer(a, [2048, 2048], "float32") B = T.match_buffer(b, [2048, 2048], "float32") C = T.match_buffer(c, [2048, 2048], "float32") A_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") B_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") for by in T.thread_binding(0, 32, thread = "blockIdx.y"): for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): for vy in T.thread_binding(0, 2, thread = "vthread.y"): for vx in T.thread_binding(0, 2, thread = "vthread.x"): for ty in T.thread_binding(0, 8, thread = "threadIdx.y"): for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): for k0 in T.serial(0, 256): for i, j in T.grid(8, 64): with T.block("A_shared"): v0 = T.axis.S(2048, k0 * 8 + i) v1 = T.axis.S(2048, by * 64 + j) A_shared[v0, v1] = A[v0, v1] for i, j in T.grid(8, 64): with T.block("B_shared"): v0 = T.axis.S(2048, k0 * 8 + i) v1 = T.axis.S(2048, bx * 64 + j) B_shared[v0, v1] = B[v0, v1] for k1 in T.unroll(0, 8): for i, j in T.grid(1, 4): with T.block("A_shared_local"): v0 = T.axis.S(2048, k0 * 8 + k1 + i) v1 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] for i, j in T.grid(1, 4): with T.block("B_shared_local"): v0 = T.axis.S(2048, k0 * 8 + k1 + i) v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) B_shared_local[v0, v1] = B_shared[v0, v1] for _, i, j in T.grid(1, 4, 4): with T.block("C"): vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) vk = T.axis.R(2048, k0 * 8 + k1) with T.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] for i, j in T.grid(4, 4): with T.block("C_local"): v0 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1]
def before_unrolled_loop( placeholder: T.Buffer[(1, 56, 56, 64), "float32"], ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) bgemm = T.alloc_buffer([6, 6, 196, 64], dtype="float32") inverse = T.alloc_buffer([4, 4, 196, 64], dtype="float32") for i2_0, i3_0, i2_1, i3_1 in T.grid(98, 4, 2, 16): for i0 in T.unroll(4): for i1 in T.unroll(4): for i4 in T.unroll(6): for i5 in T.unroll(6): with T.block("inverse"): vh, vw = T.axis.remap("SS", [i0, i1]) p = T.axis.spatial(196, i2_0 * 2 + i2_1) co = T.axis.spatial(64, i3_0 * 16 + i3_1) r_a, r_b = T.axis.remap("RR", [i4, i5]) T.reads(bgemm[r_a, r_b, p, co]) T.writes(inverse[vh, vw, p, co]) with T.init(): inverse[vh, vw, p, co] = T.float32(0) inverse[vh, vw, p, co] = inverse[vh, vw, p, co] + bgemm[r_a, r_b, p, co]
def main(): for i in T.unroll(2): with T.allocate([16], "float32", "global") as buf: buf[0] = 0.0