def softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: A = T.match_buffer(var_A, [256, 256], dtype="float32") T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256], dtype="float32") T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") for i0 in T.thread_binding(0, 256, thread="blockIdx.x"): for ax0_0 in T.serial(0, 8): for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_maxelem"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads([T_softmax_maxelem_shared[i0_1], A[i0_1, k]]) T.writes([T_softmax_maxelem_shared[i0_1]]) with T.init(): T_softmax_maxelem_shared[i0_1] = T.min_value("float32") T_softmax_maxelem_shared[i0_1] = T.max( T_softmax_maxelem_shared[i0_1], A[i0_1, k] ) for ax0_0 in T.serial(0, 8): for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_expsum"): i0_2 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads( [ T_softmax_expsum_shared[i0_2], A[i0_2, k], T_softmax_maxelem_shared[i0_2], ] ) T.writes([T_softmax_expsum_shared[i0_2]]) with T.init(): T_softmax_expsum_shared[i0_2] = T.float32(0) T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp( A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32" ) for i1_0 in T.serial(0, 8): for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_norm"): i0_3 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_0 * 32 + i1_1) T.reads( [ A[i0_3, i1], T_softmax_maxelem_shared[i0_3], T_softmax_expsum_shared[i0_3], ] ) T.writes([T_softmax_norm[i0_3, i1]]) T.block_attr({"axis": 1}) T_softmax_norm[i0_3, i1] = ( T.exp( A[i0_3, i1] - T_softmax_maxelem_shared[i0_3], dtype="float32", ) / T_softmax_expsum_shared[i0_3] )
def opaque_access_reorder(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16], "float32") B = T.match_buffer(b, [16, 16], "float32") for j, i in T.grid(16, 16): with T.block([16, 16], "A") as [vi, vj]: T.bind(vi, i) T.bind(vj, j) T.reads([]) T.writes([A[0:16, 0:16]]) T.store(A.data, vi * 16 + vj, 1) for j, i in T.grid(16, 16): with T.block([16, 16], "B") as [vi, vj]: T.bind(vi, i) T.bind(vj, j) T.reads([]) T.writes([B[0:16, 0:16]]) T.evaluate( T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle"))
def lowered_loop_split(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.serial(0, 128): for ki in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("B_in_thread_init"): T.reads([]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.float32(0) for ko in T.serial(0, 4): with T.block("B_normal_reduction"): vi = T.axis.S(128, i) vk = T.axis.R(128, ko * 32 + ki) T.reads([A[vi, vk], normal_reduce_temp0[0]]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk] with T.block("B_cross_thread_reduction"): T.reads([normal_reduce_temp0[0]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), normal_reduce_temp0[0], True, reduce_temp0.data, ki, dtype="handle", )) with T.block("B_write_back"): vi = T.axis.S(128, i) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0]
def multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16, 16], dtype="float32") B = T.match_buffer(b, [16], dtype="float32") B_rf_local = T.alloc_buffer([16, 16], dtype="float32", scope="local") for i in T.thread_binding(0, 16, thread="blockIdx.x"): for k0o in T.thread_binding(0, 4, thread="threadIdx.x"): for k0i0, k1 in T.grid(4, 16): with T.block("B_rf"): vk0 = T.axis.spatial(16, k0o * 4 + k0i0) vi, vk1 = T.axis.remap("SR", [i, k1]) T.reads([B_rf_local[vk0, vi], A[vi, vk0, vk1]]) T.writes([B_rf_local[vk0, vi]]) with T.init(): B_rf_local[vk0, vi] = T.float32(0) B_rf_local[vk0, vi] = B_rf_local[vk0, vi] + A[vi, vk0, vk1] for k0i1 in T.serial(0, 4): with T.block("B"): vk0 = T.axis.reduce(16, k0o * 4 + k0i1) vi = T.axis.spatial(16, i) T.reads([B[vi], B_rf_local[vk0, vi]]) T.writes([B[vi]]) with T.init(): B[vi] = T.float32(0) B[vi] = B[vi] + B_rf_local[vk0, vi]
def blockized_1(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128], "float32") B = T.alloc_buffer([128, 128], "float32") C = T.match_buffer(c, [128, 128], "float32") for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(8, 8): with T.block("C_outer"): vi_o, vj_o = T.axis.remap("SS", [i, j]) T.reads([B[ vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16, ]]) T.writes([C[ vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16 ]]) for i_i, j_i in T.grid(16, 16): with T.block("C_inner"): vi = T.axis.S(128, vi_o * 16 + i_i) vj = T.axis.S(128, vj_o * 16 + j_i) C[vi, vj] = B[vi, vj] + 1.0
def high_dim_opaque_access_with_source_strides(a: T.handle) -> None: A = T.match_buffer(a, (16, 32, 64), strides=[2576, 80, 1]) for i, j, k in T.grid(16, 2, 4): with T.block(): As_0 = T.var("int32") As_1 = T.var("int32") T.reads([]) T.writes(A[i, j * 16:j * 16 + 16, k * 16:k * 16 + 16]) sub_A = T.match_buffer( A[i, j * 16:j * 16 + 16, k * 16:k * 16 + 16], (16, 16), strides=[As_0, As_1], offset_factor=1, ) T.evaluate( T.intrin_test( sub_A.data, sub_A.elem_offset, sub_A.strides[0], sub_A.strides[1], sub_A.shape[0], sub_A.shape[1], dtype="handle", ))
def outer_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 1), offset_factor=1) B = T.match_buffer(b, (16, 1), offset_factor=1) C = T.match_buffer(c, (16, 16), offset_factor=1) with T.block("root"): T.reads( C[0 : 16, 0 : 16], A[0 : 16, 0 : 1], B[0 : 16, 0 : 1], ) T.writes(C[0 : 16, 0 : 16]) T.evaluate( T.call_extern( "outer_product", C.data, C.elem_offset, A.data, A.elem_offset, B.data, B.elem_offset, dtype="int32", ) )
def main( placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") for i0, i1, i2, i3, i4_0, i5, i6, i7, i8, i9_0, i4_1, i9_1 in T.grid( 1, 16, 56, 56, 1, 1, 1, 4, 4, 1, 16, 4): with T.block("conv2d_NCHWc_int8"): n, oc_chunk, oh, ow = T.axis.remap("SSSS", [i0, i1, i2, i3]) oc_block = T.axis.spatial(16, i4_0 * 16 + i4_1) kh, kw, ic_outer, ic_f_inner = T.axis.remap( "RRRR", [i5, i6, i7, i8]) ic_s_inner = T.axis.reduce(4, i9_0 * 4 + i9_1) T.reads( placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], ) T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) with T.init(): conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0 conv2d_NCHWc_int8[ n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[ n, oc_chunk, oh, ow, oc_block] + T.cast( placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32" ) * T.cast( placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], "int32", )
def expected_bufferslice_indices(data: T.handle, index: T.handle) -> None: index_buf = T.match_buffer(index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1) data_buf = T.match_buffer(data, [16, 16], elem_offset=0, align=128, offset_factor=1) with T.block([], "root"): T.reads([]) T.writes([]) out_buf = T.alloc_buffer([16, 16], elem_offset=0, align=128, offset_factor=1) for i0, i1 in T.grid(16, 16): with T.block([16, 16], "") as [vi, vj]: T.bind(vi, i0) T.bind(vj, i1) T.reads([data_buf[vi, 0:16], index_buf[0]]) T.writes([out_buf[vi, vj]]) out_buf[vi, vj] = data_buf[vi, index_buf[0]]
def tir_argmax_val_idx(var_val: T.handle, var_idx: T.handle, var_argmax_v0: T.handle, var_argmax_v1: T.handle) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) m = T.var("int32") n = T.var("int32") val = T.match_buffer(var_val, [m, n], dtype="float32") idx = T.match_buffer(var_idx, [m, n], dtype="int32") argmax_v0 = T.match_buffer(var_argmax_v0, [m], dtype="float32") argmax_v1 = T.match_buffer(var_argmax_v1, [m], dtype="int32") for i0, i1 in T.grid(m, n): with T.block("argmax"): i, k = T.axis.remap("SR", [i0, i1]) T.reads(val[i, k], idx[i, k]) T.writes(argmax_v0[i], argmax_v1[i]) with T.init(): argmax_v0[i] = T.min_value("float32") argmax_v1[i] = T.int32(-1) v_argmax_v0: T.float32 = T.Select(argmax_v0[i] >= val[i, k], argmax_v0[i], val[i, k]) v_argmax_v1: T.int32 = T.Select(argmax_v0[i] >= val[i, k], argmax_v1[i], idx[i, k]) argmax_v0[i] = v_argmax_v0 argmax_v1[i] = v_argmax_v1
def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (m_dim, k_dim), in_dtype, align=128, offset_factor=16, scope="wmma.matrix_a") B = T.match_buffer( b, maybe_swap(k_dim, n_dim), in_dtype, align=128, offset_factor=16, scope="wmma.matrix_b", ) C = T.match_buffer(c, (m_dim, n_dim), out_dtype, align=128, offset_factor=16, scope="wmma.accumulator") with T.block("root"): T.reads(C[0:m_dim, 0:n_dim], A[0:m_dim, 0:k_dim], B[0:b_shape_0, 0:b_shape_1]) T.writes(C[0:m_dim, 0:n_dim]) T.evaluate( T.tvm_mma_sync( C.data, get_wmma_fragment_index(C, m_dim, n_dim), A.data, get_wmma_fragment_index(A, m_dim, k_dim), B.data, get_wmma_fragment_index(B, b_shape_0, b_shape_1), C.data, get_wmma_fragment_index(C, m_dim, n_dim), dtype="handle", ))
def access_opaque_ptr_then_elemwise_inline(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [1024], dtype="float32") B = T.match_buffer(b, [1024], dtype="float32") A_cache = T.alloc_buffer([1024], dtype="float32") with T.block("opaque"): # annotated opaque partial access should be kept T.reads(A[0:512]) T.writes([A_cache[0:512]]) T.evaluate( T.tvm_access_ptr( T.type_annotation(dtype="float32"), A.data, 0, 512, "r", dtype="handle" ) ) T.evaluate( T.tvm_access_ptr( T.type_annotation(dtype="float32"), A_cache.data, 0, 512, "w", dtype="handle" ) ) for i in T.serial(0, 512): with T.block("B"): vi = T.axis.spatial(512, i) T.reads([A_cache[vi]]) T.writes([B[vi]]) B[vi] = A_cache[vi] * 2.0 + 1.0
def dot_product_4x4_i8i8i32_sdot( A: T.Buffer((4, ), "int8", offset_factor=1), B: T.Buffer((4, 4), "int8", offset_factor=1), C: T.Buffer((4, ), "int32", offset_factor=1), ) -> None: with T.block("root"): T.reads(C[0:4], A[0:4], B[0:4, 0:4]) T.writes(C[0:4]) A_i8x4 = A.vload([0], "int8x4") A_i32 = T.reinterpret(A_i8x4, dtype="int32") vec_ai32 = T.broadcast(A_i32, 4) vec_a = T.reinterpret(vec_ai32, dtype="int8x16") vec_b = B.vload([0, 0], dtype="int8x16") C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.sdot.v4i32.v16i8"), T.uint32(3), T.int32x4(0), vec_a, vec_b, dtype="int32x4", )
def transformed_trivial_pipeline(A: T.Buffer[(16, 1), "float32"], C: T.Buffer[(16, 1), "float32"]) -> None: for tx in T.thread_binding(16, thread="threadIdx.x"): with T.block(): T.reads(A[tx, 0]) T.writes(C[tx, 0]) B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") with T.block(): T.reads(A[tx, 0]) T.writes(B[0, tx, 0]) B[0, tx, 0] = A[tx, 0] * T.float32(2) with T.block(): T.reads() T.writes() T.evaluate(0) with T.block(): T.reads(B[0, tx, 0]) T.writes(C[tx, 0]) C[tx, 0] = B[0, tx, 0] + T.float32(1)
def single_reduction_loop_with_block_predicate( A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None: T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") for i0 in T.serial(256): for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_maxelem"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) T.reads(T_softmax_maxelem_shared[i0_1], A[i0_1, k]) T.writes(T_softmax_maxelem_shared[i0_1]) with T.init(): T_softmax_maxelem_shared[i0_1] = T.float32( -3.4028234663852886e38) T_softmax_maxelem_shared[i0_1] = T.max( T_softmax_maxelem_shared[i0_1], A[i0_1, k]) for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_expsum"): i0_2 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) T.reads(T_softmax_expsum_shared[i0_2], A[i0_2, k], T_softmax_maxelem_shared[i0_2]) T.writes(T_softmax_expsum_shared[i0_2]) with T.init(): T_softmax_expsum_shared[i0_2] = T.float32(0) T_softmax_expsum_shared[ i0_2] = T_softmax_expsum_shared[i0_2] + T.exp( A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32") for i1_0 in T.serial(1): for i1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_norm"): i0_3 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_1) T.where(i1_0 * 512 + i1_1 < 256) T.reads(A[i0_3, i1], T_softmax_maxelem_shared[i0_3], T_softmax_expsum_shared[i0_3]) T.writes(T_softmax_norm[i0_3, i1]) T.block_attr({"axis": 1}) T_softmax_norm[i0_3, i1] = ( T.exp(A[i0_3, i1] - T_softmax_maxelem_shared[i0_3], dtype="float32") / T_softmax_expsum_shared[i0_3])
def compacted_spatial_tiled_pad_and_pooling( X: T.Buffer[(64, 112, 112), "int32"], Y: T.Buffer[(64, 56, 56), "int32"] ) -> None: for h_o, w_o in T.grid(14, 14): with T.block(): T.reads(X[0:64, h_o * 8 - 1 : h_o * 8 + 8, w_o * 8 - 1 : w_o * 8 + 8]) T.writes(Y[h_o * 4 : h_o * 4 + 4, w_o * 4 : w_o * 4 + 4, 0:64]) X_cache = T.alloc_buffer([9, 9, 64], dtype="int32") for ax0, ax1, ax2 in T.grid(64, 9, 9): with T.block("cache"): T.where(1 <= h_o * 8 + ax1 and 1 <= w_o * 8 + ax2) T.reads(X[ax0, h_o * 8 + ax1 - 1, w_o * 8 + ax2 - 1]) T.writes( X_cache[ h_o * 8 + ax1 - T.max(0, h_o * 8 - 1) - 1, w_o * 8 + ax2 - T.max(0, w_o * 8 - 1) - 1, ax0, ] ) X_cache[ h_o * 8 + ax1 - T.max(0, h_o * 8 - 1) - 1, w_o * 8 + ax2 - T.max(0, w_o * 8 - 1) - 1, ax0, ] = X[ax0, h_o * 8 + ax1 - 1, w_o * 8 + ax2 - 1] for h_i, w_i, kh, kw, c in T.grid(4, 4, 3, 3, 64): with T.block("compute"): T.reads( X_cache[ h_o * 8 + h_i * 2 + kh - T.max(0, h_o * 8 - 1) - 1, w_o * 8 + w_i * 2 + kw - T.max(0, w_o * 8 - 1) - 1, c, ] ) T.writes(Y[h_o * 4 + h_i, w_o * 4 + w_i, c]) if kh == 0 and kw == 0: Y[h_o * 4 + h_i, w_o * 4 + w_i, c] = 0 Y[h_o * 4 + h_i, w_o * 4 + w_i, c] = T.max( Y[h_o * 4 + h_i, w_o * 4 + w_i, c], T.if_then_else( T.likely(1 <= h_o * 8 + h_i * 2 + kh, dtype="bool") and T.likely(1 <= w_o * 8 + w_i * 2 + kw, dtype="bool"), X_cache[ h_o * 8 + h_i * 2 + kh - T.max(0, h_o * 8 - 1) - 1, w_o * 8 + w_i * 2 + kw - T.max(0, w_o * 8 - 1) - 1, c, ], 0, dtype="int32", ), )
def func(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [64, 32], dtype="float32") B = T.match_buffer(b, [64, 32], dtype="float32") C = T.match_buffer(c, [64, 32], dtype="float32") for i, j in T.grid(64, 32): # type: ignore with T.block(): T.reads([A[i, j], B[i, j]]) # type: ignore T.writes([B[i, j], C[i, j]]) # type: ignore with T.block("B"): T.reads([A[i, j]]) # type: ignore T.writes([B[i, j]]) # type: ignore B[i, j] = A[i, j] # type: ignore with T.block("C"): T.reads([B[i, j]]) # type: ignore T.writes([C[i, j]]) # type: ignore C[i, j] = B[i, j] # type: ignore
def expected_match_buffer_func(a: T.handle) -> None: A = T.match_buffer(a, (16, 16)) for i in range(0, 16): with T.block(): T.reads([]) T.writes(A[i, 0:16]) A0 = T.match_buffer(A[i, 0:16], (16)) with T.block(): T.reads([]) T.writes(A0[0:16]) for j in range(0, 16): with T.block(): T.reads([]) T.writes(A0[j]) A1 = T.match_buffer(A0[j], ()) A1[()] = 1.0
def opaque_block_func() -> None: with T.block("root"): A = T.alloc_buffer((16, 16), "float32") B = T.alloc_buffer((16, 16), "float32") T.reads([]) T.writes([]) # Need add read/write region manually to avoid triggering block access region detector for i in range(0, 16): with T.block(): T.reads(A[i, 0:16]) T.writes([B[i, 0:16]]) for j in range(0, 16): with T.block(): T.reads(A[i, j]) T.writes(B[i, j]) B[i, j] = A[i, j] + 1.0
def compacted_elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((1, 16), "float32") for j in range(0, 16): with T.block() as []: T.reads(A[i, j]) T.writes(B[0, j]) B[0, j] = A[i, j] + 1.0 for j in range(0, 16): with T.block() as []: T.reads(B[0, j]) T.writes(C[i, j]) C[i, j] = B[0, j] * 2.0
def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (n * 8, ), "float32") C = T.match_buffer(c, (n * 8, ), "float32") for i in range(0, n): with T.block(): T.reads(A[i * 8:i * 8 + 8]) T.writes(C[i * 8:i * 8 + 8]) B = T.alloc_buffer((T.min(n, 1) * 8, ), "float32") for j in range(0, 8): with T.block() as []: T.reads(A[i * 8 + j]) T.writes(B[j]) B[j] = A[i * 8 + j] + 1.0 for j in range(0, 8): with T.block() as []: T.reads(B[j]) T.writes(C[i * 8 + j]) C[i * 8 + j] = B[j] * 2.0
def simple_compute_missing_annotation(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): for tx in T.thread_binding(0, 16, thread="threadIdx.x"): for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 1]}): with T.block(): T.reads(A[tx, i]) T.writes(C[tx, i]) B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") with T.block(): T.reads(A[tx, i]) T.writes(B[tx, 0]) B[tx, 0] = A[tx, i] * T.float32(2) with T.block(): T.reads(B[tx, 0]) T.writes(C[tx, i]) C[tx, i] = B[tx, 0] + T.float32(1)
def substituted_elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer([16, 16], "float32") for j in range(0, 16): with T.block(): T.reads([A[i, j]]) T.writes([B[i, j]]) B[i, j] = A[i, j] + 1.0 for j in range(0, 16): with T.block(): T.reads([B[i, j]]) T.writes([C[i, j]]) C[i, j] = B[i, j] * 2.0
def compacted_storage_align_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((1, 16), strides=(31, 1), dtypes="float32") for j in range(0, 16): with T.block() as []: T.reads(A[i, j]) T.writes(B[0, j]) T.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]}) B[0, j] = A[i, j] + 1.0 for j in range(0, 16): with T.block() as []: T.reads(B[0, j]) T.writes(C[i, j]) C[i, j] = B[0, j] * 2.0
def block_in_opaque_block(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.match_buffer(b, (128, 128), "float32") with T.block([128], "B") as vi: T.reads([A[0:128, 0:128]]) T.writes([B[0:128, 0:128]]) B[vi, 0] = A[vi, 0] if A[vi, 0] == 0.0: with T.block([], "C"): T.reads([A[0:128, 0:128]]) T.writes([B[0:128, 0:128]]) with T.block([128], "D") as vj: B[vi, vj] = A[vi, vj] * 3.0 else: with T.block([], "E"): T.reads([A[0:128, 0:128]]) T.writes([B[0:128, 0:128]]) with T.block([128], "F") as vj: B[vi, vj] = A[vi, vj] * 2.0
def pooling_decompose_3( x: T.Buffer[(1, 16, 225, 225), "int8"], tensor: T.Buffer[(1, 16, 225, 225), "int8"]) -> None: pad_temp = T.alloc_buffer([1, 16, 231, 231], dtype="int8") for i0, i2_0, i3_0 in T.grid(1, 3, 3): for ax0, ax1, ax2 in T.grid(16, 86, 86): with T.block("pad_temp_pad_const"): ax0_1 = T.axis.spatial(1, 0) ax1_1 = T.axis.spatial(16, ax0) ax2_1 = T.axis.spatial(231, i2_0 * 80 + ax1) ax3 = T.axis.spatial(231, i3_0 * 80 + ax2) T.where(i2_0 * 80 + ax1 < 231 and i3_0 * 80 + ax2 < 231) T.reads() T.writes(pad_temp[ax0_1, ax1_1, ax2_1, ax3]) pad_temp[ax0_1, ax1_1, ax2_1, ax3] = T.int8(0) for ax0, ax1, ax2 in T.grid(16, 86, 86): with T.block("pad_temp"): ax0_2 = T.axis.spatial(1, 0) ax1_2 = T.axis.spatial(16, ax0) ax2_2 = T.axis.spatial(225, i2_0 * 80 + ax1 - 3) ax3 = T.axis.spatial(225, i3_0 * 80 + ax2 - 3) T.where(3 <= i2_0 * 80 + ax1 and i2_0 * 80 + ax1 < 228 and 3 <= i3_0 * 80 + ax2 and i3_0 * 80 + ax2 < 228 and i2_0 * 80 + ax1 < 231 and i3_0 * 80 + ax2 < 231) T.reads(x[ax0_2, ax1_2, ax2_2, ax3]) T.writes(pad_temp[ax0_2, ax1_2, ax2_2 + 3, ax3 + 3]) pad_temp[ax0_2, ax1_2, ax2_2 + 3, ax3 + 3] = x[ax0_2, ax1_2, ax2_2, ax3] for i1, i2_1, i3_1, i4, i5 in T.grid(16, 80, 80, 7, 7): with T.block("tensor"): ax0_3, ax1_3 = T.axis.remap("SS", [i0, i1]) ax2_3 = T.axis.spatial(225, i2_0 * 80 + i2_1) ax3 = T.axis.spatial(225, i3_0 * 80 + i3_1) rv0, rv1 = T.axis.remap("RR", [i4, i5]) T.where(i2_0 * 80 + i2_1 < 225 and i3_0 * 80 + i3_1 < 225) T.reads(pad_temp[ax0_3, ax1_3, ax2_3 + rv0, ax3 + rv1]) T.writes(tensor[ax0_3, ax1_3, ax2_3, ax3]) with T.init(): tensor[ax0_3, ax1_3, ax2_3, ax3] = T.int8(0) tensor[ax0_3, ax1_3, ax2_3, ax3] = ( tensor[ax0_3, ax1_3, ax2_3, ax3] + pad_temp[ax0_3, ax1_3, ax2_3 + rv0, ax3 + rv1])
def compacted_warp_mem_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i0 in T.thread_binding(0, 2, thread="blockIdx.x"): for i1 in T.thread_binding(0, 2, thread="vthread"): for i2 in T.thread_binding(0, 4, thread="threadIdx.x"): with T.block(): T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) B = T.alloc_buffer((4, 16), "float32", scope="warp") for j in range(0, 16): with T.block() as []: T.reads(A[i0 * 8 + i1 * 4 + i2, j]) T.writes(B[i2, j]) B[i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 for j in range(0, 16): with T.block() as []: T.reads(B[i2, j]) T.writes(C[i0 * 8 + i1 * 4 + i2, j]) C[i0 * 8 + i1 * 4 + i2, j] = B[i2, j] * 2.0
def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: A = T.match_buffer(a, (n, m), "float32") C = T.match_buffer(c, (n, m), "float32") for i in range(0, n): with T.block(): T.reads(A[i, m]) T.writes(C[i, m]) B = T.alloc_buffer((m, ), "float32", scope="global") for j in range(0, m): with T.block() as []: T.reads(A[i, j]) T.writes(B[j]) B[j] = A[i, j] + 1.0 for j in range(0, m): with T.block() as []: T.reads(B[j]) T.writes(C[i, j]) C[i, j] = B[j] * 2.0
def conv2d_nhwc_reindex_weight(var_inputs: T.handle, var_weight: T.handle, var_conv2d_nhwc: T.handle) -> None: inputs = T.match_buffer(var_inputs, [1, 224, 224, 3], dtype="float32") weight = T.match_buffer(var_weight, [7, 7, 3, 64], dtype="float32") conv2d_nhwc = T.match_buffer(var_conv2d_nhwc, [1, 112, 112, 64], dtype="float32") PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") weight_reindex = T.alloc_buffer([64, 7, 7, 3], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): with T.block("PadInput"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1]) T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( i1_1 >= 3 and i1_1 < 227 and i2_1 >= 3 and i2_1 < 227, inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1], T.float32(0), dtype="float32", ) for ax0, ax1, ax2, ax3, ax4, ax5, ax6 in T.grid(1, 1, 1, 64, 7, 7, 3): with T.block("weight_reindex"): v0, v1, v2, v3, v4, v5, v6 = T.axis.remap( "SSSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5, ax6]) T.reads(weight[v4, v5, v6, v3]) T.writes(weight_reindex[v3, v4, v5, v6]) weight_reindex[v3, v4, v5, v6] = weight[v4, v5, v6, v3] for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3): with T.block("conv2d_nhwc"): n, h, w, co, rh, rw, rc = T.axis.remap( "SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) T.reads( PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight_reindex[co, rh, rw, rc], ) T.writes(conv2d_nhwc[n, h, w, co]) with T.init(): conv2d_nhwc[n, h, w, co] = T.float32(0) conv2d_nhwc[n, h, w, co] = ( conv2d_nhwc[n, h, w, co] + PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc] * weight_reindex[co, rh, rw, rc])
def element_wise(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body with T.block("root"): T.reads([]) T.writes([]) B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) for i0 in T.serial(0, 128): for ax1 in T.serial(0, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i0, ax1]) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) B[vi, vj] = (A[vi, vj]*T.float32(2)) for i1 in T.serial(0, 128): with T.block("C"): vi_1, vj_1 = T.axis.remap("SS", [i0, i1]) T.reads([B[vi_1, vj_1]]) T.writes([C[vi_1, vj_1]]) C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1))