def pooling_decompose_2( x: T.Buffer[(1, 16, 225, 225), "int8"], tensor: T.Buffer[(1, 16, 225, 225), "int8"]) -> None: pad_temp = T.alloc_buffer([1, 16, 231, 231], dtype="int8") for i0, i2_0, i3_0, ax0, ax1, ax2 in T.grid(1, 3, 3, 16, 81, 81): with T.block("pad_temp_pad_const"): ax0_1 = T.axis.spatial(1, 0) ax1_1 = T.axis.spatial(16, ax0) ax2_1 = T.axis.spatial(231, i2_0 * 75 + ax1) ax3 = T.axis.spatial(231, i3_0 * 75 + ax2) T.reads() T.writes(pad_temp[ax0_1, ax1_1, ax2_1, ax3]) pad_temp[ax0_1, ax1_1, ax2_1, ax3] = T.int8(0) for i0, i2_0, i3_0 in T.grid(1, 3, 3): for ax0, ax1, ax2 in T.grid(16, 81, 81): with T.block("pad_temp"): ax0_2 = T.axis.spatial(1, 0) ax1_2 = T.axis.spatial(16, ax0) ax2_2 = T.axis.spatial(225, i2_0 * 75 + ax1 - 3) ax3 = T.axis.spatial(225, i3_0 * 75 + ax2 - 3) T.where(3 <= i2_0 * 75 + ax1 and i2_0 * 75 + ax1 < 228 and 3 <= i3_0 * 75 + ax2 and i3_0 * 75 + ax2 < 228) T.reads(x[ax0_2, ax1_2, ax2_2, ax3]) T.writes(pad_temp[ax0_2, ax1_2, ax2_2 + 3, ax3 + 3]) pad_temp[ax0_2, ax1_2, ax2_2 + 3, ax3 + 3] = x[ax0_2, ax1_2, ax2_2, ax3] for i1, i2_1, i3_1, i4, i5 in T.grid(16, 75, 75, 7, 7): with T.block("tensor"): ax0_3, ax1_3 = T.axis.remap("SS", [i0, i1]) ax2_3 = T.axis.spatial(225, i2_0 * 75 + i2_1) ax3 = T.axis.spatial(225, i3_0 * 75 + i3_1) rv0, rv1 = T.axis.remap("RR", [i4, i5]) T.reads(pad_temp[ax0_3, ax1_3, ax2_3 + rv0, ax3 + rv1]) T.writes(tensor[ax0_3, ax1_3, ax2_3, ax3]) with T.init(): tensor[ax0_3, ax1_3, ax2_3, ax3] = T.int8(0) tensor[ax0_3, ax1_3, ax2_3, ax3] = ( tensor[ax0_3, ax1_3, ax2_3, ax3] + pad_temp[ax0_3, ax1_3, ax2_3 + rv0, ax3 + rv1])
def element_wise_invalid_annotation(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body with T.block("root"): T.reads([]) T.writes([]) B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) for i0 in T.serial(0, 128): for ax1 in T.serial(0, 128): with T.block("B"): T.block_attr({"buffer_dim_align": [0]}) vi, vj = T.axis.remap("SS", [i0, ax1]) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) B[vi, vj] = (A[vi, vj]*T.float32(2)) for i1 in T.serial(0, 128): with T.block("C"): vi_1, vj_1 = T.axis.remap("SS", [i0, i1]) T.reads([B[vi_1, vj_1]]) T.writes([C[vi_1, vj_1]]) C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1))
def shared_mem_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i0 in T.thread_binding(0, 2, thread="blockIdx.x"): for i1 in T.thread_binding(0, 2, thread="vthread"): for i2 in T.thread_binding(0, 4, thread="threadIdx.x"): with T.block(): T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) B = T.alloc_buffer((16, 16), "float32", scope="shared") for j in range(0, 16): with T.block() as []: T.reads(A[i0 * 8 + i1 * 4 + i2, j]) T.writes(B[i0 * 8 + i1 * 4 + i2, j]) B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 for j in range(0, 16): with T.block() as []: T.reads(B[i0 * 8 + i1 * 4 + i2, j]) T.writes(C[i0 * 8 + i1 * 4 + i2, j]) C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4 + i2, j] * 2.0
def cache_read_shape_int64(var_A: T.handle, var_C: T.handle) -> None: A = T.match_buffer(var_A, (T.int64(128), T.int64(128)), dtype="float32") C = T.match_buffer(var_C, (T.int64(128), T.int64(128)), dtype="float32") B = T.alloc_buffer([T.int64(128), T.int64(128)], dtype="float32") A_global = T.alloc_buffer([T.int64(128), T.int64(128)], dtype="float32") for ax0, ax1 in T.grid(T.int64(128), T.int64(128)): with T.block("A_global"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v0, v1]) T.writes(A_global[v0, v1]) A_global[v0, v1] = A[v0, v1] for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A_global[vi, vj]) T.writes(B[vi, vj]) B[vi, vj] = A_global[vi, vj] * T.float32(2) for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[vi, vj]) T.writes(C[vi, vj]) C[vi, vj] = B[vi, vj] + T.float32(1)
def simple_compute(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): for tx in T.thread_binding(0, 16, thread="threadIdx.x"): for i in T.serial( 0, 16, annotations={ "software_pipeline_stage": [0, 1], "software_pipeline_order": [0, 1] }, ): with T.block(): T.reads(A[tx, i]) T.writes(C[tx, i]) B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") with T.block(): T.reads(A[tx, i]) T.writes(B[tx, 0]) B[tx, 0] = A[tx, i] * T.float32(2) with T.block(): T.reads(B[tx, 0]) T.writes(C[tx, i]) C[tx, i] = B[tx, 0] + T.float32(1)
def element_wise(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body with T.block([], "root"): T.reads([]) T.writes([]) B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) for i0 in T.serial(0, 128): for ax1 in T.serial(0, 128): with T.block([128, 128], "B") as [vi, vj]: T.bind(vi, i0) T.bind(vj, ax1) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) B[vi, vj] = (A[vi, vj]*T.float32(2)) for i1 in T.serial(0, 128): with T.block([128, 128], "C") as [vi_1, vj_1]: T.bind(vi_1, i0) T.bind(vj_1, i1) T.reads([B[vi_1, vj_1]]) T.writes([C[vi_1, vj_1]]) C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1))
def compacted_strided_buffer_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i0 in range(0, 4): with T.block(): T.reads(A[i0 * 4:i0 * 4 + 4, 0:16]) T.writes(C[i0 * 4:i0 * 4 + 4, 0:16]) B = T.alloc_buffer([4, 16], "float32", strides=[17, 1], scope="global") for i1 in range(0, 4): for j in range(0, 16): with T.block() as []: T.reads(A[i0 * 4 + i1, j]) T.writes(B[i1, j]) B[i1, j] = A[i0 * 4 + i1, j] + 1.0 for i1 in range(0, 4): for j in range(0, 16): with T.block() as []: T.reads(B[i1, j]) T.writes(C[i0 * 4 + i1, j]) C[i0 * 4 + i1, j] = B[i1, j] * 2.0
def before_blockize_rca( A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"], ) -> None: B = T.alloc_buffer([128, 128], dtype="float32") for i, j in T.grid(8, 8): with T.block("B_o"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi * 16:vi * 16 + 16, vj * 16:vj * 16 + 16]) T.writes(B[vi * 16:vi * 16 + 16, vj * 16:vj * 16 + 16]) for i_1, j_1 in T.grid(16, 16): with T.block("B"): vi_i, vj_i = T.axis.remap("SS", [i_1, j_1]) T.reads(A[vi * 16 + vi_i, vj * 16 + vj_i]) T.writes(B[vi * 16 + vi_i, vj * 16 + vj_i]) B[vi * 16 + vi_i, vj * 16 + vj_i] = A[vi * 16 + vi_i, vj * 16 + vj_i] * 2.0 for ax0, ax1 in T.grid(16, 16): with T.block("C"): vi = T.axis.spatial(128, i * 16 + ax0) vj = T.axis.spatial(128, j * 16 + ax1) T.reads(B[vi, vj]) T.writes(C[vi, vj]) C[vi, vj] = B[vi, vj] + 1.0
def opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: A = T.match_buffer(a, (128, 128), dtype="float16") B = T.match_buffer(b, (128, 128), dtype="float16") C = T.match_buffer(c, (128, 128), dtype="float16") D = T.match_buffer(d, (128, 128), dtype="float16") for i, j in T.grid(128, 128): with T.block("load_store"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(D[vi, vj]) D[vi, vj] = A[vi, vj] for i, j in T.grid(8, 8): with T.block("opaque"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.evaluate( T.tvm_load_matrix_sync( B.data, 16, 16, 16, vi * 8 + vj, T.tvm_access_ptr( T.type_annotation(dtype="float16"), A.data, vi * 2048 + vj * 16, 128, 1, dtype="handle", ), 128, "row_major", dtype="handle", ) ) for i, j in T.grid(8, 8): with T.block("match_buffer"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) A0 = T.match_buffer( A[ vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16, ], (16, 16), "float16", strides=[128, 1], offset_factor=1, ) C0 = T.match_buffer( C[ vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16, ], (16, 16), "float16", strides=[128, 1], offset_factor=1, ) T.evaluate( T.tvm_load_matrix_sync( C0.data, 16, 16, 16, vi * 8 + vj, T.tvm_access_ptr( T.type_annotation(dtype="float16"), A0.data, A0.elem_offset, A0.strides[0], 1, dtype="handle", ), 128, "row_major", dtype="handle", ) )
def mismatch_args() -> None: A = T.alloc_buffer((128, 128), "float32") with T.block(): T.reads(A[0, 0], A[1, 1]) # error T.evaluate(1.0)
def implicit_root_has_read(): T.reads([]) # error: implicit root does not support reads T.evaluate(0.0)
def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp") B = T.match_buffer(b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp") C = T.match_buffer(c, (WARP_SIZE, local_size_out), out_dtype, align=128, offset_factor=16, scope="warp") with T.block("root"): T.reads( C[0:WARP_SIZE, 0:local_size_out], A[0:WARP_SIZE, 0:local_size], B[0:WARP_SIZE, 0:local_size], ) T.writes(C[0:WARP_SIZE, 0:local_size_out]) tx = T.env_thread("threadIdx.x") T.launch_thread(tx, WARP_SIZE) T.evaluate( T.ptx_mma( mma_prefix, "row", "col", in_dtype_abbrv, in_dtype_abbrv, out_dtype_abbrv, A.data, A.elem_offset + tx * lift(local_size), B.data, B.elem_offset + tx * lift(local_size), C.data, C.elem_offset + tx * lift(local_size_out), False, dtype=out_dtype, )) T.evaluate( T.ptx_mma( mma_prefix, "row", "col", in_dtype_abbrv, in_dtype_abbrv, out_dtype_abbrv, A.data, A.elem_offset + tx * lift(local_size), B.data, B.elem_offset + tx * lift(local_size) + lift(local_size) // 2, C.data, C.elem_offset + tx * lift(local_size_out) + lift(local_size_out) // 2, False, dtype=out_dtype, ))
def duplicate_reads() -> None: A = T.alloc_buffer((128, 128), "float32") with T.block([16, 16]) as [vi, vj]: T.reads(A[0:8, 0:8]) T.reads(A[0:16, 0:16]) # error T.evaluate(1.0)
def GmmCuda2(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None: Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local") X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"): for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"): for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"): for i1_3_init, i2_4_init in T.grid(4, 2): with T.block("Z_init"): b = T.axis.spatial(1, 0) i = T.axis.spatial( 128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init) j = T.axis.spatial( 128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init) T.reads() T.writes(Z_local[b, i, j]) Z_local[b, i, j] = T.float32(0) for i3_0 in T.serial(4): for ax0_ax1_ax2_fused_0 in T.serial(4): for ax0_ax1_ax2_fused_1 in T.thread_binding( 128, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in T.vectorized(2): with T.block("X_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial( 128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32) v2 = T.axis.spatial( 128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32) T.reads(X[v0, v1, v2]) T.writes(X_shared[v0, v1, v2]) X_shared[v0, v1, v2] = X[v0, v1, v2] for ax0_ax1_ax2_fused_0 in T.serial(8): for ax0_ax1_ax2_fused_1 in T.thread_binding( 128, thread="threadIdx.x"): with T.block("Y_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial( 128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32) v2 = T.axis.spatial( 128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32) T.reads(Y[v0, v1, v2]) T.writes(Y_shared[v0, v1, v2]) Y_shared[v0, v1, v2] = Y[v0, v1, v2] for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid( 1, 1, 4, 1, 32, 1, 1, 2): with T.block("Z_update"): b = T.axis.spatial(1, 0) i = T.axis.spatial( 128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3) j = T.axis.spatial( 128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4) k = T.axis.reduce(128, i3_0 * 32 + i3_2) T.block_attr({ "meta_schedule.thread_extent_low_inclusive": 1024, "meta_schedule.thread_extent_high_inclusive": 1024, }) T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j]) T.writes(Z_local[b, i, j]) Z_local[b, i, j] = Z_local[ b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j] for ax0, ax1, ax2 in T.grid(1, 4, 2): with T.block("Z_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial( 128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1) v2 = T.axis.spatial( 128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2) T.reads(Z_local[v0, v1, v2]) T.writes(Z[v0, v1, v2]) Z[v0, v1, v2] = Z_local[v0, v1, v2]
def main( placeholder: T.Buffer[(1, 384), "int64"], placeholder_1: T.Buffer[(30522, 768), "float32"], placeholder_2: T.Buffer[(1, 384, 768), "float32"], T_add: T.Buffer[(1, 384, 768), "float32"], ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) compile_engine_const = T.alloc_buffer([], dtype="int64") T_less = T.alloc_buffer([1, 384], dtype="bool") compile_engine_const_1 = T.alloc_buffer([], dtype="int64") T_add_1 = T.alloc_buffer([1, 384], dtype="int64") T_where = T.alloc_buffer([1, 384], dtype="int64") T_take = T.alloc_buffer([1, 384, 768], dtype="float32") with T.block("compile_engine_const"): vi = T.axis.spatial(1, 0) T.reads() T.writes(compile_engine_const[()]) compile_engine_const[()] = T.int64(0) for i0, i1 in T.grid(1, 384): with T.block("T_less"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(placeholder[ax0, ax1], compile_engine_const[()]) T.writes(T_less[ax0, ax1]) T_less[ax0, ax1] = placeholder[ax0, ax1] < compile_engine_const[()] with T.block("compile_engine_const_1"): vi = T.axis.spatial(1, 0) T.reads() T.writes(compile_engine_const_1[()]) compile_engine_const_1[()] = T.int64(30522) for i0, i1 in T.grid(1, 384): with T.block("T_add"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(placeholder[ax0, ax1], compile_engine_const_1[()]) T.writes(T_add_1[ax0, ax1]) T_add_1[ax0, ax1] = placeholder[ax0, ax1] + compile_engine_const_1[()] for i0, i1 in T.grid(1, 384): with T.block("T_where"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(T_less[ax0, ax1], T_add_1[ax0, ax1], placeholder[ax0, ax1]) T.writes(T_where[ax0, ax1]) T_where[ax0, ax1] = T.Select( T.cast(T_less[ax0, ax1], "int32") != 0, T_add_1[ax0, ax1], placeholder[ax0, ax1]) for i0, i1, i2 in T.grid(1, 384, 768): with T.block("T_take"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads( placeholder_1[T.min(T.max(T.int64(0), T_where[ ax0, ax1]), T.int64(30521)), ax2], T_where[ax0, ax1], ) T.writes(T_take[ax0, ax1, ax2]) T_take[ax0, ax1, ax2] = placeholder_1[ T.min(T.max(T.int64(0), T_where[ax0, ax1]), T.int64(30521)), ax2] for i0, i1, i2 in T.grid(1, 384, 768): with T.block("T_add_1"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_take[ax0, ax1, ax2], placeholder_2[ax0, ax1, ax2]) T.writes(T_add[ax0, ax1, ax2]) T_add[ax0, ax1, ax2] = T_take[ax0, ax1, ax2] + placeholder_2[ax0, ax1, ax2]
def main(A: T.Buffer[(1024, 1024), "float32"], B: T.Buffer[(1024, 1024), "float32"], C: T.Buffer[(1024, 1024), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) # body # with T.block("root") for blockIdx_y in T.thread_binding(32, thread="blockIdx.y"): for blockIdx_x in T.thread_binding(32, thread="blockIdx.x"): for threadIdx_y in T.thread_binding(2, thread="threadIdx.y"): for threadIdx_x in T.thread_binding(2, thread="threadIdx.x"): for k_0 in T.serial(32): with T.block(): T.reads( A[blockIdx_y * 32:blockIdx_y * 32 + 32, k_0 * 32:k_0 * 32 + 32], B[k_0 * 32:k_0 * 32 + 32, blockIdx_x * 32:blockIdx_x * 32 + 32]) T.writes( C[blockIdx_y * 32:blockIdx_y * 32 + 32, blockIdx_x * 32:blockIdx_x * 32 + 32]) A_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") B_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") for ax0_ax1_fused_0 in T.serial(64): for ax0_ax1_fused_3 in T.vectorized(4): with T.block("A_shared"): T.reads(A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) T.writes(A_shared[ blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) T.block_attr({ "tir.manifest_shared_memory_local_stage": 1 }) A_shared[ blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] = A[ blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] for ax0_ax1_fused_0 in T.serial(64): for ax0_ax1_fused_3 in T.vectorized(4): with T.block("B_shared"): T.reads(B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) T.writes(B_shared[ k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) T.block_attr({ "tir.manifest_shared_memory_local_stage": 1 }) B_shared[ k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] = B[ k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] for k_1, i_2, j_2, k_2 in T.grid( 2, 16, 16, 16): with T.block("C"): T.reads( A_shared[blockIdx_y * 32 + threadIdx_y * 16 + i_2, k_0 * 32 + k_1 * 16 + k_2], B_shared[k_0 * 32 + k_1 * 16 + k_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2]) T.writes(C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2]) if k_0 * 32 + k_1 * 16 + k_2 == 0: C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] = T.float32(0) C[ blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] = C[ blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] + A_shared[ blockIdx_y * 32 + threadIdx_y * 16 + i_2, k_0 * 32 + k_1 * 16 + k_2] * B_shared[ k_0 * 32 + k_1 * 16 + k_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2]
def opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: A = T.match_buffer(a, (128, 128), dtype="float16") B = T.match_buffer(b, (128, 128), dtype="float16") C = T.match_buffer(c, (128, 128), dtype="float16") D = T.match_buffer(d, (128, 128), dtype="float16") with T.block([128, 128], "load_store") as [vi, vj]: T.reads(A[vi, vj]) T.writes(D[vi, vj]) D.data[vi * 128 + vj] = T.load("float16", A.data, vi * 128 + vj) with T.block([8, 8], "opaque") as [vi, vj]: T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.evaluate( T.tvm_load_matrix_sync( B.data, 16, 16, 16, vi * 8 + vj, T.tvm_access_ptr( T.type_annotation(dtype="float16"), A.data, vi * 2048 + vj * 16, 128, 1, dtype="handle", ), 128, "row_major", dtype="handle", ) ) with T.block([8, 8], "match_buffer") as [vi, vj]: T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) A0 = T.match_buffer( A[ vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16, ], (16, 16), "float16", strides=[128, 1], offset_factor=1, ) C0 = T.match_buffer( C[ vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16, ], (16, 16), "float16", strides=[128, 1], offset_factor=1, ) T.evaluate( T.tvm_load_matrix_sync( C0.data, 16, 16, 16, vi * 8 + vj, T.tvm_access_ptr( T.type_annotation(dtype="float16"), A0.data, A0.elem_offset, A0.strides[0], 1, dtype="handle", ), 128, "row_major", dtype="handle", ) )
def lowered_single_reduction_loop_with_block_predicate( A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None: T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") cross_thread_0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") in_thread_0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") cross_thread_1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") in_thread_1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i0 in T.serial(256): for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_maxelem_in_thread_init"): T.reads() T.writes(in_thread_0[0]) in_thread_0[0] = T.float32(-3.4028234663852886e38) with T.block("T_softmax_maxelem_in_thread"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) T.reads(A[i0_1, k], in_thread_0[0]) T.writes(in_thread_0[0]) in_thread_0[0] = T.max(in_thread_0[0], A[i0_1, k]) with T.block("T_softmax_maxelem_cross_thread"): T.reads(in_thread_0[0]) T.writes(cross_thread_0[0]) T.attr( T.comm_reducer(lambda x, y: T.max(x, y), [T.float32(-3.4028234663852886e38)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), in_thread_0[0], True, cross_thread_0.data, ax1_1, dtype="handle", )) with T.block("T_softmax_maxelem_write_back"): i0_2 = T.axis.spatial(256, i0) T.reads(cross_thread_0[0]) T.writes(T_softmax_maxelem_shared[i0_2]) T_softmax_maxelem_shared[i0_2] = cross_thread_0[0] for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_expsum_in_thread_init"): T.reads() T.writes(in_thread_1[0]) in_thread_1[0] = T.float32(0) with T.block("T_softmax_expsum_in_thread"): i0_3 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) T.reads(A[i0_3, k], T_softmax_maxelem_shared[i0_3], in_thread_1[0]) T.writes(in_thread_1[0]) in_thread_1[0] = in_thread_1[0] + T.exp( A[i0_3, k] - T_softmax_maxelem_shared[i0_3], dtype="float32") with T.block("T_softmax_expsum_cross_thread"): T.reads(in_thread_1[0]) T.writes(cross_thread_1[0]) T.attr( T.comm_reducer(lambda x_1, y_1: x_1 + y_1, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), in_thread_1[0], True, cross_thread_1.data, ax1_1, dtype="handle", )) with T.block("T_softmax_expsum_write_back"): i0_4 = T.axis.spatial(256, i0) T.reads(cross_thread_1[0]) T.writes(T_softmax_expsum_shared[i0_4]) T_softmax_expsum_shared[i0_4] = cross_thread_1[0] for i1_0 in T.serial(1): for i1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_norm"): i0_5 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_1) T.where(i1_0 * 512 + i1_1 < 256) T.reads(A[i0_5, i1], T_softmax_maxelem_shared[i0_5], T_softmax_expsum_shared[i0_5]) T.writes(T_softmax_norm[i0_5, i1]) T.block_attr({"axis": 1}) T_softmax_norm[i0_5, i1] = ( T.exp(A[i0_5, i1] - T_softmax_maxelem_shared[i0_5], dtype="float32") / T_softmax_expsum_shared[i0_5])
def main( X: T.Buffer[(128, 128), "int8"], W: T.Buffer[(128, 128), "int8"], compute: T.Buffer[(128, 128), "int32"], ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") compute_local = T.alloc_buffer([128, 128], dtype="int32", scope="local") X_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared") W_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared") for i0_0_i1_0_fused in T.thread_binding(16, thread="blockIdx.x"): for i0_1_i1_1_fused in T.thread_binding(2, thread="vthread.x"): for i0_2_i1_2_fused in T.thread_binding(2, thread="threadIdx.x"): for i0_3_init, i1_3_init, i0_4_init in T.grid(4, 16, 4): with T.block("compute_o_init"): i = T.axis.spatial( 128, i0_0_i1_0_fused // 2 * 16 + i0_3_init * 4 + i0_4_init) j = T.axis.spatial( 128, i0_0_i1_0_fused % 2 * 64 + i0_1_i1_1_fused * 32 + i0_2_i1_2_fused * 16 + i1_3_init, ) T.reads() T.writes(compute_local[i, j]) T.block_attr( {"meta_schedule.auto_tensorize": "dp4a"}) with T.block("compute_init"): T.reads() T.writes(compute_local[i, j]) compute_local[i, j] = 0 for i2_0_0 in T.serial(2): for ax0_ax1_fused in T.serial(1024): with T.block("X_shared"): v0 = T.axis.spatial( 128, i0_0_i1_0_fused // 2 * 16 + ax0_ax1_fused // 64) v1 = T.axis.spatial( 128, i2_0_0 * 64 + ax0_ax1_fused % 64) T.reads(X[v0, v1]) T.writes(X_shared[v0, v1]) T.block_attr( {"meta_schedule.cooperative_fetch": 4}) X_shared[v0, v1] = X[v0, v1] for ax0_ax1_fused in T.serial(4096): with T.block("W_shared"): v0 = T.axis.spatial( 128, i0_0_i1_0_fused % 2 * 64 + ax0_ax1_fused // 64) v1 = T.axis.spatial( 128, i2_0_0 * 64 + ax0_ax1_fused % 64) T.reads(W[v0, v1]) T.writes(W_shared[v0, v1]) T.block_attr( {"meta_schedule.cooperative_fetch": 1}) W_shared[v0, v1] = W[v0, v1] for i2_0_1, i0_3, i1_3, i2_0_2, i0_4, i1_4 in T.grid( 2, 4, 16, 8, 4, 1): with T.block("compute_o_update"): i = T.axis.spatial( 128, i0_0_i1_0_fused // 2 * 16 + i0_3 * 4 + i0_4) j = T.axis.spatial( 128, i0_0_i1_0_fused % 2 * 64 + i0_1_i1_1_fused * 32 + i0_2_i1_2_fused * 16 + i1_3, ) k_o = T.axis.reduce( 32, i2_0_0 * 16 + i2_0_1 * 8 + i2_0_2) T.reads( compute_local[i, j], X_shared[i, k_o * 4:k_o * 4 + 4], W_shared[j, k_o * 4:k_o * 4 + 4], ) T.writes(compute_local[i, j]) A = T.match_buffer( X_shared[i, k_o * 4:k_o * 4 + 4], [4], dtype="int8", scope="shared", align=4, offset_factor=1, ) B = T.match_buffer( W_shared[j, k_o * 4:k_o * 4 + 4], [4], dtype="int8", scope="shared", align=4, offset_factor=1, ) C = T.match_buffer( compute_local[i, j], [1], dtype="int32", scope="local", align=4, offset_factor=1, ) C[0] = C[0] + T.call_pure_extern( "__dp4a", A[T.ramp(0, 1, 4)], B[T.ramp(0, 1, 4)], 0, dtype="int32", ) for ax0, ax1 in T.grid(16, 16): with T.block("compute_local"): v0 = T.axis.spatial( 128, i0_0_i1_0_fused // 2 * 16 + ax0) v1 = T.axis.spatial( 128, i0_0_i1_0_fused % 2 * 64 + i0_1_i1_1_fused * 32 + i0_2_i1_2_fused * 16 + ax1, ) T.reads(compute_local[v0, v1]) T.writes(compute[v0, v1]) compute[v0, v1] = compute_local[v0, v1]
def main( X: T.Buffer[(128, 128), "int8"], W: T.Buffer[(128, 128), "int8"], compute: T.Buffer[(128, 128), "int32"], ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) compute_local = T.alloc_buffer([128, 128], dtype="int32", scope="local") X_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared") W_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared") for i0_0_i1_0_fused in T.thread_binding(16, thread="blockIdx.x"): for i0_1_i1_1_fused in T.thread_binding(2, thread="vthread.x"): for i0_2_i1_2_fused in T.thread_binding(2, thread="threadIdx.x"): for i2_0_0 in T.serial(2): for ax0_ax1_fused in T.serial(1024): with T.block("X_shared"): v0 = T.axis.spatial( 128, i0_0_i1_0_fused // 2 * 16 + ax0_ax1_fused // 64) v1 = T.axis.spatial( 128, i2_0_0 * 64 + ax0_ax1_fused % 64) T.reads(X[v0, v1]) T.writes(X_shared[v0, v1]) T.block_attr( {"meta_schedule.cooperative_fetch": 4}) X_shared[v0, v1] = X[v0, v1] for ax0_ax1_fused in T.serial(4096): with T.block("W_shared"): v0 = T.axis.spatial( 128, i0_0_i1_0_fused % 2 * 64 + ax0_ax1_fused // 64) v1 = T.axis.spatial( 128, i2_0_0 * 64 + ax0_ax1_fused % 64) T.reads(W[v0, v1]) T.writes(W_shared[v0, v1]) T.block_attr( {"meta_schedule.cooperative_fetch": 1}) W_shared[v0, v1] = W[v0, v1] for i2_0_1, i0_3, i1_3, i2_0_2, i0_4, i1_4 in T.grid( 2, 4, 16, 8, 4, 1): with T.block("compute_o"): i = T.axis.spatial( 128, i0_0_i1_0_fused // 2 * 16 + i0_3 * 4 + i0_4) j = T.axis.spatial( 128, i0_0_i1_0_fused % 2 * 64 + i0_1_i1_1_fused * 32 + i0_2_i1_2_fused * 16 + i1_3, ) k_o = T.axis.reduce( 32, i2_0_0 * 16 + i2_0_1 * 8 + i2_0_2) T.reads( X_shared[i, k_o * 4:k_o * 4 + 4], W_shared[j, k_o * 4:k_o * 4 + 4], ) T.writes(compute_local[i, j]) T.block_attr( {"meta_schedule.auto_tensorize": "dp4a"}) with T.init(): with T.block("compute_init"): T.reads() T.writes(compute_local[i, j]) compute_local[i, j] = 0 for i2_1 in T.serial(4): with T.block("compute"): k = T.axis.reduce(4, i2_1) T.reads( compute_local[i, j], X_shared[i, k_o * 4 + k], W_shared[j, k_o * 4 + k], ) T.writes(compute_local[i, j]) T.block_attr({ "meta_schedule.tiling_structure": "SSSRRSRS" }) compute_local[ i, j] = compute_local[i, j] + T.cast( X_shared[i, k_o * 4 + k], "int32") * T.cast( W_shared[j, k_o * 4 + k], "int32") for ax0, ax1 in T.grid(16, 16): with T.block("compute_local"): v0 = T.axis.spatial( 128, i0_0_i1_0_fused // 2 * 16 + ax0) v1 = T.axis.spatial( 128, i0_0_i1_0_fused % 2 * 64 + i0_1_i1_1_fused * 32 + i0_2_i1_2_fused * 16 + ax1, ) T.reads(compute_local[v0, v1]) T.writes(compute[v0, v1]) compute[v0, v1] = compute_local[v0, v1]
def main( placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) for ( i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3, ) in T.grid( 1, 1, 2, 1, 1, 1, 4, 1, 14, 1, 1, 1, 4, 1, 1, 1, 4, 7, 1, 1, 1, 1, 1, 4, 1, 1, 1, 4, 4, 1, ): with T.block("conv2d_NCHWc_int8_o"): n = T.axis.spatial(1, 0) oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2) oh = T.axis.spatial(56, i2_0 * 28 + i2_2 * 4 + i2_3) ow = T.axis.spatial(56, i3_1 * 4 + i3_3) oc_block_o = T.axis.spatial(1, 0) kh = T.axis.reduce(1, 0) kw = T.axis.reduce(1, 0) ic_outer, ic_f_inner = T.axis.remap("RR", [i7_0, i8_1]) ic_s_inner_o = T.axis.reduce(1, 0) T.reads( placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4:ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16, 0:4], ) T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16]) T.block_attr({"meta_schedule.auto_tensorize": "dot_16x4_vnni"}) with T.init(): for i4_1 in T.serial(16): with T.block("conv2d_NCHWc_int8_init"): oc_block_init = T.axis.spatial(16, i4_1) T.reads() T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_init]) conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_init] = 0 for i4_1, i9_1 in T.grid(16, 4): with T.block("conv2d_NCHWc_int8"): oc_block, ic_s_inner = T.axis.remap("SR", [i4_1, i9_1]) T.reads( conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block], placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], ) T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) T.block_attr( {"meta_schedule.tiling_structure": "SSRSRS"}) conv2d_NCHWc_int8[ n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[ n, oc_chunk, oh, ow, oc_block] + T.cast( placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32", ) * T.cast( placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], "int32", )
def main( placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0 in T.grid( 1, 1, 2, 1, 1, 1, 4, 1, 14, 1, 1, 1): for i1_2_init, i2_2_init, i2_3_init, i3_3_init in T.grid( 4, 7, 4, 4): with T.block("conv2d_NCHWc_int8_o_init"): n = T.axis.spatial(1, 0) oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2_init) oh = T.axis.spatial(56, i2_0 * 28 + i2_2_init * 4 + i2_3_init) ow = T.axis.spatial(56, i3_1 * 4 + i3_3_init) oc_block_o = T.axis.spatial(1, 0) T.reads() T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16]) for i4_1 in T.vectorized(16): with T.block("conv2d_NCHWc_int8_init"): oc_block_init = T.axis.spatial(16, i4_1) T.reads() T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_init]) conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_init] = 0 for ( i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3, ) in T.grid(4, 1, 1, 1, 4, 7, 1, 1, 1, 1, 1, 4, 1, 1, 1, 4, 4, 1): with T.block("conv2d_NCHWc_int8_o_update"): n = T.axis.spatial(1, 0) oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2) oh = T.axis.spatial(56, i2_0 * 28 + i2_2 * 4 + i2_3) ow = T.axis.spatial(56, i3_1 * 4 + i3_3) oc_block_o = T.axis.spatial(1, 0) kh = T.axis.reduce(1, 0) kw = T.axis.reduce(1, 0) ic_outer, ic_f_inner = T.axis.remap("RR", [i7_0, i8_1]) ic_s_inner_o = T.axis.reduce(1, 0) T.reads( conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16], placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4:ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16, 0:4], ) T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16]) A = T.match_buffer( placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4:ic_f_inner * 4 + 4], [4], dtype="uint8", offset_factor=1, ) B = T.match_buffer( placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16, 0:4], [16, 4], dtype="int8", offset_factor=1, ) C = T.match_buffer( conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16], [16], dtype="int32", offset_factor=1, ) A_u8x4 = A.vload([0], "uint8x4") A_i32 = T.reinterpret(A_u8x4, dtype="int32") B_i8x64 = B.vload([0, 0], dtype="int8x64") B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16") C[T.ramp( 0, 1, 16)] = C[T.ramp(0, 1, 16)] + T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id( "llvm.x86.avx512.vpdpbusd.512"), T.uint32(0), T.broadcast(0, 16), T.broadcast(A_i32, 16), B_i32x16, dtype="int32x16", )
def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: A = T.match_buffer(var_A, [256, 256], dtype="float32") T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256], dtype="float32") T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") reduce_temp1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") normal_reduce_temp1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i0 in T.thread_binding(0, 256, thread="blockIdx.x"): for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_maxelem_normal_reduction_init"): T.reads([]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.min_value("float32") for ax0_0 in T.serial(0, 8): with T.block("T_softmax_maxelem_normal_reduction"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads([A[i0_1, k], normal_reduce_temp0[0]]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.max(normal_reduce_temp0[0], A[i0_1, k]) with T.block("T_softmax_maxelem_cross_thread_reduction"): T.reads([normal_reduce_temp0[0]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: T.max(x, y), [T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), normal_reduce_temp0[0], True, reduce_temp0.data, ax0_1, dtype="handle", )) with T.block("T_softmax_maxelem_write_back"): i0_2 = T.axis.spatial(256, i0) T.reads([reduce_temp0[0]]) T.writes([T_softmax_maxelem_shared[i0_2]]) T_softmax_maxelem_shared[i0_2] = reduce_temp0[0] for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_expsum_normal_reduction_init"): T.reads([]) T.writes([normal_reduce_temp1[0]]) normal_reduce_temp1[0] = T.float32(0) for ax0_0 in T.serial(0, 8): with T.block("T_softmax_expsum_normal_reduction"): i0_3 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads([ A[i0_3, k], T_softmax_maxelem_shared[i0_3], normal_reduce_temp1[0], ]) T.writes([normal_reduce_temp1[0]]) normal_reduce_temp1[0] = normal_reduce_temp1[0] + T.exp( A[i0_3, k] - T_softmax_maxelem_shared[i0_3], dtype="float32") with T.block("T_softmax_expsum_cross_thread_reduction"): T.reads([normal_reduce_temp1[0]]) T.writes([reduce_temp1[0]]) T.attr( T.comm_reducer(lambda x_1, y_1: x_1 + y_1, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), normal_reduce_temp1[0], True, reduce_temp1.data, ax0_1, dtype="handle", )) with T.block("T_softmax_expsum_write_back"): i0_4 = T.axis.spatial(256, i0) T.reads([reduce_temp1[0]]) T.writes([T_softmax_expsum_shared[i0_4]]) T_softmax_expsum_shared[i0_4] = reduce_temp1[0] for i1_0 in T.serial(0, 8): for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_norm"): i0_5 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_0 * 32 + i1_1) T.reads([ A[i0_5, i1], T_softmax_maxelem_shared[i0_5], T_softmax_expsum_shared[i0_5], ]) T.writes([T_softmax_norm[i0_5, i1]]) T.block_attr({"axis": 1}) T_softmax_norm[i0_5, i1] = (T.exp( A[i0_5, i1] - T_softmax_maxelem_shared[i0_5], dtype="float32", ) / T_softmax_expsum_shared[i0_5])
def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) A = T.match_buffer(var_A, [512, 512], dtype="float32") B = T.match_buffer(var_B, [512, 512], dtype="float32") C = T.match_buffer(var_C, [512, 512], dtype="float32") # body # with T.block("root") C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"): for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"): for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): for i2_0 in T.serial(0, 1): for ax0_ax1_fused_0 in T.serial(0, 32768): for ax0_ax1_fused_1 in T.thread_binding( 0, 8, thread="threadIdx.x"): with T.block("A_shared"): v0 = T.axis.spatial( 512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) // 512) v1 = T.axis.spatial( 512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) % 512) T.reads([A[v0, v1]]) T.writes([A_shared[v0, v1]]) T.block_attr( {"meta_schedule.cooperative_fetch": 1}) A_shared[v0, v1] = A[v0, v1] for ax0_ax1_fused_0 in T.serial(0, 1024): for ax0_ax1_fused_1 in T.thread_binding( 0, 8, thread="threadIdx.x"): for ax0_ax1_fused_2 in T.vectorized(0, 2): with T.block("B_shared"): v0 = T.axis.spatial( 512, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 32) v1 = T.axis.spatial( 512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 32) T.reads([B[v0, v1]]) T.writes([B_shared[v0, v1]]) T.block_attr({ "meta_schedule.cooperative_fetch": 2 }) B_shared[v0, v1] = B[v0, v1] for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid( 16, 2, 2, 32, 16, 2): with T.block("C"): i = T.axis.spatial( 512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) j = T.axis.spatial( 512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4) k = T.axis.reduce(512, i2_1 * 32 + i2_2) T.reads([ C_local[i, j], A_shared[i, k], B_shared[k, j] ]) T.writes([C_local[i, j]]) with T.init(): C_local[i, j] = T.float32(0) C_local[i, j] = C_local[ i, j] + A_shared[i, k] * B_shared[k, j] for ax0, ax1 in T.grid(32, 4): with T.block("C_local"): v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0) v1 = T.axis.spatial( 512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1) T.reads([C_local[v0, v1]]) T.writes([C[v0, v1]]) C[v0, v1] = C_local[v0, v1]
def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None: # match buffer A = T.match_buffer(a, [1024, 1024], "float16") B = T.match_buffer(b, [1024, 1024], "float16") C = T.match_buffer(c, [1024, 1024], "float32") # body for blockIdx_x in T.thread_binding(0, 16, "blockIdx.x"): for blockIdx_y in T.thread_binding(0, 8, "blockIdx.y"): with T.block([16, 8]) as [bx, by]: T.bind(bx, blockIdx_x) T.bind(by, blockIdx_y) shared_A = T.alloc_buffer([1024, 1024], "float16", scope="shared") shared_B = T.alloc_buffer([1024, 1024], "float16", scope="shared") wmma_A = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a") wmma_B = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_b") wmma_C = T.alloc_buffer([1024, 1024], "float32", scope="wmma.accumulator") for ty in T.thread_binding(0, 2, "threadIdx.y"): for tz in T.thread_binding(0, 2, "threadIdx.z"): for i, j in T.grid(2, 4): with T.block([64, 64]) as [vi, vj]: T.bind(vi, bx * 4 + ty * 2 + i) T.bind(vj, by * 8 + tz * 4 + j) T.reads([]) T.writes(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) C0 = T.match_buffer( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) T.evaluate( T.tvm_fill_fragment( C0.data, 16, 16, 16, i * 4 + j, T.float32(0), dtype="handle", ) ) for ko in range(0, 32): # copy data from global to shared for tx in T.thread_binding(0, 32, "threadIdx.x"): for i0, j0 in T.grid(1, 4): for j1 in T.vectorized(0, 4): with T.block([1024, 1024]) as [vi, vj]: T.bind(vi, bx * 64 + ty * 32 + tx + i0) T.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) shared_A[vi, vj + 8] = A[vi, vj] for i0, j0 in T.grid(2, 4): for j1 in T.vectorized(0, 4): with T.block([1024, 1024]) as [vi, vj]: T.bind(vi, by * 128 + ty * 64 + tx * 2 + i0) T.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) shared_B[vi, vj + 8] = B[vi, vj] for ki in range(0, 2): for i in range(0, 2): with T.block([64, 64]) as [vi, vk]: T.bind(vi, bx * 4 + ty * 2 + i) T.bind(vk, ko * 2 + ki) T.reads( shared_A[ vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ] ) T.writes( wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16] ) s0 = T.var("int32") s1 = T.var("int32") A0 = T.match_buffer( shared_A[ vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ], (16, 16 + 8), "float16", strides=[s0, s1], scope="shared", offset_factor=1, ) wmma_A0 = T.match_buffer( wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_a", offset_factor=1, ) T.evaluate( T.tvm_load_matrix_sync( wmma_A0.data, 16, 16, 16, i, T.tvm_access_ptr( T.type_annotation(dtype="float16"), A0.data, A0.elem_offset + 8, A0.strides[0], 1, dtype="handle", ), A0.strides[0], "row_major", dtype="handle", ) ) for j in range(0, 4): with T.block([64, 64]) as [vj, vk]: T.bind(vj, by * 8 + tz * 4 + j) T.bind(vk, ko * 2 + ki) T.reads( shared_B[ vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ] ) T.writes( wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16] ) s0 = T.var("int32") s1 = T.var("int32") B0 = T.match_buffer( shared_B[ vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ], (16, 16 + 8), "float16", strides=[s0, s1], scope="shared", offset_factor=1, ) wmma_B0 = T.match_buffer( wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_b", offset_factor=1, ) T.evaluate( T.tvm_load_matrix_sync( wmma_B0.data, 16, 16, 16, j, T.tvm_access_ptr( T.type_annotation(dtype="float16"), B0.data, B0.elem_offset + 8, B0.strides[0], 1, dtype="handle", ), B0.strides[0], "col_major", dtype="handle", ) ) for i, j in T.grid(2, 4): with T.block([64, 64, T.reduce_axis(0, 64)]) as [ vi, vj, vk, ]: T.bind(vi, bx * 4 + ty * 2 + i) T.bind(vj, by * 8 + tz * 4 + j) T.bind(vk, ko * 2 + ki) T.reads( [ wmma_A[ vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 ], wmma_B[ vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 ], wmma_C[ vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16 ], ] ) T.writes( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16] ) wmma_A1 = T.match_buffer( wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_a", offset_factor=1, ) wmma_B1 = T.match_buffer( wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_b", offset_factor=1, ) wmma_C1 = T.match_buffer( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) T.evaluate( T.tvm_mma_sync( wmma_C1.data, i * 4 + j, wmma_A1.data, i, wmma_B1.data, j, wmma_C1.data, i * 4 + j, dtype="handle", ) ) for i, j in T.grid(2, 4): with T.block([64, 64]) as [vi, vj]: T.bind(vi, bx * 4 + ty * 2 + i) T.bind(vj, by * 8 + tz * 4 + j) T.reads(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) s0 = T.var("int32") s1 = T.var("int32") wmma_C2 = T.match_buffer( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) C1 = T.match_buffer( C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[s0, s1], offset_factor=1, ) T.evaluate( T.tvm_store_matrix_sync( wmma_C2.data, 16, 16, 16, i * 4 + j, T.tvm_access_ptr( T.type_annotation(dtype="float32"), C1.data, C1.elem_offset, C1.strides[0], 1, dtype="handle", ), C1.strides[0], "row_major", dtype="handle", ) )
def main( A: T.Buffer[(512, 512), "float32"], B: T.Buffer[(512, 512), "float32"], C: T.Buffer[(512, 512), "float32"], ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"): for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"): for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.y"): for i2_0 in T.serial(0, 1): for ax0_ax1_fused_0 in T.serial(0, 1024): for ax0_ax1_fused_1 in T.thread_binding( 0, 8, thread="threadIdx.y"): for ax0_ax1_fused_2 in T.thread_binding( 0, 32, thread="threadIdx.x"): with T.block("A_shared"): v0 = T.axis.spatial( 512, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 32 + ax0_ax1_fused_2) // 512, ) v1 = T.axis.spatial( 512, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 32 + ax0_ax1_fused_2) % 512, ) T.reads([A[v0, v1]]) T.writes([A_shared[v0, v1]]) A_shared[v0, v1] = A[v0, v1] for ax0_ax1_fused_0 in T.serial(0, 32): for ax0_ax1_fused_1 in T.thread_binding( 0, 8, thread="threadIdx.y"): for ax0_ax1_fused_2 in T.thread_binding( 0, 32, thread="threadIdx.x"): for ax0_ax1_fused_3 in T.vectorized(0, 2): with T.block("B_shared"): v0 = T.axis.spatial( 512, (ax0_ax1_fused_0 * 512 + ax0_ax1_fused_1 * 64 + ax0_ax1_fused_2 * 2 + ax0_ax1_fused_3) // 32, ) v1 = T.axis.spatial( 512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 512 + ax0_ax1_fused_1 * 64 + ax0_ax1_fused_2 * 2 + ax0_ax1_fused_3) % 32, ) T.reads([B[v0, v1]]) T.writes([B_shared[v0, v1]]) B_shared[v0, v1] = B[v0, v1] for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid( 16, 2, 2, 32, 16, 2): with T.block("C"): i = T.axis.spatial( 512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) j = T.axis.spatial( 512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4, ) k = T.axis.reduce( 512, i2_0 * 512 + i2_1 * 32 + i2_2) T.reads([A_shared[i, k], B_shared[k, j]]) T.writes([C_local[i, j]]) T.block_attr({"warp_execution": 1}) with T.init(): C_local[i, j] = T.float32(0) C_local[i, j] = C_local[ i, j] + A_shared[i, k] * B_shared[k, j] for ax0, ax1 in T.grid(32, 4): with T.block("C_local"): v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0) v1 = T.axis.spatial( 512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1) T.reads([C_local[v0, v1]]) T.writes([C[v0, v1]]) C[v0, v1] = C_local[v0, v1]
def tensorcore_gemm(handle_a: T.handle, handle_b: T.handle, handle_c: T.handle) -> None: # pylint: disable=missing-function-docstring # match buffer match_buffer_a = T.match_buffer(handle_a, [1024, 1024], "float16") match_buffer_b = T.match_buffer(handle_b, [1024, 1024], "float16") match_buffer_c = T.match_buffer(handle_c, [1024, 1024], "float32") # body for block_idx_x in T.thread_binding(0, 16, "blockIdx.x"): for block_idx_y in T.thread_binding(0, 8, "blockIdx.y"): with T.block(): axis_bx, axis_by = T.axis.remap("SS", [block_idx_x, block_idx_y]) shared_a = T.alloc_buffer([1024, 1024], "float16", scope="shared") shared_b = T.alloc_buffer([1024, 1024], "float16", scope="shared") wmma_a = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a") wmma_b = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_b") wmma_c = T.alloc_buffer([1024, 1024], "float32", scope="wmma.accumulator") # pylint: disable=too-many-nested-blocks for thread_ty in T.thread_binding(0, 2, "threadIdx.y"): for thread_tz in T.thread_binding(0, 2, "threadIdx.z"): for index_i, index_jj in T.grid(2, 4): with T.block(): new_axis_vi = T.axis.S( 64, axis_bx * 4 + thread_ty * 2 + index_i) new_axis_vj = T.axis.S( 64, axis_by * 8 + thread_tz * 4 + index_jj) T.reads([]) T.writes(wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ]) match_buffer_c0 = T.match_buffer( wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) T.evaluate( T.tvm_fill_fragment( match_buffer_c0.data, 16, 16, 16, index_i * 4 + index_jj, T.float32(0), # pylint: disable=not-callable dtype="handle", )) for k_o in range(0, 32): # copy data from global to shared for thread_tx in T.thread_binding( 0, 32, "threadIdx.x"): for index_i0, index_j0 in T.grid(1, 4): for index_j1 in T.vectorized(0, 4): with T.block(): new_axis_vi = T.axis.S( 1024, axis_bx * 64 + thread_ty * 32 + thread_tx + index_i0, ) new_axis_vj = T.axis.S( 1024, k_o * 32 + thread_tz * 16 + index_j0 * 4 + index_j1, ) shared_a[new_axis_vi, new_axis_vj + 8] = match_buffer_a[ new_axis_vi, new_axis_vj] for index_i0, index_j0 in T.grid(2, 4): for index_j1 in T.vectorized(0, 4): with T.block(): new_axis_vi = T.axis.S( 1024, axis_by * 128 + thread_ty * 64 + thread_tx * 2 + index_i0, ) new_axis_vj = T.axis.S( 1024, k_o * 32 + thread_tz * 16 + index_j0 * 4 + index_j1, ) shared_b[new_axis_vi, new_axis_vj + 8] = match_buffer_b[ new_axis_vi, new_axis_vj] for k_i in range(0, 2): for index_i in range(0, 2): with T.block(): new_axis_vi = T.axis.S( 64, axis_bx * 4 + thread_ty * 2 + index_i) axis_vk = T.axis.S(64, k_o * 2 + k_i) T.reads(shared_a[new_axis_vi * 16:new_axis_vi * 16 + 16, axis_vk * 16:axis_vk * 16 + 16 + 8, ]) T.writes( wmma_a[new_axis_vi * 16:new_axis_vi * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ]) stride0 = T.var("int32") stride1 = T.var("int32") match_buffer_a0 = T.match_buffer( shared_a[new_axis_vi * 16:new_axis_vi * 16 + 16, axis_vk * 16:axis_vk * 16 + 16 + 8, ], (16, 16 + 8), "float16", strides=[stride0, stride1], scope="shared", offset_factor=1, ) wmma_a0 = T.match_buffer( wmma_a[new_axis_vi * 16:new_axis_vi * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_a", offset_factor=1, ) T.evaluate( T.tvm_load_matrix_sync( wmma_a0.data, 16, 16, 16, index_i, T.tvm_access_ptr( T.type_annotation( dtype="float16"), match_buffer_a0.data, match_buffer_a0.elem_offset + 8, match_buffer_a0.strides[0], 1, dtype="handle", ), match_buffer_a0.strides[0], "row_major", dtype="handle", )) for index_jj in range(0, 4): with T.block(): new_axis_vj = T.axis.S( 64, axis_by * 8 + thread_tz * 4 + index_jj) axis_vk = T.axis.S(64, k_o * 2 + k_i) T.reads(shared_b[new_axis_vj * 16:new_axis_vj * 16 + 16, axis_vk * 16:axis_vk * 16 + 16 + 8, ]) T.writes( wmma_b[new_axis_vj * 16:new_axis_vj * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ]) stride0 = T.var("int32") stride1 = T.var("int32") match_buffer_b0 = T.match_buffer( shared_b[new_axis_vj * 16:new_axis_vj * 16 + 16, axis_vk * 16:axis_vk * 16 + 16 + 8, ], (16, 16 + 8), "float16", strides=[stride0, stride1], scope="shared", offset_factor=1, ) wmma_b0 = T.match_buffer( wmma_b[new_axis_vj * 16:new_axis_vj * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_b", offset_factor=1, ) T.evaluate( T.tvm_load_matrix_sync( wmma_b0.data, 16, 16, 16, index_jj, T.tvm_access_ptr( T.type_annotation( dtype="float16"), match_buffer_b0.data, match_buffer_b0.elem_offset + 8, match_buffer_b0.strides[0], 1, dtype="handle", ), match_buffer_b0.strides[0], "col_major", dtype="handle", )) for index_i, index_jj in T.grid(2, 4): with T.block(): new_axis_vi = T.axis.S( 64, axis_bx * 4 + thread_ty * 2 + index_i) new_axis_vj = T.axis.S( 64, axis_by * 8 + thread_tz * 4 + index_jj) axis_vk = T.axis.R(64, k_o * 2 + k_i) T.reads([ wmma_a[new_axis_vi * 16:new_axis_vi * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ], wmma_b[new_axis_vj * 16:new_axis_vj * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ], wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ], ]) T.writes( wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ]) wmma_a1 = T.match_buffer( wmma_a[new_axis_vi * 16:new_axis_vi * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_a", offset_factor=1, ) wmma_b1 = T.match_buffer( wmma_b[new_axis_vj * 16:new_axis_vj * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_b", offset_factor=1, ) wmma_c1 = T.match_buffer( wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) T.evaluate( T.tvm_mma_sync( wmma_c1.data, index_i * 4 + index_jj, wmma_a1.data, index_i, wmma_b1.data, index_jj, wmma_c1.data, index_i * 4 + index_jj, dtype="handle", )) for index_i, index_jj in T.grid(2, 4): with T.block(): new_axis_vi = T.axis.S( 64, axis_bx * 4 + thread_ty * 2 + index_i) new_axis_vj = T.axis.S( 64, axis_by * 8 + thread_tz * 4 + index_jj) T.reads(wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ]) T.writes( match_buffer_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ]) stride0 = T.var("int32") stride1 = T.var("int32") wmma_c2 = T.match_buffer( wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) match_buffer_c1 = T.match_buffer( match_buffer_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ], (16, 16), "float32", strides=[stride0, stride1], offset_factor=1, ) T.evaluate( T.tvm_store_matrix_sync( wmma_c2.data, 16, 16, 16, index_i * 4 + index_jj, T.tvm_access_ptr( T.type_annotation(dtype="float32"), match_buffer_c1.data, match_buffer_c1.elem_offset, match_buffer_c1.strides[0], 1, dtype="handle", ), match_buffer_c1.strides[0], "row_major", dtype="handle", ))
def transformed_nested_pipeline_double_buffer( A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]) -> None: for tx in T.thread_binding(0, 16, thread="threadIdx.x"): with T.block(): T.reads([A[tx, 0:16, 0:16]]) T.writes([C[tx, 0:16, 0:16]]) A_shared = T.alloc_buffer([16, 1, 16], dtype="float32", scope="shared") A_local = T.alloc_buffer([2, 1, 1, 16], dtype="float32", scope="local") B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared") with T.block(): T.reads([ A[tx, 0, 0:16], A_shared[tx, 0, 0:16], A_local[0, tx, 0, 0] ]) T.writes([ A_shared[tx, 0, 0:16], A_local[0, 0, 0, 0:16], B[0, tx, 0, 0] ]) with T.block(): T.reads([A[tx, 0, 0:16]]) T.writes([A_shared[tx, 0, 0:16]]) for j in T.serial(0, 16): with T.block(): T.reads([A[tx, 0, j]]) T.writes([A_shared[tx, 0, j]]) A_shared[tx, 0, j] = A[tx, 0, j] with T.block(): T.reads([A_shared[tx, 0, 0:16]]) T.writes([A_local[0, 0, 0, 0:16]]) for j in T.serial(0, 16): with T.block(): T.reads([A_shared[tx, 0, j]]) T.writes([A_local[0, 0, 0, j]]) T.block_attr({"double_buffer_scope": 0}) A_local[0, 0, 0, j] = A_shared[tx, 0, j] with T.block(): T.reads([A_local[0, tx, 0, 0]]) T.writes([B[0, tx, 0, 0]]) B[0, tx, 0, 0] = A_local[0, 0, 0, 0] * T.float32(2) with T.block(): T.reads([ A[tx, 1:16, 0:16], A_local[0:2, tx, 0:16, 0:16], B[0:2, tx, 0:15, 0], A_shared[tx, 0, 0:16], ]) T.writes([ A_shared[tx, 0, 0:16], B[0:2, tx, 0:16, 0], C[tx, 0:15, 0:16], A_local[0:2, 0, 0, 0:16], ]) for i in T.serial(0, 15): with T.block(): T.reads([A[tx, i + 1, 0:16]]) T.writes([A_shared[tx, 0, 0:16]]) for j in T.serial(0, 16): with T.block(): T.reads([A[tx, i + 1, j]]) T.writes([A_shared[tx, 0, j]]) A_shared[tx, 0, j] = A[tx, i + 1, j] with T.block(): T.reads( [A_local[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]]) T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]]) for j in T.serial(0, 15): with T.block(): T.reads([A_local[i % 2, tx, i, j + 1]]) T.writes([B[(j + 1) % 2, tx, i, 0]]) B[(j + 1) % 2, tx, i, 0] = A_local[i % 2, 0, 0, j + 1] * T.float32(2) with T.block(): T.reads([B[j % 2, tx, i, 0]]) T.writes([C[tx, i, j]]) C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) with T.block(): T.reads([A_shared[tx, 0, 0:16]]) T.writes([A_local[(i + 1) % 2, 0, 0, 0:16]]) for j in T.serial(0, 16): with T.block(): T.reads([A_shared[tx, 0, j]]) T.writes([A_local[(i + 1) % 2, 0, 0, j]]) T.block_attr({"double_buffer_scope": 0}) A_local[(i + 1) % 2, 0, 0, j] = A_shared[tx, i + 1, j] with T.block(): T.reads([A_local[(i + 1) % 2, tx, i + 1, 0]]) T.writes([B[0, tx, i + 1, 0]]) B[0, tx, i + 1, 0] = A_local[(i + 1) % 2, 0, 0, 0] * T.float32(2) with T.block(): T.reads([B[1, tx, i, 0]]) T.writes([C[tx, i, 15]]) C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1) with T.block(): T.reads([A_local[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]]) with T.block(): T.reads([A_local[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]]) for j in T.serial(0, 15): with T.block(): T.reads([A_local[1, tx, 15, j + 1]]) T.writes([B[(j + 1) % 2, tx, 15, 0]]) B[(j + 1) % 2, tx, 15, 0] = A_local[1, 0, 0, j + 1] * T.float32(2) with T.block(): T.reads([B[j % 2, tx, 15, 0]]) T.writes([C[tx, 15, j]]) C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1) with T.block(): T.reads([B[1, tx, 15, 0]]) T.writes([C[tx, 15, 15]]) C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1)
def conv2d_winograd_cuda( # type: ignore placeholder: T.Buffer[(1, 14, 14, 128), "float32"], # type: ignore placeholder_1: T.Buffer[(6, 6, 128, 128), "float32"], # type: ignore conv2d_winograd: T.Buffer[(1, 12, 12, 128), "float32"], # type: ignore ) -> None: # type: ignore data_pad = T.alloc_buffer([1, 16, 16, 128]) input_tile = T.alloc_buffer([6, 6, 9, 128]) B = T.alloc_buffer([6, 6]) data_pack = T.alloc_buffer([6, 6, 9, 128]) bgemm = T.alloc_buffer([6, 6, 9, 128]) A = T.alloc_buffer([6, 4]) inverse = T.alloc_buffer([4, 4, 9, 128]) for i0, i1, i2, i3 in T.grid(1, 16, 16, 128): with T.block("data_pad"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.block_attr({"schedule_rule": "None"}) T.reads([placeholder[i0_1, i1_1, i2_1, i3_1]]) T.writes([data_pad[i0_1, i1_1, i2_1, i3_1]]) data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( 0 <= i1_1 and i1_1 < 14 and 0 <= i2_1 and i2_1 < 14, # type: ignore placeholder[i0_1, i1_1, i2_1, i3_1], T.float32(0), dtype="float32", ) for i0_2, i1_2, i2_2, i3_2 in T.grid(6, 6, 9, 128): with T.block("input_tile"): eps, nu, p, ci = T.axis.remap("SSSS", [i0_2, i1_2, i2_2, i3_2]) T.block_attr({"schedule_rule": "None"}) T.reads( [ data_pad[ T.floordiv(p, 9), # type: ignore ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), # type: ignore ((T.floormod(p, 3) * 4) + nu), # type: ignore ci, ] ] ) T.writes([input_tile[eps, nu, p, ci]]) input_tile[eps, nu, p, ci] = data_pad[ T.floordiv(p, 9), # type: ignore ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), # type: ignore ((T.floormod(p, 3) * 4) + nu), # type: ignore ci, ] for i0_3, i1_3 in T.grid(6, 6): with T.block("B"): i, j = T.axis.remap("SS", [i0_3, i1_3]) T.block_attr({"schedule_rule": "meta_schedule.compute_inline"}) T.writes([B[i, j]]) # fmt: off B[i, j] = T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 0)), T.float32(0), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 5)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 4)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 3)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 2)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 0)), T.float32(1), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 5)), T.float32(-2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 4)), T.float32(-0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 3)), T.float32(2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 2)), T.float32(2.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 1)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 0)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 5)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 4)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 3)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 2)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 1)), T.float32(-2.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 0)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 4)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 3)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 2)), T.float32(-1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 0)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 5)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) # type: ignore # fmt: on for i0_4, i1_4, i2_3, i3_3, i4, i5 in T.grid(6, 6, 9, 128, 6, 6): with T.block("data_pack"): eps_1, nu_1, p_1, ci_1, r_a, r_b = T.axis.remap( "SSSSRR", [i0_4, i1_4, i2_3, i3_3, i4, i5] ) T.block_attr({"schedule_rule": "meta_schedule.winograd_data_pack.cuda"}) T.reads( [ data_pack[eps_1, nu_1, p_1, ci_1], input_tile[r_a, r_b, p_1, ci_1], B[ T.min(r_a, r_b) : ( # type: ignore T.min(r_a, r_b) + ((T.max(r_a, r_b) + 1) - T.min(r_a, r_b)) # type: ignore ), T.min(eps_1, nu_1) : ( # type: ignore T.min(eps_1, nu_1) + ((T.max(eps_1, nu_1) + 1) - T.min(eps_1, nu_1)) # type: ignore ), ], ] ) T.writes([data_pack[eps_1, nu_1, p_1, ci_1]]) with T.init(): data_pack[eps_1, nu_1, p_1, ci_1] = T.float32(0) data_pack[eps_1, nu_1, p_1, ci_1] = data_pack[eps_1, nu_1, p_1, ci_1] + ( (input_tile[r_a, r_b, p_1, ci_1] * B[r_a, eps_1]) * B[r_b, nu_1] ) for i0_5, i1_5, i2_4, i3_4, i4_1 in T.grid(6, 6, 9, 128, 128): with T.block("bgemm"): eps_2, nu_2, p_2, co, ci_2 = T.axis.remap("SSSSR", [i0_5, i1_5, i2_4, i3_4, i4_1]) T.block_attr({"meta_schedule.write_cache_level": [3]}) T.reads( [ bgemm[eps_2, nu_2, p_2, co], data_pack[eps_2, nu_2, p_2, ci_2], placeholder_1[eps_2, nu_2, co, ci_2], ] ) T.writes([bgemm[eps_2, nu_2, p_2, co]]) with T.init(): bgemm[eps_2, nu_2, p_2, co] = T.float32(0) bgemm[eps_2, nu_2, p_2, co] = bgemm[eps_2, nu_2, p_2, co] + ( data_pack[eps_2, nu_2, p_2, ci_2] * placeholder_1[eps_2, nu_2, co, ci_2] ) for i0_6, i1_6 in T.grid(6, 4): with T.block("A"): i_1, j_1 = T.axis.remap("SS", [i0_6, i1_6]) T.block_attr({"schedule_rule": "meta_schedule.compute_inline"}) T.writes([A[i_1, j_1]]) # fmt: off A[i_1, j_1] = T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 0)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 3)), T.float32(-8), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 2)), T.float32(4), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 1)), T.float32(-2), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 3)), T.float32(0.125), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 2)), T.float32(0.25), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 1)), T.float32(0.5), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 1)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 3)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 1)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 3)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))) # type: ignore # fmt: on for i0_7, i1_7, i2_5, i3_5, i4_2, i5_1 in T.grid(4, 4, 9, 128, 6, 6): with T.block("inverse"): vh, vw, p_3, co_1, r_a_1, r_b_1 = T.axis.remap( "SSSSRR", [i0_7, i1_7, i2_5, i3_5, i4_2, i5_1] ) T.block_attr({"schedule_rule": "meta_schedule.winograd_inverse"}) T.reads( [ inverse[vh, vw, p_3, co_1], bgemm[r_a_1, r_b_1, p_3, co_1], A[ T.min(r_a_1, r_b_1) : ( # type: ignore T.min(r_a_1, r_b_1) + ((T.max(r_a_1, r_b_1) + 1) - T.min(r_a_1, r_b_1)) # type: ignore ), T.min(vh, vw) : (T.min(vh, vw) + ((T.max(vh, vw) + 1) - T.min(vh, vw))), # type: ignore ], ] ) T.writes([inverse[vh, vw, p_3, co_1]]) with T.init(): inverse[vh, vw, p_3, co_1] = T.float32(0) inverse[vh, vw, p_3, co_1] = inverse[vh, vw, p_3, co_1] + ( (bgemm[r_a_1, r_b_1, p_3, co_1] * A[r_a_1, vh]) * A[r_b_1, vw] ) for i0_8, i1_8, i2_6, i3_6 in T.grid(1, 12, 12, 128): with T.block("conv2d_winograd"): n, h, w, co_2 = T.axis.remap("SSSS", [i0_8, i1_8, i2_6, i3_6]) T.reads( [ inverse[ T.floormod(h, 4), # type: ignore T.floormod(w, 4), # type: ignore (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), # type: ignore co_2, ] ] ) T.writes([conv2d_winograd[n, h, w, co_2]]) conv2d_winograd[n, h, w, co_2] = inverse[ T.floormod(h, 4), # type: ignore T.floormod(w, 4), # type: ignore (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), # type: ignore co_2, ]
def main(placeholder: T.Buffer[(1, 13, 13, 3, 85), "float32"], placeholder_1: T.Buffer[(1, 26, 26, 3, 85), "float32"], placeholder_2: T.Buffer[(1, 52, 52, 3, 85), "float32"], T_expand_dims: T.Buffer[(1, 80, 10647), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") T_strided_slice_with_axes = T.alloc_buffer([1, 52, 52, 3, 1], dtype="float32") T_sigmoid = T.alloc_buffer([1, 52, 52, 3, 1], dtype="float32") T_strided_slice_with_axes_1 = T.alloc_buffer([1, 52, 52, 3, 80], dtype="float32") T_sigmoid_1 = T.alloc_buffer([1, 52, 52, 3, 80], dtype="float32") T_multiply = T.alloc_buffer([1, 52, 52, 3, 80], dtype="float32") T_reshape = T.alloc_buffer([8112, 80], dtype="float32") T_strided_slice_with_axes_2 = T.alloc_buffer([1, 26, 26, 3, 1], dtype="float32") T_sigmoid_2 = T.alloc_buffer([1, 26, 26, 3, 1], dtype="float32") T_strided_slice_with_axes_3 = T.alloc_buffer([1, 26, 26, 3, 80], dtype="float32") T_sigmoid_3 = T.alloc_buffer([1, 26, 26, 3, 80], dtype="float32") T_multiply_1 = T.alloc_buffer([1, 26, 26, 3, 80], dtype="float32") T_reshape_1 = T.alloc_buffer([2028, 80], dtype="float32") T_strided_slice_with_axes_4 = T.alloc_buffer([1, 13, 13, 3, 1], dtype="float32") T_sigmoid_4 = T.alloc_buffer([1, 13, 13, 3, 1], dtype="float32") T_strided_slice_with_axes_5 = T.alloc_buffer([1, 13, 13, 3, 80], dtype="float32") T_sigmoid_5 = T.alloc_buffer([1, 13, 13, 3, 80], dtype="float32") T_multiply_2 = T.alloc_buffer([1, 13, 13, 3, 80], dtype="float32") T_reshape_2 = T.alloc_buffer([507, 80], dtype="float32") T_concat = T.alloc_buffer([10647, 80], dtype="float32") T_transpose = T.alloc_buffer([80, 10647], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 1): with T.block("T_strided_slice_with_axes"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)]) T.writes(T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4]) T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4] = placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)] for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 1): with T.block("T_sigmoid"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4]) T.writes(T_sigmoid[ax0, ax1, ax2, ax3, ax4]) T_sigmoid[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 80): with T.block("T_strided_slice_with_axes_1"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)]) T.writes(T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4]) T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4] = placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)] for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 80): with T.block("T_sigmoid_1"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4]) T.writes(T_sigmoid_1[ax0, ax1, ax2, ax3, ax4]) T_sigmoid_1[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 80): with T.block("T_multiply"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_sigmoid[ax0, ax1, ax2, ax3, 0], T_sigmoid_1[ax0, ax1, ax2, ax3, ax4]) T.writes(T_multiply[ax0, ax1, ax2, ax3, ax4]) T_multiply[ax0, ax1, ax2, ax3, ax4] = T_sigmoid[ax0, ax1, ax2, ax3, 0] * T_sigmoid_1[ax0, ax1, ax2, ax3, ax4] for i0, i1 in T.grid(8112, 80): with T.block("T_reshape"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(T_multiply[0, (ax1 // 80 + ax0) % 8112 // 156, (ax1 // 80 + ax0) % 156 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80]) T.writes(T_reshape[ax0, ax1]) T_reshape[ax0, ax1] = T_multiply[0, (ax1 // 80 + ax0) % 8112 // 156, (ax1 // 80 + ax0) % 156 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80] for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 1): with T.block("T_strided_slice_with_axes_2"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)]) T.writes(T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4]) T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4] = placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)] for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 1): with T.block("T_sigmoid_2"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4]) T.writes(T_sigmoid_2[ax0, ax1, ax2, ax3, ax4]) T_sigmoid_2[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 80): with T.block("T_strided_slice_with_axes_3"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)]) T.writes(T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4]) T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4] = placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)] for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 80): with T.block("T_sigmoid_3"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4]) T.writes(T_sigmoid_3[ax0, ax1, ax2, ax3, ax4]) T_sigmoid_3[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 80): with T.block("T_multiply_1"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_sigmoid_2[ax0, ax1, ax2, ax3, 0], T_sigmoid_3[ax0, ax1, ax2, ax3, ax4]) T.writes(T_multiply_1[ax0, ax1, ax2, ax3, ax4]) T_multiply_1[ax0, ax1, ax2, ax3, ax4] = T_sigmoid_2[ax0, ax1, ax2, ax3, 0] * T_sigmoid_3[ax0, ax1, ax2, ax3, ax4] for i0, i1 in T.grid(2028, 80): with T.block("T_reshape_1"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(T_multiply_1[0, (ax1 // 80 + ax0) % 2028 // 78, (ax1 // 80 + ax0) % 78 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80]) T.writes(T_reshape_1[ax0, ax1]) T_reshape_1[ax0, ax1] = T_multiply_1[0, (ax1 // 80 + ax0) % 2028 // 78, (ax1 // 80 + ax0) % 78 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80] for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 1): with T.block("T_strided_slice_with_axes_4"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)]) T.writes(T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4]) T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4] = placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)] for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 1): with T.block("T_sigmoid_4"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4]) T.writes(T_sigmoid_4[ax0, ax1, ax2, ax3, ax4]) T_sigmoid_4[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 80): with T.block("T_strided_slice_with_axes_5"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)]) T.writes(T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4]) T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4] = placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)] for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 80): with T.block("T_sigmoid_5"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4]) T.writes(T_sigmoid_5[ax0, ax1, ax2, ax3, ax4]) T_sigmoid_5[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 80): with T.block("T_multiply_2"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_sigmoid_4[ax0, ax1, ax2, ax3, 0], T_sigmoid_5[ax0, ax1, ax2, ax3, ax4]) T.writes(T_multiply_2[ax0, ax1, ax2, ax3, ax4]) T_multiply_2[ax0, ax1, ax2, ax3, ax4] = T_sigmoid_4[ax0, ax1, ax2, ax3, 0] * T_sigmoid_5[ax0, ax1, ax2, ax3, ax4] for i0, i1 in T.grid(507, 80): with T.block("T_reshape_2"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(T_multiply_2[0, (ax1 // 80 + ax0) % 507 // 39, (ax1 // 80 + ax0) % 39 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80]) T.writes(T_reshape_2[ax0, ax1]) T_reshape_2[ax0, ax1] = T_multiply_2[0, (ax1 // 80 + ax0) % 507 // 39, (ax1 // 80 + ax0) % 39 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80] for i0, i1 in T.grid(10647, 80): with T.block("T_concat"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(T_reshape[ax0 - 2535, ax1], T_reshape_1[ax0 - 507, ax1], T_reshape_2[ax0, ax1]) T.writes(T_concat[ax0, ax1]) T_concat[ax0, ax1] = T.if_then_else(2535 <= ax0, T_reshape[ax0 - 2535, ax1], T.if_then_else(507 <= ax0, T_reshape_1[ax0 - 507, ax1], T_reshape_2[ax0, ax1], dtype="float32"), dtype="float32") for i0, i1 in T.grid(80, 10647): with T.block("T_transpose"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(T_concat[ax1, ax0]) T.writes(T_transpose[ax0, ax1]) T_transpose[ax0, ax1] = T_concat[ax1, ax0] for i0, i1, i2 in T.grid(1, 80, 10647): with T.block("T_expand_dims"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_transpose[ax1, ax2]) T.writes(T_expand_dims[ax0, ax1, ax2]) T_expand_dims[ax0, ax1, ax2] = T_transpose[ax1, ax2]