def lowered_two_bound_loops(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.serial(0, 128): for ko in T.thread_binding(0, 4, thread="threadIdx.x"): for ki in T.thread_binding(0, 32, thread="threadIdx.y"): with T.block("B_cross_thread_reduction"): vi = T.axis.spatial(128, i) vk = T.axis.reduce(128, ko * 32 + ki) T.reads([A[vi, vk]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce(T.uint32(1), A[vi, vk], True, reduce_temp0.data, ko, ki, dtype="handle")) with T.block("B_write_back"): vi = T.axis.spatial(128, i) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0]
def main(a: T.handle, b: T.handle) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) A = T.match_buffer(a, (128, 128, 4), dtype="float32", scope="global.texture") B = T.alloc_buffer((128, 128, 4), dtype="float32", scope="global.texture") C = T.match_buffer(b, (128, 128, 4), dtype="float32", scope="global.texture") for block_idx in T.thread_binding(0, 128, thread="blockIdx.x"): for thread_idx in T.thread_binding(0, 128, thread="threadIdx.x"): for k in T.serial(4): with T.block("B"): vb, vt, vk = T.axis.remap( "SSS", [block_idx, thread_idx, k]) B[vb, vt, vk] = A[vb, vt, vk] + T.float32(1) for block_idx in T.thread_binding(0, 128, thread="blockIdx.x"): for thread_idx in T.thread_binding(0, 128, thread="threadIdx.x"): for k in T.serial(4): with T.block("C"): vb, vt, vk = T.axis.remap( "SSS", [block_idx, thread_idx, k]) C[vb, vt, vk] = B[vb, vt, vk] * T.float32(2)
def main(placeholder: T.Buffer[(12, 64, 64), "float32"], T_reshape: T.Buffer[(64, 768), "float32"]) -> None: for i0_i1_fused_0_i0_i1_fused_1_fused_0 in T.thread_binding( 48, thread="blockIdx.x"): for i0_i1_fused_0_i0_i1_fused_1_fused_1 in T.thread_binding( 1024, thread="threadIdx.x"): with T.block("T_reshape_1"): ax0 = T.axis.spatial( 64, ((i0_i1_fused_0_i0_i1_fused_1_fused_0 * 1024 + i0_i1_fused_0_i0_i1_fused_1_fused_1) // 32 * 32 + (i0_i1_fused_0_i0_i1_fused_1_fused_0 * 1024 + i0_i1_fused_0_i0_i1_fused_1_fused_1) % 32) // 768, ) ax1 = T.axis.spatial( 768, ((i0_i1_fused_0_i0_i1_fused_1_fused_0 * 1024 + i0_i1_fused_0_i0_i1_fused_1_fused_1) // 32 * 32 + (i0_i1_fused_0_i0_i1_fused_1_fused_0 * 1024 + i0_i1_fused_0_i0_i1_fused_1_fused_1) % 32) % 768, ) T.reads(placeholder[ax1 % 768 // 64, (ax1 // 768 + ax0) % 64, ax1 % 64]) T.writes(T_reshape[ax0, ax1]) T_reshape[ax0, ax1] = placeholder[ ((ax1 % 64 // 64 + (ax1 // 768 + ax0) % 64) // 64 + ax1 % 768 // 64) % 12, (ax1 % 64 // 64 + (ax1 // 768 + ax0) % 64) % 64, ax1 % 64 % 64, ]
def lowered_multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16, 16], dtype="float32") B = T.match_buffer(b, [16], dtype="float32") B_rf_local = T.alloc_buffer([16, 16], dtype="float32", scope="local") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.thread_binding(0, 16, thread="blockIdx.x"): for k0o in T.thread_binding(0, 4, thread="threadIdx.x"): with T.block("B_in_thread_init"): T.reads([]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.float32(0) for k0i0, k1 in T.grid(4, 16): with T.block("B_rf"): vk0 = T.axis.spatial(16, k0o * 4 + k0i0) vi, vk1 = T.axis.remap("SR", [i, k1]) T.reads([B_rf_local[vk0, vi], A[vi, vk0, vk1]]) T.writes([B_rf_local[vk0, vi]]) with T.init(): B_rf_local[vk0, vi] = T.float32(0) B_rf_local[vk0, vi] = B_rf_local[vk0, vi] + A[vi, vk0, vk1] for k0i1 in T.serial(0, 4): with T.block("B_normal_reduction"): vk0 = T.axis.reduce(16, k0o * 4 + k0i1) vi = T.axis.spatial(16, i) T.reads([B_rf_local[vk0, vi], normal_reduce_temp0[0]]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[ 0] = normal_reduce_temp0[0] + B_rf_local[vk0, vi] with T.block("B_cross_thread_reduction"): T.reads([normal_reduce_temp0[0]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), normal_reduce_temp0[0], True, reduce_temp0.data, k0o, dtype="handle", )) with T.block("B_write_back"): vi = T.axis.spatial(16, i) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0]
def softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: A = T.match_buffer(var_A, [256, 256], dtype="float32") T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256], dtype="float32") T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") for i0 in T.thread_binding(0, 256, thread="blockIdx.x"): for ax0_0 in T.serial(0, 8): for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_maxelem"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads([T_softmax_maxelem_shared[i0_1], A[i0_1, k]]) T.writes([T_softmax_maxelem_shared[i0_1]]) with T.init(): T_softmax_maxelem_shared[i0_1] = T.min_value("float32") T_softmax_maxelem_shared[i0_1] = T.max( T_softmax_maxelem_shared[i0_1], A[i0_1, k] ) for ax0_0 in T.serial(0, 8): for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_expsum"): i0_2 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads( [ T_softmax_expsum_shared[i0_2], A[i0_2, k], T_softmax_maxelem_shared[i0_2], ] ) T.writes([T_softmax_expsum_shared[i0_2]]) with T.init(): T_softmax_expsum_shared[i0_2] = T.float32(0) T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp( A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32" ) for i1_0 in T.serial(0, 8): for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_norm"): i0_3 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_0 * 32 + i1_1) T.reads( [ A[i0_3, i1], T_softmax_maxelem_shared[i0_3], T_softmax_expsum_shared[i0_3], ] ) T.writes([T_softmax_norm[i0_3, i1]]) T.block_attr({"axis": 1}) T_softmax_norm[i0_3, i1] = ( T.exp( A[i0_3, i1] - T_softmax_maxelem_shared[i0_3], dtype="float32", ) / T_softmax_expsum_shared[i0_3] )
def main(X: T.Buffer[(1, 512, 56, 56), "float32"], W: T.Buffer[(512, 512, 3, 3), "float32"], B: T.Buffer[(512, 1, 1), "float32"], bn_scale: T.Buffer[(512, 1, 1), "float32"], bn_offset: T.Buffer[(512, 1, 1), "float32"], compute: T.Buffer[(1, 512, 56, 56), "float32"]) -> None: compute_local = T.alloc_buffer([1, 512, 56, 56], dtype="float32", scope="local") for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(224, thread="blockIdx.x"): for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding( 2, thread="vthread.x"): for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding( 8, thread="threadIdx.x"): for i4_0, i5_0, i6_0, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid( 1, 3, 1, 32, 1, 1, 1, 1, 1, 1, 16, 1, 3, 1, 8, 2, 28): with T.block("compute"): nn = T.axis.spatial(1, 0) ff = T.axis.spatial( 512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + i1_4) yy = T.axis.spatial( 56, i0_0_i1_0_i2_0_i3_0_fused // 2 % 7 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + i2_4) xx = T.axis.spatial( 56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + i3_4) rc = T.axis.reduce(512, i4_1 * 16 + i4_2) ry, rx = T.axis.remap("RR", [i5_0, i6_2]) with T.init(): compute_local[nn, ff, yy, xx] = T.float32(0) compute_local[nn, ff, yy, xx] = compute_local[ nn, ff, yy, xx] + T.if_then_else( yy + ry >= 1 and yy + ry < 57 and xx + rx >= 1 and xx + rx < 57, X[nn, rc, yy + ry - 1, xx + rx - 1], T.float32(0), dtype="float32") * W[ff, rc, ry, rx] for ax0, ax1, ax2, ax3 in T.grid(1, 8, 2, 28): with T.block("compute_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial( 512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + ax1) v2 = T.axis.spatial( 56, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + ax2) v3 = T.axis.spatial( 56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + ax3) compute[v0, v1, v2, v3] = T.max( (compute_local[v0, v1, v2, v3] + B[v1, 0, 0]) * bn_scale[v1, 0, 0] + bn_offset[v1, 0, 0], T.float32(0))
def loop_syntax_sugar(a: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) for i in T.serial(128): for j in T.parallel(128): for k in T.vectorized(128): for x in T.unroll(128): for y in T.thread_binding(128, "threadIdx.x"): for z in T.thread_binding(128, thread="threadIdx.x"): A[i, j, k, x] = A[i, j, k, x] * 2.0
def GmmCuda2(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None: Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local") X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"): for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"): for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"): for i1_3_init, i2_4_init in T.grid(4, 2): with T.block("Z_init"): b = T.axis.spatial(1, 0) i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init) j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init) T.reads() T.writes(Z_local[b, i, j]) Z_local[b, i, j] = T.float32(0) for i3_0 in T.serial(4): for ax0_ax1_ax2_fused_0 in T.serial(4): for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in T.vectorized(2): with T.block("X_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32) v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32) T.reads(X[v0, v1, v2]) T.writes(X_shared[v0, v1, v2]) X_shared[v0, v1, v2] = X[v0, v1, v2] for ax0_ax1_ax2_fused_0 in T.serial(8): for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): with T.block("Y_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32) v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32) T.reads(Y[v0, v1, v2]) T.writes(Y_shared[v0, v1, v2]) Y_shared[v0, v1, v2] = Y[v0, v1, v2] for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2): with T.block("Z_update"): b = T.axis.spatial(1, 0) i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3) j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4) k = T.axis.reduce(128, i3_0 * 32 + i3_2) T.block_attr({ "meta_schedule.thread_extent_low_inclusive": 1024, "meta_schedule.thread_extent_high_inclusive": 1024, }) T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j]) T.writes(Z_local[b, i, j]) Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j] for ax0, ax1, ax2 in T.grid(1, 4, 2): with T.block("Z_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1) v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2) T.reads(Z_local[v0, v1, v2]) T.writes(Z[v0, v1, v2]) Z[v0, v1, v2] = Z_local[v0, v1, v2]
def unified_element_wise_vthread_x(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) for vthread_x in T.thread_binding(0, 2, "vthread.x"): for threadIdx_x in T.thread_binding(0, 64, "threadIdx.x"): for j_1 in T.serial(0, 64): with T.block(""): B[vthread_x * 64 + threadIdx_x, vthread_x * 64 + j_1] = ( A[vthread_x * 64 + threadIdx_x, vthread_x * 64 + j_1] * 2.0)
def main(A: T.Buffer[8, "float32"], B: T.Buffer[8, "float32"]) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i_0 in T.thread_binding(2, thread="blockIdx.x"): for i_2 in T.thread_binding(2, thread="threadIdx.x"): for i_1 in T.serial(2): with T.block("B"): vi = T.axis.spatial(8, i_0 * 4 + i_1 * 2 + i_2) T.reads(A[vi]) T.writes(B[vi]) B[vi] = A[vi] + T.float32(1)
def element_wise_vthread_x(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) for i_0 in T.thread_binding(0, 2, "vthread.x"): for i_1 in T.thread_binding(0, 64, "threadIdx.x"): for j_0 in T.thread_binding(0, 2, "vthread.x"): for j_1 in T.serial(0, 64): with T.block(""): B[i_0 * 64 + i_1, j_0 * 64 + j_1] = A[i_0 * 64 + i_1, j_0 * 64 + j_1] * 2.0
def main(var_A: T.handle, var_B: T.handle) -> None: A = T.match_buffer(var_A, [512, 512], dtype="float32") B = T.match_buffer(var_B, [512, 512], dtype="float32") for i_j_fused_0 in T.thread_binding(0, 8192, thread="blockIdx.x"): for i_j_fused_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("C"): vi = T.axis.spatial( 512, (i_j_fused_0 * 32 + i_j_fused_1) // 512) vj = T.axis.spatial(512, (i_j_fused_0 * 32 + i_j_fused_1) % 512) B[vi, vj] = A[vi, vj] + 1.0
def element_wise_two_thread_x_in_same_kernel_not_equal(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 64]) for i in T.thread_binding(0, 128, "blockIdx.x"): for j0 in T.thread_binding(0, 128, "threadIdx.x"): B[i, j0] = A[i, j0] * 2.0 for j1 in T.thread_binding(0, 64, "threadIdx.x"): C[i, j1] = A[i, j1] + 1.0
def element_wise_kernels_with_different_size(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [256, 256]) D = T.match_buffer(d, [256, 256]) for i0 in T.thread_binding(0, 128, "blockIdx.x"): for j0 in T.thread_binding(0, 128, "threadIdx.x"): B[i0, j0] = A[i0, j0] * 2.0 for i1 in T.thread_binding(0, 256, "blockIdx.x"): for j1 in T.thread_binding(0, 256, "threadIdx.x"): D[i1, j1] = C[i1, j1] + 1.0
def unified_element_wise_kernels_with_different_size(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [256, 256]) D = T.match_buffer(d, [256, 256]) for blockIdx_x in T.thread_binding(0, 128, "blockIdx.x"): for threadIdx_x in T.thread_binding(0, 128, "threadIdx.x"): B[blockIdx_x, threadIdx_x] = A[blockIdx_x, threadIdx_x] * 2.0 for blockIdx_x in T.thread_binding(0, 256, "blockIdx.x"): for threadIdx_x in T.thread_binding(0, 256, "threadIdx.x"): D[blockIdx_x, threadIdx_x] = C[blockIdx_x, threadIdx_x] + 1.0
def element_wise_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i in T.thread_binding(0, 128, "blockIdx.x"): for j0_0 in T.thread_binding(0, 4, "threadIdx.x"): for j0_1 in T.serial(0, 32): with T.block(""): B[i, j0_0 * 32 + j0_1] = A[i, j0_0 * 32 + j0_1] * 2.0 for j1_0 in T.thread_binding(0, 4, "threadIdx.x"): for j1_1 in T.serial(0, 32): with T.block(""): C[i, j1_0 * 32 + j1_1] = B[i, j1_0 * 32 + j1_1] + 1.0
def single_reduction_loop_with_block_predicate( A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None: T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") for i0 in T.serial(256): for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_maxelem"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) T.reads(T_softmax_maxelem_shared[i0_1], A[i0_1, k]) T.writes(T_softmax_maxelem_shared[i0_1]) with T.init(): T_softmax_maxelem_shared[i0_1] = T.float32( -3.4028234663852886e38) T_softmax_maxelem_shared[i0_1] = T.max( T_softmax_maxelem_shared[i0_1], A[i0_1, k]) for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_expsum"): i0_2 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) T.reads(T_softmax_expsum_shared[i0_2], A[i0_2, k], T_softmax_maxelem_shared[i0_2]) T.writes(T_softmax_expsum_shared[i0_2]) with T.init(): T_softmax_expsum_shared[i0_2] = T.float32(0) T_softmax_expsum_shared[ i0_2] = T_softmax_expsum_shared[i0_2] + T.exp( A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32") for i1_0 in T.serial(1): for i1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_norm"): i0_3 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_1) T.where(i1_0 * 512 + i1_1 < 256) T.reads(A[i0_3, i1], T_softmax_maxelem_shared[i0_3], T_softmax_expsum_shared[i0_3]) T.writes(T_softmax_norm[i0_3, i1]) T.block_attr({"axis": 1}) T_softmax_norm[i0_3, i1] = ( T.exp(A[i0_3, i1] - T_softmax_maxelem_shared[i0_3], dtype="float32") / T_softmax_expsum_shared[i0_3])
def thread_bound_nested_block( A: T.Buffer[(16, 16, 16, 16), "float32"], B: T.Buffer[(16, 16, 16), "float32"] ) -> None: for i in T.serial(16): for j in T.thread_binding(16, thread="blockIdx.x"): with T.block("outer"): vi, vj = T.axis.remap("SS", [i, j]) for k in T.serial(16): for l in T.thread_binding(16, thread="threadIdx.x"): with T.block("inner"): vk, vl = T.axis.remap("SR", [k, l]) with T.init(): B[vi, vj, vk] = T.float32(0) B[vi, vj, vk] = B[vi, vj, vk] + A[vi, vj, vk, vl]
def warp_memory(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) C = T.match_buffer(c, [128, 128]) B = T.alloc_buffer([128, 4, 32], scope="warp") for i_o in T.thread_binding(0, 4, thread="threadIdx.y"): for i_i in T.thread_binding(0, 32, thread="threadIdx.x"): for j in T.serial(0, 128): with T.block("B"): warp_id, lane_id, vj = T.axis.remap("SSS", [i_o, i_i, j]) B[vj, warp_id, lane_id] = A[warp_id * 32 + lane_id, vj] * 2.0 for j in T.serial(0, 128): with T.block("C"): warp_id, lane_id, vj = T.axis.remap("SSS", [i_o, i_i, j]) C[warp_id * 32 + lane_id, vj] = B[vj, warp_id, lane_id] + 1.0
def two_bound_loops(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") for i in T.serial(0, 128): for ko in T.thread_binding(0, 4, thread="threadIdx.x"): for ki in T.thread_binding(0, 32, thread="threadIdx.y"): with T.block("B"): vi = T.axis.spatial(128, i) vk = T.axis.reduce(128, ko * 32 + ki) T.reads([B[vi], A[vi, vk]]) T.writes([B[vi]]) with T.init(): B[vi] = T.float32(0) B[vi] = B[vi] + A[vi, vk]
def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) A = T.match_buffer(var_A, [512, 512], dtype="float32") B = T.match_buffer(var_B, [512, 512], dtype="float32") C = T.match_buffer(var_C, [512, 512], dtype="float32") # body # with T.block("root") C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"): for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"): for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): for i2_0 in T.serial(0, 1): for ax0_ax1_fused_0 in T.serial(0, 32768): for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): with T.block("A_shared"): v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) // 512) v1 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) % 512) T.reads([A[v0, v1]]) T.writes([A_shared[v0, v1]]) A_shared[v0, v1] = A[v0, v1] for ax0_ax1_fused_0 in T.serial(0, 1024): for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): for ax0_ax1_fused_2 in T.vectorized(0, 2): with T.block("B_shared"): v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 32) v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 32) T.reads([B[v0, v1]]) T.writes([B_shared[v0, v1]]) B_shared[v0, v1] = B[v0, v1] for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2): with T.block("C"): i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4) k = T.axis.reduce(512, i2_1 * 32 + i2_2) T.reads([A_shared[i, k], B_shared[k, j]]) T.writes([C_local[i, j]]) with T.init(): C_local[i, j] = T.float32(0) C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j] for ax0, ax1 in T.grid(32, 4): with T.block("C_local"): v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0) v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1) T.reads([C_local[v0, v1]]) T.writes([C[v0, v1]]) C[v0, v1] = C_local[v0, v1]
def element_wise_compute_at_split_j0_j1o_bound(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) B = T.alloc_buffer((128, 128)) for i in T.serial(0, 128): for j0 in T.thread_binding(0, 128, thread="threadIdx.x"): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j0]) B[vi, vj] = A[vi, vj] * 2.0 for j1o in T.thread_binding(0, 32, thread="threadIdx.x"): for j1i in T.serial(0, 4): with T.block("C"): vi = T.axis.S(128, i) vj = T.axis.S(128, j1o * 4 + j1i) C[vi, vj] = B[vi, vj] + 1.0
def element_wise_thread_x_different_dtype( A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"], ) -> None: for i in T.thread_binding(128, "blockIdx.x"): for j0_0 in T.thread_binding(4, "threadIdx.x"): for j0_1 in T.serial(0, 32): with T.block(""): B[i, j0_0 * 32 + j0_1] = A[i, j0_0 * 32 + j0_1] * 2.0 for j1_0 in T.thread_binding(T.int64(4), "threadIdx.x"): for j1_1 in T.serial(T.int64(32)): with T.block(""): C[i, j1_0 * T.int64(32) + j1_1] = B[i, j1_0 * T.int64(32) + j1_1] + 1.0
def transformed_simple_compute(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None: for tx in T.thread_binding(0, 16, thread="threadIdx.x"): with T.block(): T.reads([A[tx, 0:16]]) T.writes([C[tx, 0:16]]) B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") with T.block(): T.reads([A[tx, 0]]) T.writes([B[0, tx, 0]]) B[0, tx, 0] = A[tx, 0] * T.float32(2) with T.block(): T.reads([A[tx, 1:16], B[0:2, tx, 0]]) T.writes([B[0:2, tx, 0], C[tx, 0:15]]) for i in T.serial(0, 15): with T.block(): T.reads([A[tx, i + 1]]) T.writes([B[(i + 1) % 2, tx, 0]]) B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2) with T.block(): T.reads([B[i % 2, tx, 0]]) T.writes([C[tx, i]]) C[tx, i] = B[i % 2, tx, 0] + T.float32(1) with T.block(): T.reads([B[1, tx, 0]]) T.writes([C[tx, 15]]) C[tx, 15] = B[1, tx, 0] + T.float32(1)
def equal_ranked_threads(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) C = T.match_buffer(c, [128, 128]) B = T.alloc_buffer([128, 128], scope="shared") for i_o in T.thread_binding(0, 16, thread="threadIdx.x"): for i_i in T.thread_binding(0, 8, thread="threadIdx.y"): for j in T.serial(0, 128): with T.block("B"): vi = T.axis.S(128, i_o * 8 + i_i) vj = T.axis.S(128, j) B[vi, vj] = A[vi, vj] * 2.0 for j in T.serial(0, 128): with T.block("C"): vi = T.axis.S(128, i_o * 8 + i_i) vj = T.axis.S(128, j) C[vj, vi] = B[vj, vi] + 1.0
def lowered_reducer_max(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.serial(0, 128): for k in T.thread_binding(0, 128, thread="threadIdx.x"): with T.block("B_cross_thread_reduction"): vi, vk = T.axis.remap("SR", [i, k]) T.reads([A[vi, vk]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: T.max(x, y), [T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce(T.uint32(1), A[vi, vk], True, reduce_temp0.data, k, dtype="handle")) with T.block("B_write_back"): vi = T.axis.spatial(128, i) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0]
def unified_element_wise_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for blockIdx_x in T.thread_binding(0, 128, "blockIdx.x"): for threadIdx_x in T.thread_binding(0, 4, "threadIdx.x"): for j0_1 in T.serial(0, 32): with T.block(""): B[blockIdx_x, threadIdx_x * 32 + j0_1] = (A[blockIdx_x, threadIdx_x * 32 + j0_1] * 2.0) for j1_1 in T.serial(0, 32): with T.block(""): C[blockIdx_x, threadIdx_x * 32 + j1_1] = (B[blockIdx_x, threadIdx_x * 32 + j1_1] + 1.0)
def lowered_zero_rank_buffer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128], dtype="float32") B = T.match_buffer(b, [], dtype="float32") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for k in T.thread_binding(0, 128, thread="threadIdx.x"): with T.block("B_cross_thread_reduction"): vk = T.axis.reduce(128, k) T.reads([A[vk]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce(T.uint32(1), A[vk], True, reduce_temp0.data, k, dtype="handle")) with T.block("B_write_back"): T.reads([reduce_temp0[0]]) T.writes([B[()]]) B[()] = reduce_temp0[0]
def simple_compute_conflicting_order(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]): for tx in T.thread_binding(0, 16, thread="threadIdx.x"): for i in T.serial( 0, 16, annotations={ "software_pipeline_stage": [0, 1, 1], "software_pipeline_order": [0, 1, 1], }, ): with T.block(): T.reads(A[tx, i]) T.writes(D[tx, i]) B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") C = T.alloc_buffer((16, 1), dtype="float32", scope="shared") with T.block(): T.reads(A[tx, i]) T.writes(B[tx, 0]) B[tx, 0] = A[tx, i] * T.float32(2) with T.block(): T.reads(B[tx, 0]) T.writes(C[tx, 0]) C[tx, 0] = B[tx, 0] + T.float32(2) with T.block(): T.reads(C[tx, 0]) T.writes(D[tx, i]) D[tx, i] = C[tx, 0] + T.float32(1)
def element_wise_compute_at_split_j0_j1o_bound(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) B = T.alloc_buffer((128, 128)) for i in T.serial(0, 128): for j0 in T.thread_binding(0, 128, thread="threadIdx.x"): with T.block([128, 128], "B") as [vi, vj]: T.bind(vi, i) T.bind(vj, j0) B[vi, vj] = A[vi, vj] * 2.0 for j1o in T.thread_binding(0, 32, thread="threadIdx.x"): for j1i in T.serial(0, 4): with T.block([128, 128], "C") as [vi, vj]: T.bind(vi, i) T.bind(vj, j1o * 4 + j1i) C[vi, vj] = B[vi, vj] + 1.0