def element_wise_invalid_annotation(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body with T.block("root"): T.reads([]) T.writes([]) B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) for i0 in T.serial(0, 128): for ax1 in T.serial(0, 128): with T.block("B"): T.block_attr({"buffer_dim_align": [0]}) vi, vj = T.axis.remap("SS", [i0, ax1]) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) B[vi, vj] = (A[vi, vj] * T.float32(2)) for i1 in T.serial(0, 128): with T.block("C"): vi_1, vj_1 = T.axis.remap("SS", [i0, i1]) T.reads([B[vi_1, vj_1]]) T.writes([C[vi_1, vj_1]]) C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1))
def softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: A = T.match_buffer(var_A, [256, 256], dtype="float32") T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256], dtype="float32") T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") for i0 in T.thread_binding(0, 256, thread="blockIdx.x"): for ax0_0 in T.serial(0, 8): for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_maxelem"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads([T_softmax_maxelem_shared[i0_1], A[i0_1, k]]) T.writes([T_softmax_maxelem_shared[i0_1]]) with T.init(): T_softmax_maxelem_shared[i0_1] = T.min_value("float32") T_softmax_maxelem_shared[i0_1] = T.max( T_softmax_maxelem_shared[i0_1], A[i0_1, k] ) for ax0_0 in T.serial(0, 8): for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_expsum"): i0_2 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads( [ T_softmax_expsum_shared[i0_2], A[i0_2, k], T_softmax_maxelem_shared[i0_2], ] ) T.writes([T_softmax_expsum_shared[i0_2]]) with T.init(): T_softmax_expsum_shared[i0_2] = T.float32(0) T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp( A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32" ) for i1_0 in T.serial(0, 8): for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_norm"): i0_3 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_0 * 32 + i1_1) T.reads( [ A[i0_3, i1], T_softmax_maxelem_shared[i0_3], T_softmax_expsum_shared[i0_3], ] ) T.writes([T_softmax_norm[i0_3, i1]]) T.block_attr({"axis": 1}) T_softmax_norm[i0_3, i1] = ( T.exp( A[i0_3, i1] - T_softmax_maxelem_shared[i0_3], dtype="float32", ) / T_softmax_expsum_shared[i0_3] )
def GmmCuda2(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None: Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local") X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"): for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"): for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"): for i1_3_init, i2_4_init in T.grid(4, 2): with T.block("Z_init"): b = T.axis.spatial(1, 0) i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init) j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init) T.reads() T.writes(Z_local[b, i, j]) Z_local[b, i, j] = T.float32(0) for i3_0 in T.serial(4): for ax0_ax1_ax2_fused_0 in T.serial(4): for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in T.vectorized(2): with T.block("X_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32) v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32) T.reads(X[v0, v1, v2]) T.writes(X_shared[v0, v1, v2]) X_shared[v0, v1, v2] = X[v0, v1, v2] for ax0_ax1_ax2_fused_0 in T.serial(8): for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): with T.block("Y_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32) v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32) T.reads(Y[v0, v1, v2]) T.writes(Y_shared[v0, v1, v2]) Y_shared[v0, v1, v2] = Y[v0, v1, v2] for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2): with T.block("Z_update"): b = T.axis.spatial(1, 0) i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3) j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4) k = T.axis.reduce(128, i3_0 * 32 + i3_2) T.block_attr({ "meta_schedule.thread_extent_low_inclusive": 1024, "meta_schedule.thread_extent_high_inclusive": 1024, }) T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j]) T.writes(Z_local[b, i, j]) Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j] for ax0, ax1, ax2 in T.grid(1, 4, 2): with T.block("Z_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1) v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2) T.reads(Z_local[v0, v1, v2]) T.writes(Z[v0, v1, v2]) Z[v0, v1, v2] = Z_local[v0, v1, v2]
def matmul_with_annotation(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j, k in T.grid(128, 128, 128): with T.block("update"): T.block_attr({"test_annotation": 1}) vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
def square_sum_with_annotation(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [16, 256, 256]) C = T.match_buffer(c, [16]) for b0, i0, j0 in T.grid(16, 256, 256): with T.block("C"): T.block_attr({"test_annotation": 1}) b, i, j = T.axis.remap("SRR", [b0, i0, j0]) with T.init(): C[b] = 0.0 C[b] = C[b] + A[b, i, j] * A[b, i, j]
def annotated_matmul( A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"], ) -> None: for i, j, k in T.grid(128, 128, 128): with T.block("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) T.block_attr({"test_annotation": True}) with T.init(): C[vi, vj] = T.float32(0) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
def nested_pipeline_double_buffer(A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]): for tx in T.thread_binding(0, 16, thread="threadIdx.x"): for i in T.serial( 0, 16, annotations={ "software_pipeline_stage": [0, 0, 0, 1, 1], "software_pipeline_order": [0, 2, 3, 1, 4], }, ): with T.block(): T.reads(A[tx, i, 0:16]) T.writes(C[tx, i, 0:16]) A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared") A_local = T.alloc_buffer((1, 1, 16), dtype="float32", scope="local") for j in T.serial(0, 16): with T.block(): T.reads(A[tx, i, j]) T.writes(A_shared[tx, 0, j]) A_shared[tx, 0, j] = A[tx, i, j] for j in T.serial(0, 16): with T.block(): T.block_attr({"double_buffer_scope": 0}) T.reads(A_shared[tx, 0, j]) T.writes(A_local[0, 0, j]) A_local[0, 0, j] = A_shared[tx, i, j] for j in T.serial( 0, 16, annotations={ "software_pipeline_stage": [0, 1], "software_pipeline_order": [0, 1], }, ): with T.block(): T.reads(A_local[0, 0, j]) T.writes(C[tx, i, j]) B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared") with T.block(): T.reads(A_local[tx, i, j]) T.writes(B[tx, i, 0]) B[tx, i, 0] = A_local[0, 0, j] * T.float32(2) with T.block(): T.reads(B[tx, i, 0]) T.writes(C[tx, i, j]) C[tx, i, j] = B[tx, i, 0] + T.float32(1)
def annotated_mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) with T.block("root"): T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]) T.writes(C[0:16, 0:16]) for i, j, k in T.grid(16, 16, 16): with T.block("update"): T.block_attr({"test_annotation": True}) vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk]
def single_reduction_loop_with_block_predicate( A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None: T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") for i0 in T.serial(256): for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_maxelem"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) T.reads(T_softmax_maxelem_shared[i0_1], A[i0_1, k]) T.writes(T_softmax_maxelem_shared[i0_1]) with T.init(): T_softmax_maxelem_shared[i0_1] = T.float32( -3.4028234663852886e38) T_softmax_maxelem_shared[i0_1] = T.max( T_softmax_maxelem_shared[i0_1], A[i0_1, k]) for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_expsum"): i0_2 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) T.reads(T_softmax_expsum_shared[i0_2], A[i0_2, k], T_softmax_maxelem_shared[i0_2]) T.writes(T_softmax_expsum_shared[i0_2]) with T.init(): T_softmax_expsum_shared[i0_2] = T.float32(0) T_softmax_expsum_shared[ i0_2] = T_softmax_expsum_shared[i0_2] + T.exp( A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32") for i1_0 in T.serial(1): for i1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_norm"): i0_3 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_1) T.where(i1_0 * 512 + i1_1 < 256) T.reads(A[i0_3, i1], T_softmax_maxelem_shared[i0_3], T_softmax_expsum_shared[i0_3]) T.writes(T_softmax_norm[i0_3, i1]) T.block_attr({"axis": 1}) T_softmax_norm[i0_3, i1] = ( T.exp(A[i0_3, i1] - T_softmax_maxelem_shared[i0_3], dtype="float32") / T_softmax_expsum_shared[i0_3])
def main(a: T.handle, b: T.handle, c: T.handle) -> None: T.func_attr({"global_symbol": "main"}) A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") with T.block("root"): T.reads([]) T.writes([]) T.block_attr({"meta_schedule.parallel": 128, "meta_schedule.vectorize": 16, "meta_schedule.unroll_explicit": 2}) for i, j, k in T.grid(1024, 1024, 1024): with T.block("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
def main(a: T.handle, b: T.handle, c: T.handle) -> None: T.func_attr({"global_symbol": "main"}) A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") with T.block("root"): for i, j, k in T.grid(1024, 1024, 1024): with T.block("matmul"): T.block_attr({ "schedule_rule": "tvm.meta_schedule.test.custom_search_space" }) vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
def main( placeholder: T.Buffer[(1, 1, 16, 16, 3), "float32"], placeholder_1: T.Buffer[(2, 1, 5, 5, 3, 4), "float32"], conv2d_NCHWc: T.Buffer[(1, 2, 16, 16, 4), "float32"] ) -> None: # type: ignore # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") data_pad = T.alloc_buffer([1, 1, 20, 20, 3], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 1, 20, 20, 3): with T.block("data_pad"): i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap( "SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1]) T.writes(data_pad[i0_1, i1_1, i2_1, i3_1, i4_1]) data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18, placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1], T.float32(0), dtype="float32") # type: ignore # pylint: disable=R1716 for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(1, 2, 16, 16, 4, 3, 5, 5): with T.block("conv2d_NCHWc"): n, oc_chunk, oh, ow, oc_block, ic, kh, kw = T.axis.remap( "SSSSSRRR", [i0, i1, i2, i3, i4, i5, i6, i7]) T.reads(data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3], placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block]) # type: ignore T.writes(conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block]) T.block_attr({ "workload": [ "conv2d_NCHWc.x86", ["TENSOR", [1, 1, 16, 16, 3], "float32"], ["TENSOR", [2, 1, 5, 5, 3, 4], "float32"], [1, 1], [2, 2, 2, 2], [1, 1], "NCHW3c", "NCHW4c", "float32" ] }) with T.init(): conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = T.float32(0) conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[ n, oc_chunk, oh, ow, oc_block] + data_pad[ n, ic // 3, oh + kh, ow + kw, ic % 3] * placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block] # type: ignore
def main(a: T.handle, b: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main"}) A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32") B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32") # body with T.block("root"): T.block_attr({"meta_schedule.parallel":128, "meta_schedule.vectorize":32}) for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid(128, 64, 4, 4, 64, 4, 8, 32): with T.block("move"): vi = T.axis.spatial(1024, i0 * 16 + i1 * 4 + i2) vj = T.axis.spatial(1024, j0 * 32 + j1 * 8 + j2) vk = T.axis.spatial(1024, k0 * 32 + k1) T.where((i0 * 4 + i1) * 4 + i2 < 1024 and (j0 * 4 + j1) * 8 + j2 < 1024 and k0 * 32 + k1 < 1024) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk]
def matmul_relu_ann2(a: T.handle, b: T.handle, d: T.handle) -> None: A = T.match_buffer(a, (1024, 1024)) B = T.match_buffer(b, (1024, 1024)) C = T.alloc_buffer((1024, 1024)) D = T.match_buffer(d, (1024, 1024)) for i, j, k in T.grid(1024, 1024, 1024): with T.block("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 T.block_attr({"test1": "aaa"}) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(1024, 1024): with T.block("relu"): vi, vj = T.axis.remap("SS", [i, j]) T.block_attr({"test2": 0.22, "test3": ["aa", 1]}) D[vi, vj] = T.max(C[vi, vj], 0.0)
def compacted_storage_align_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((1, 16), strides=(31, 1), dtypes="float32") for j in range(0, 16): with T.block() as []: T.reads(A[i, j]) T.writes(B[0, j]) T.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]}) B[0, j] = A[i, j] + 1.0 for j in range(0, 16): with T.block() as []: T.reads(B[0, j]) T.writes(C[i, j]) C[i, j] = B[0, j] * 2.0
def square_sum_with_annotation_rfactor(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [16, 256, 256]) C = T.match_buffer(c, [16]) C_rf = T.alloc_buffer([16, 256]) for i0, i1, i2 in T.grid(16, 256, 256): with T.block("C_rf"): T.block_attr({"test_annotation": 1}) vi2, b, i = T.axis.remap("SSR", [i2, i0, i1]) with T.init(): C_rf[b, vi2] = 0.0 C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2]) for i0_1, i2_1 in T.grid(16, 256): with T.block("C"): T.block_attr({"test_annotation": 1}) vi2_1, b_1 = T.axis.remap("RS", [i2_1, i0_1]) with T.init(): C[b_1] = 0.0 C[b_1] = C[b_1] + C_rf[b_1, vi2_1]
def rewritten_tir_matmul( A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"], ) -> None: T.func_attr({"layout_free_buffers": [1]}) B_reindex = T.alloc_buffer([16, 4, 4], dtype="float32") for ax0, ax1 in T.grid(16, 16): with T.block("layout_rewrite"): i0, i1 = T.axis.remap("SS", [ax0, ax1]) T.block_attr({"meta_schedule.layout_rewrite_preproc": True}) B_reindex[i1, i0 // 4, i0 % 4] = B[i0, i1] for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4): with T.block("matmul"): vi = T.axis.spatial(16, i0 * 4 + i1) vj = T.axis.spatial(16, j) vk = T.axis.reduce(16, k0 * 4 + k1) with T.init(): C[vi, vj] = T.float32(0) C[vi, vj] = C[vi, vj] + A[vi, vk] * B_reindex[vj, vk // 4, vk % 4]
def element_wise_storage_align(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body with T.block([], "root"): T.reads([]) T.writes([]) B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) for i0 in T.serial(0, 128): for ax1 in T.serial(0, 128): with T.block([128, 128], "B") as [vi, vj]: T.bind(vi, i0) T.bind(vj, ax1) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) T.block_attr({"buffer_dim_align":[[0, 0, 128, 127]]}) B[vi, vj] = (A[vi, vj]*T.float32(2)) for i1 in T.serial(0, 128): with T.block([128, 128], "C") as [vi_1, vj_1]: T.bind(vi_1, i0) T.bind(vj_1, i1) T.reads([B[vi_1, vj_1]]) T.writes([C[vi_1, vj_1]]) C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1))
def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) A = T.match_buffer(var_A, [512, 512], dtype="float32") B = T.match_buffer(var_B, [512, 512], dtype="float32") C = T.match_buffer(var_C, [512, 512], dtype="float32") # body # with T.block("root") C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"): for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"): for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): for i2_0 in T.serial(0, 1): for ax0_ax1_fused_0 in T.serial(0, 32768): for ax0_ax1_fused_1 in T.thread_binding( 0, 8, thread="threadIdx.x"): with T.block("A_shared"): v0 = T.axis.spatial( 512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) // 512) v1 = T.axis.spatial( 512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) % 512) T.reads([A[v0, v1]]) T.writes([A_shared[v0, v1]]) T.block_attr( {"meta_schedule.cooperative_fetch": 1}) A_shared[v0, v1] = A[v0, v1] for ax0_ax1_fused_0 in T.serial(0, 1024): for ax0_ax1_fused_1 in T.thread_binding( 0, 8, thread="threadIdx.x"): for ax0_ax1_fused_2 in T.vectorized(0, 2): with T.block("B_shared"): v0 = T.axis.spatial( 512, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 32) v1 = T.axis.spatial( 512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 32) T.reads([B[v0, v1]]) T.writes([B_shared[v0, v1]]) T.block_attr({ "meta_schedule.cooperative_fetch": 2 }) B_shared[v0, v1] = B[v0, v1] for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid( 16, 2, 2, 32, 16, 2): with T.block("C"): i = T.axis.spatial( 512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) j = T.axis.spatial( 512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4) k = T.axis.reduce(512, i2_1 * 32 + i2_2) T.reads([ C_local[i, j], A_shared[i, k], B_shared[k, j] ]) T.writes([C_local[i, j]]) with T.init(): C_local[i, j] = T.float32(0) C_local[i, j] = C_local[ i, j] + A_shared[i, k] * B_shared[k, j] for ax0, ax1 in T.grid(32, 4): with T.block("C_local"): v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0) v1 = T.axis.spatial( 512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1) T.reads([C_local[v0, v1]]) T.writes([C[v0, v1]]) C[v0, v1] = C_local[v0, v1]
def main( X: T.Buffer[(128, 128), "int8"], W: T.Buffer[(128, 128), "int8"], compute: T.Buffer[(128, 128), "int32"], ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") compute_local = T.alloc_buffer([128, 128], dtype="int32", scope="local") X_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared") W_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared") for i0_0_i1_0_fused in T.thread_binding(16, thread="blockIdx.x"): for i0_1_i1_1_fused in T.thread_binding(2, thread="vthread.x"): for i0_2_i1_2_fused in T.thread_binding(2, thread="threadIdx.x"): for i0_3_init, i1_3_init, i0_4_init in T.grid(4, 16, 4): with T.block("compute_o_init"): i = T.axis.spatial( 128, i0_0_i1_0_fused // 2 * 16 + i0_3_init * 4 + i0_4_init) j = T.axis.spatial( 128, i0_0_i1_0_fused % 2 * 64 + i0_1_i1_1_fused * 32 + i0_2_i1_2_fused * 16 + i1_3_init, ) T.reads() T.writes(compute_local[i, j]) T.block_attr( {"meta_schedule.auto_tensorize": "dp4a"}) with T.block("compute_init"): T.reads() T.writes(compute_local[i, j]) compute_local[i, j] = 0 for i2_0_0 in T.serial(2): for ax0_ax1_fused in T.serial(1024): with T.block("X_shared"): v0 = T.axis.spatial( 128, i0_0_i1_0_fused // 2 * 16 + ax0_ax1_fused // 64) v1 = T.axis.spatial( 128, i2_0_0 * 64 + ax0_ax1_fused % 64) T.reads(X[v0, v1]) T.writes(X_shared[v0, v1]) T.block_attr( {"meta_schedule.cooperative_fetch": 4}) X_shared[v0, v1] = X[v0, v1] for ax0_ax1_fused in T.serial(4096): with T.block("W_shared"): v0 = T.axis.spatial( 128, i0_0_i1_0_fused % 2 * 64 + ax0_ax1_fused // 64) v1 = T.axis.spatial( 128, i2_0_0 * 64 + ax0_ax1_fused % 64) T.reads(W[v0, v1]) T.writes(W_shared[v0, v1]) T.block_attr( {"meta_schedule.cooperative_fetch": 1}) W_shared[v0, v1] = W[v0, v1] for i2_0_1, i0_3, i1_3, i2_0_2, i0_4, i1_4 in T.grid( 2, 4, 16, 8, 4, 1): with T.block("compute_o_update"): i = T.axis.spatial( 128, i0_0_i1_0_fused // 2 * 16 + i0_3 * 4 + i0_4) j = T.axis.spatial( 128, i0_0_i1_0_fused % 2 * 64 + i0_1_i1_1_fused * 32 + i0_2_i1_2_fused * 16 + i1_3, ) k_o = T.axis.reduce( 32, i2_0_0 * 16 + i2_0_1 * 8 + i2_0_2) T.reads( compute_local[i, j], X_shared[i, k_o * 4:k_o * 4 + 4], W_shared[j, k_o * 4:k_o * 4 + 4], ) T.writes(compute_local[i, j]) A = T.match_buffer( X_shared[i, k_o * 4:k_o * 4 + 4], [4], dtype="int8", scope="shared", align=4, offset_factor=1, ) B = T.match_buffer( W_shared[j, k_o * 4:k_o * 4 + 4], [4], dtype="int8", scope="shared", align=4, offset_factor=1, ) C = T.match_buffer( compute_local[i, j], [1], dtype="int32", scope="local", align=4, offset_factor=1, ) C[0] = C[0] + T.call_pure_extern( "__dp4a", A[T.ramp(0, 1, 4)], B[T.ramp(0, 1, 4)], 0, dtype="int32", ) for ax0, ax1 in T.grid(16, 16): with T.block("compute_local"): v0 = T.axis.spatial( 128, i0_0_i1_0_fused // 2 * 16 + ax0) v1 = T.axis.spatial( 128, i0_0_i1_0_fused % 2 * 64 + i0_1_i1_1_fused * 32 + i0_2_i1_2_fused * 16 + ax1, ) T.reads(compute_local[v0, v1]) T.writes(compute[v0, v1]) compute[v0, v1] = compute_local[v0, v1]
def main(A: T.Buffer[(1024, 1024), "float32"], B: T.Buffer[(1024, 1024), "float32"], C: T.Buffer[(1024, 1024), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) # body # with T.block("root") for blockIdx_y in T.thread_binding(32, thread="blockIdx.y"): for blockIdx_x in T.thread_binding(32, thread="blockIdx.x"): for threadIdx_y in T.thread_binding(2, thread="threadIdx.y"): for threadIdx_x in T.thread_binding(2, thread="threadIdx.x"): for k_0 in T.serial(32): with T.block(): T.reads( A[blockIdx_y * 32:blockIdx_y * 32 + 32, k_0 * 32:k_0 * 32 + 32], B[k_0 * 32:k_0 * 32 + 32, blockIdx_x * 32:blockIdx_x * 32 + 32]) T.writes( C[blockIdx_y * 32:blockIdx_y * 32 + 32, blockIdx_x * 32:blockIdx_x * 32 + 32]) A_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") B_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") for ax0_ax1_fused_0 in T.serial(64): for ax0_ax1_fused_3 in T.vectorized(4): with T.block("A_shared"): T.reads(A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) T.writes(A_shared[ blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) T.block_attr({ "tir.manifest_shared_memory_local_stage": 1 }) A_shared[ blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] = A[ blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] for ax0_ax1_fused_0 in T.serial(64): for ax0_ax1_fused_3 in T.vectorized(4): with T.block("B_shared"): T.reads(B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) T.writes(B_shared[ k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) T.block_attr({ "tir.manifest_shared_memory_local_stage": 1 }) B_shared[ k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] = B[ k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] for k_1, i_2, j_2, k_2 in T.grid( 2, 16, 16, 16): with T.block("C"): T.reads( A_shared[blockIdx_y * 32 + threadIdx_y * 16 + i_2, k_0 * 32 + k_1 * 16 + k_2], B_shared[k_0 * 32 + k_1 * 16 + k_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2]) T.writes(C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2]) if k_0 * 32 + k_1 * 16 + k_2 == 0: C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] = T.float32(0) C[ blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] = C[ blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] + A_shared[ blockIdx_y * 32 + threadIdx_y * 16 + i_2, k_0 * 32 + k_1 * 16 + k_2] * B_shared[ k_0 * 32 + k_1 * 16 + k_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2]
def implicit_root_has_attrs(): T.block_attr({}) # error: implicit root does not support block_attr T.evaluate(0.0)
def duplicate_annotations() -> None: for i, j in T.grid(16, 16): with T.block(): vi, vj = T.axis.remap("SS", [i, j]) T.block_attr({}) T.block_attr({}) # error
def duplicate_annotations() -> None: with T.block([16, 16]) as [vi, vj]: T.block_attr({}) T.block_attr({}) # error
def main( X: T.Buffer[(128, 128), "int8"], W: T.Buffer[(128, 128), "int8"], compute: T.Buffer[(128, 128), "int32"], ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) compute_local = T.alloc_buffer([128, 128], dtype="int32", scope="local") X_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared") W_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared") for i0_0_i1_0_fused in T.thread_binding(16, thread="blockIdx.x"): for i0_1_i1_1_fused in T.thread_binding(2, thread="vthread.x"): for i0_2_i1_2_fused in T.thread_binding(2, thread="threadIdx.x"): for i2_0_0 in T.serial(2): for ax0_ax1_fused in T.serial(1024): with T.block("X_shared"): v0 = T.axis.spatial( 128, i0_0_i1_0_fused // 2 * 16 + ax0_ax1_fused // 64) v1 = T.axis.spatial( 128, i2_0_0 * 64 + ax0_ax1_fused % 64) T.reads(X[v0, v1]) T.writes(X_shared[v0, v1]) T.block_attr( {"meta_schedule.cooperative_fetch": 4}) X_shared[v0, v1] = X[v0, v1] for ax0_ax1_fused in T.serial(4096): with T.block("W_shared"): v0 = T.axis.spatial( 128, i0_0_i1_0_fused % 2 * 64 + ax0_ax1_fused // 64) v1 = T.axis.spatial( 128, i2_0_0 * 64 + ax0_ax1_fused % 64) T.reads(W[v0, v1]) T.writes(W_shared[v0, v1]) T.block_attr( {"meta_schedule.cooperative_fetch": 1}) W_shared[v0, v1] = W[v0, v1] for i2_0_1, i0_3, i1_3, i2_0_2, i0_4, i1_4 in T.grid( 2, 4, 16, 8, 4, 1): with T.block("compute_o"): i = T.axis.spatial( 128, i0_0_i1_0_fused // 2 * 16 + i0_3 * 4 + i0_4) j = T.axis.spatial( 128, i0_0_i1_0_fused % 2 * 64 + i0_1_i1_1_fused * 32 + i0_2_i1_2_fused * 16 + i1_3, ) k_o = T.axis.reduce( 32, i2_0_0 * 16 + i2_0_1 * 8 + i2_0_2) T.reads( X_shared[i, k_o * 4:k_o * 4 + 4], W_shared[j, k_o * 4:k_o * 4 + 4], ) T.writes(compute_local[i, j]) T.block_attr( {"meta_schedule.auto_tensorize": "dp4a"}) with T.init(): with T.block("compute_init"): T.reads() T.writes(compute_local[i, j]) compute_local[i, j] = 0 for i2_1 in T.serial(4): with T.block("compute"): k = T.axis.reduce(4, i2_1) T.reads( compute_local[i, j], X_shared[i, k_o * 4 + k], W_shared[j, k_o * 4 + k], ) T.writes(compute_local[i, j]) T.block_attr({ "meta_schedule.tiling_structure": "SSSRRSRS" }) compute_local[ i, j] = compute_local[i, j] + T.cast( X_shared[i, k_o * 4 + k], "int32") * T.cast( W_shared[j, k_o * 4 + k], "int32") for ax0, ax1 in T.grid(16, 16): with T.block("compute_local"): v0 = T.axis.spatial( 128, i0_0_i1_0_fused // 2 * 16 + ax0) v1 = T.axis.spatial( 128, i0_0_i1_0_fused % 2 * 64 + i0_1_i1_1_fused * 32 + i0_2_i1_2_fused * 16 + ax1, ) T.reads(compute_local[v0, v1]) T.writes(compute[v0, v1]) compute[v0, v1] = compute_local[v0, v1]
def main( A: T.Buffer[(512, 512), "float32"], B: T.Buffer[(512, 512), "float32"], C: T.Buffer[(512, 512), "float32"], ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"): for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"): for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.y"): for i2_0 in T.serial(0, 1): for ax0_ax1_fused_0 in T.serial(0, 1024): for ax0_ax1_fused_1 in T.thread_binding( 0, 8, thread="threadIdx.y"): for ax0_ax1_fused_2 in T.thread_binding( 0, 32, thread="threadIdx.x"): with T.block("A_shared"): v0 = T.axis.spatial( 512, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 32 + ax0_ax1_fused_2) // 512, ) v1 = T.axis.spatial( 512, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 32 + ax0_ax1_fused_2) % 512, ) T.reads([A[v0, v1]]) T.writes([A_shared[v0, v1]]) A_shared[v0, v1] = A[v0, v1] for ax0_ax1_fused_0 in T.serial(0, 32): for ax0_ax1_fused_1 in T.thread_binding( 0, 8, thread="threadIdx.y"): for ax0_ax1_fused_2 in T.thread_binding( 0, 32, thread="threadIdx.x"): for ax0_ax1_fused_3 in T.vectorized(0, 2): with T.block("B_shared"): v0 = T.axis.spatial( 512, (ax0_ax1_fused_0 * 512 + ax0_ax1_fused_1 * 64 + ax0_ax1_fused_2 * 2 + ax0_ax1_fused_3) // 32, ) v1 = T.axis.spatial( 512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 512 + ax0_ax1_fused_1 * 64 + ax0_ax1_fused_2 * 2 + ax0_ax1_fused_3) % 32, ) T.reads([B[v0, v1]]) T.writes([B_shared[v0, v1]]) B_shared[v0, v1] = B[v0, v1] for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid( 16, 2, 2, 32, 16, 2): with T.block("C"): i = T.axis.spatial( 512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) j = T.axis.spatial( 512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4, ) k = T.axis.reduce( 512, i2_0 * 512 + i2_1 * 32 + i2_2) T.reads([A_shared[i, k], B_shared[k, j]]) T.writes([C_local[i, j]]) T.block_attr({"warp_execution": 1}) with T.init(): C_local[i, j] = T.float32(0) C_local[i, j] = C_local[ i, j] + A_shared[i, k] * B_shared[k, j] for ax0, ax1 in T.grid(32, 4): with T.block("C_local"): v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0) v1 = T.axis.spatial( 512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1) T.reads([C_local[v0, v1]]) T.writes([C[v0, v1]]) C[v0, v1] = C_local[v0, v1]
def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: A = T.match_buffer(var_A, [256, 256], dtype="float32") T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256], dtype="float32") T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") reduce_temp1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") normal_reduce_temp1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i0 in T.thread_binding(0, 256, thread="blockIdx.x"): for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_maxelem_normal_reduction_init"): T.reads([]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.min_value("float32") for ax0_0 in T.serial(0, 8): with T.block("T_softmax_maxelem_normal_reduction"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads([A[i0_1, k], normal_reduce_temp0[0]]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.max(normal_reduce_temp0[0], A[i0_1, k]) with T.block("T_softmax_maxelem_cross_thread_reduction"): T.reads([normal_reduce_temp0[0]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: T.max(x, y), [T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), normal_reduce_temp0[0], True, reduce_temp0.data, ax0_1, dtype="handle", )) with T.block("T_softmax_maxelem_write_back"): i0_2 = T.axis.spatial(256, i0) T.reads([reduce_temp0[0]]) T.writes([T_softmax_maxelem_shared[i0_2]]) T_softmax_maxelem_shared[i0_2] = reduce_temp0[0] for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_expsum_normal_reduction_init"): T.reads([]) T.writes([normal_reduce_temp1[0]]) normal_reduce_temp1[0] = T.float32(0) for ax0_0 in T.serial(0, 8): with T.block("T_softmax_expsum_normal_reduction"): i0_3 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads([ A[i0_3, k], T_softmax_maxelem_shared[i0_3], normal_reduce_temp1[0], ]) T.writes([normal_reduce_temp1[0]]) normal_reduce_temp1[0] = normal_reduce_temp1[0] + T.exp( A[i0_3, k] - T_softmax_maxelem_shared[i0_3], dtype="float32") with T.block("T_softmax_expsum_cross_thread_reduction"): T.reads([normal_reduce_temp1[0]]) T.writes([reduce_temp1[0]]) T.attr( T.comm_reducer(lambda x_1, y_1: x_1 + y_1, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), normal_reduce_temp1[0], True, reduce_temp1.data, ax0_1, dtype="handle", )) with T.block("T_softmax_expsum_write_back"): i0_4 = T.axis.spatial(256, i0) T.reads([reduce_temp1[0]]) T.writes([T_softmax_expsum_shared[i0_4]]) T_softmax_expsum_shared[i0_4] = reduce_temp1[0] for i1_0 in T.serial(0, 8): for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_norm"): i0_5 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_0 * 32 + i1_1) T.reads([ A[i0_5, i1], T_softmax_maxelem_shared[i0_5], T_softmax_expsum_shared[i0_5], ]) T.writes([T_softmax_norm[i0_5, i1]]) T.block_attr({"axis": 1}) T_softmax_norm[i0_5, i1] = (T.exp( A[i0_5, i1] - T_softmax_maxelem_shared[i0_5], dtype="float32", ) / T_softmax_expsum_shared[i0_5])
def lowered_single_reduction_loop_with_block_predicate( A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None: T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") cross_thread_0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") in_thread_0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") cross_thread_1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") in_thread_1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i0 in T.serial(256): for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_maxelem_in_thread_init"): T.reads() T.writes(in_thread_0[0]) in_thread_0[0] = T.float32(-3.4028234663852886e38) with T.block("T_softmax_maxelem_in_thread"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) T.reads(A[i0_1, k], in_thread_0[0]) T.writes(in_thread_0[0]) in_thread_0[0] = T.max(in_thread_0[0], A[i0_1, k]) with T.block("T_softmax_maxelem_cross_thread"): T.reads(in_thread_0[0]) T.writes(cross_thread_0[0]) T.attr( T.comm_reducer(lambda x, y: T.max(x, y), [T.float32(-3.4028234663852886e38)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), in_thread_0[0], True, cross_thread_0.data, ax1_1, dtype="handle", )) with T.block("T_softmax_maxelem_write_back"): i0_2 = T.axis.spatial(256, i0) T.reads(cross_thread_0[0]) T.writes(T_softmax_maxelem_shared[i0_2]) T_softmax_maxelem_shared[i0_2] = cross_thread_0[0] for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_expsum_in_thread_init"): T.reads() T.writes(in_thread_1[0]) in_thread_1[0] = T.float32(0) with T.block("T_softmax_expsum_in_thread"): i0_3 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) T.reads(A[i0_3, k], T_softmax_maxelem_shared[i0_3], in_thread_1[0]) T.writes(in_thread_1[0]) in_thread_1[0] = in_thread_1[0] + T.exp( A[i0_3, k] - T_softmax_maxelem_shared[i0_3], dtype="float32") with T.block("T_softmax_expsum_cross_thread"): T.reads(in_thread_1[0]) T.writes(cross_thread_1[0]) T.attr( T.comm_reducer(lambda x_1, y_1: x_1 + y_1, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), in_thread_1[0], True, cross_thread_1.data, ax1_1, dtype="handle", )) with T.block("T_softmax_expsum_write_back"): i0_4 = T.axis.spatial(256, i0) T.reads(cross_thread_1[0]) T.writes(T_softmax_expsum_shared[i0_4]) T_softmax_expsum_shared[i0_4] = cross_thread_1[0] for i1_0 in T.serial(1): for i1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_norm"): i0_5 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_1) T.where(i1_0 * 512 + i1_1 < 256) T.reads(A[i0_5, i1], T_softmax_maxelem_shared[i0_5], T_softmax_expsum_shared[i0_5]) T.writes(T_softmax_norm[i0_5, i1]) T.block_attr({"axis": 1}) T_softmax_norm[i0_5, i1] = ( T.exp(A[i0_5, i1] - T_softmax_maxelem_shared[i0_5], dtype="float32") / T_softmax_expsum_shared[i0_5])
def conv2d_winograd_cuda( # type: ignore placeholder: T.Buffer[(1, 14, 14, 128), "float32"], # type: ignore placeholder_1: T.Buffer[(6, 6, 128, 128), "float32"], # type: ignore conv2d_winograd: T.Buffer[(1, 12, 12, 128), "float32"], # type: ignore ) -> None: # type: ignore data_pad = T.alloc_buffer([1, 16, 16, 128]) input_tile = T.alloc_buffer([6, 6, 9, 128]) B = T.alloc_buffer([6, 6]) data_pack = T.alloc_buffer([6, 6, 9, 128]) bgemm = T.alloc_buffer([6, 6, 9, 128]) A = T.alloc_buffer([6, 4]) inverse = T.alloc_buffer([4, 4, 9, 128]) for i0, i1, i2, i3 in T.grid(1, 16, 16, 128): with T.block("data_pad"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.block_attr({"schedule_rule": "None"}) T.reads([placeholder[i0_1, i1_1, i2_1, i3_1]]) T.writes([data_pad[i0_1, i1_1, i2_1, i3_1]]) data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( 0 <= i1_1 and i1_1 < 14 and 0 <= i2_1 and i2_1 < 14, # type: ignore placeholder[i0_1, i1_1, i2_1, i3_1], T.float32(0), dtype="float32", ) for i0_2, i1_2, i2_2, i3_2 in T.grid(6, 6, 9, 128): with T.block("input_tile"): eps, nu, p, ci = T.axis.remap("SSSS", [i0_2, i1_2, i2_2, i3_2]) T.block_attr({"schedule_rule": "None"}) T.reads( [ data_pad[ T.floordiv(p, 9), # type: ignore ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), # type: ignore ((T.floormod(p, 3) * 4) + nu), # type: ignore ci, ] ] ) T.writes([input_tile[eps, nu, p, ci]]) input_tile[eps, nu, p, ci] = data_pad[ T.floordiv(p, 9), # type: ignore ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), # type: ignore ((T.floormod(p, 3) * 4) + nu), # type: ignore ci, ] for i0_3, i1_3 in T.grid(6, 6): with T.block("B"): i, j = T.axis.remap("SS", [i0_3, i1_3]) T.block_attr({"schedule_rule": "meta_schedule.compute_inline"}) T.writes([B[i, j]]) # fmt: off B[i, j] = T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 0)), T.float32(0), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 5)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 4)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 3)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 2)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 0)), T.float32(1), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 5)), T.float32(-2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 4)), T.float32(-0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 3)), T.float32(2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 2)), T.float32(2.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 1)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 0)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 5)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 4)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 3)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 2)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 1)), T.float32(-2.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 0)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 4)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 3)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 2)), T.float32(-1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 0)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 5)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) # type: ignore # fmt: on for i0_4, i1_4, i2_3, i3_3, i4, i5 in T.grid(6, 6, 9, 128, 6, 6): with T.block("data_pack"): eps_1, nu_1, p_1, ci_1, r_a, r_b = T.axis.remap( "SSSSRR", [i0_4, i1_4, i2_3, i3_3, i4, i5] ) T.block_attr({"schedule_rule": "meta_schedule.winograd_data_pack.cuda"}) T.reads( [ data_pack[eps_1, nu_1, p_1, ci_1], input_tile[r_a, r_b, p_1, ci_1], B[ T.min(r_a, r_b) : ( # type: ignore T.min(r_a, r_b) + ((T.max(r_a, r_b) + 1) - T.min(r_a, r_b)) # type: ignore ), T.min(eps_1, nu_1) : ( # type: ignore T.min(eps_1, nu_1) + ((T.max(eps_1, nu_1) + 1) - T.min(eps_1, nu_1)) # type: ignore ), ], ] ) T.writes([data_pack[eps_1, nu_1, p_1, ci_1]]) with T.init(): data_pack[eps_1, nu_1, p_1, ci_1] = T.float32(0) data_pack[eps_1, nu_1, p_1, ci_1] = data_pack[eps_1, nu_1, p_1, ci_1] + ( (input_tile[r_a, r_b, p_1, ci_1] * B[r_a, eps_1]) * B[r_b, nu_1] ) for i0_5, i1_5, i2_4, i3_4, i4_1 in T.grid(6, 6, 9, 128, 128): with T.block("bgemm"): eps_2, nu_2, p_2, co, ci_2 = T.axis.remap("SSSSR", [i0_5, i1_5, i2_4, i3_4, i4_1]) T.block_attr({"meta_schedule.write_cache_level": [3]}) T.reads( [ bgemm[eps_2, nu_2, p_2, co], data_pack[eps_2, nu_2, p_2, ci_2], placeholder_1[eps_2, nu_2, co, ci_2], ] ) T.writes([bgemm[eps_2, nu_2, p_2, co]]) with T.init(): bgemm[eps_2, nu_2, p_2, co] = T.float32(0) bgemm[eps_2, nu_2, p_2, co] = bgemm[eps_2, nu_2, p_2, co] + ( data_pack[eps_2, nu_2, p_2, ci_2] * placeholder_1[eps_2, nu_2, co, ci_2] ) for i0_6, i1_6 in T.grid(6, 4): with T.block("A"): i_1, j_1 = T.axis.remap("SS", [i0_6, i1_6]) T.block_attr({"schedule_rule": "meta_schedule.compute_inline"}) T.writes([A[i_1, j_1]]) # fmt: off A[i_1, j_1] = T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 0)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 3)), T.float32(-8), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 2)), T.float32(4), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 1)), T.float32(-2), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 3)), T.float32(0.125), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 2)), T.float32(0.25), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 1)), T.float32(0.5), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 1)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 3)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 1)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 3)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))) # type: ignore # fmt: on for i0_7, i1_7, i2_5, i3_5, i4_2, i5_1 in T.grid(4, 4, 9, 128, 6, 6): with T.block("inverse"): vh, vw, p_3, co_1, r_a_1, r_b_1 = T.axis.remap( "SSSSRR", [i0_7, i1_7, i2_5, i3_5, i4_2, i5_1] ) T.block_attr({"schedule_rule": "meta_schedule.winograd_inverse"}) T.reads( [ inverse[vh, vw, p_3, co_1], bgemm[r_a_1, r_b_1, p_3, co_1], A[ T.min(r_a_1, r_b_1) : ( # type: ignore T.min(r_a_1, r_b_1) + ((T.max(r_a_1, r_b_1) + 1) - T.min(r_a_1, r_b_1)) # type: ignore ), T.min(vh, vw) : (T.min(vh, vw) + ((T.max(vh, vw) + 1) - T.min(vh, vw))), # type: ignore ], ] ) T.writes([inverse[vh, vw, p_3, co_1]]) with T.init(): inverse[vh, vw, p_3, co_1] = T.float32(0) inverse[vh, vw, p_3, co_1] = inverse[vh, vw, p_3, co_1] + ( (bgemm[r_a_1, r_b_1, p_3, co_1] * A[r_a_1, vh]) * A[r_b_1, vw] ) for i0_8, i1_8, i2_6, i3_6 in T.grid(1, 12, 12, 128): with T.block("conv2d_winograd"): n, h, w, co_2 = T.axis.remap("SSSS", [i0_8, i1_8, i2_6, i3_6]) T.reads( [ inverse[ T.floormod(h, 4), # type: ignore T.floormod(w, 4), # type: ignore (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), # type: ignore co_2, ] ] ) T.writes([conv2d_winograd[n, h, w, co_2]]) conv2d_winograd[n, h, w, co_2] = inverse[ T.floormod(h, 4), # type: ignore T.floormod(w, 4), # type: ignore (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), # type: ignore co_2, ]
def main( placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) for ( i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3, ) in T.grid( 1, 1, 2, 1, 1, 1, 4, 1, 14, 1, 1, 1, 4, 1, 1, 1, 4, 7, 1, 1, 1, 1, 1, 4, 1, 1, 1, 4, 4, 1, ): with T.block("conv2d_NCHWc_int8_o"): n = T.axis.spatial(1, 0) oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2) oh = T.axis.spatial(56, i2_0 * 28 + i2_2 * 4 + i2_3) ow = T.axis.spatial(56, i3_1 * 4 + i3_3) oc_block_o = T.axis.spatial(1, 0) kh = T.axis.reduce(1, 0) kw = T.axis.reduce(1, 0) ic_outer, ic_f_inner = T.axis.remap("RR", [i7_0, i8_1]) ic_s_inner_o = T.axis.reduce(1, 0) T.reads( placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4:ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16, 0:4], ) T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16]) T.block_attr({"meta_schedule.auto_tensorize": "dot_16x4_vnni"}) with T.init(): for i4_1 in T.serial(16): with T.block("conv2d_NCHWc_int8_init"): oc_block_init = T.axis.spatial(16, i4_1) T.reads() T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_init]) conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_init] = 0 for i4_1, i9_1 in T.grid(16, 4): with T.block("conv2d_NCHWc_int8"): oc_block, ic_s_inner = T.axis.remap("SR", [i4_1, i9_1]) T.reads( conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block], placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], ) T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) T.block_attr( {"meta_schedule.tiling_structure": "SSRSRS"}) conv2d_NCHWc_int8[ n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[ n, oc_chunk, oh, ow, oc_block] + T.cast( placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32", ) * T.cast( placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], "int32", )