def elementwise_subblock_uncovered(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") for i, j in T.grid(32, 32): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads([A[vi * 4:vi * 4 + 2, vj * 4:vj * 4 + 2]]) T.writes([B[vi * 4:vi * 4 + 2, vj * 4:vj * 4 + 2]]) for ii, jj in T.grid(2, 2): with T.block("B_sub"): vi_i, vj_i = T.axis.remap("SS", [ii, jj]) B[vi * 4 + vi_i, vj * 4 + vj_i] = A[vi * 4 + vi_i, vj * 4 + vj_i] * 2.0 for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0
def compacted_elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer([1, 16], "float32", scope="global") for j in range(0, 16): with T.block() as []: T.reads(A[i, j]) T.writes(B[0, j]) B[0, j] = A[i, j] + 1.0 for j in range(0, 16): with T.block() as []: T.reads(B[0, j]) T.writes(C[i, j]) C[i, j] = B[0, j] * 2.0
def rowsum_zero_dim_rfactor(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128]) B = T.match_buffer(b, []) B_rf = T.alloc_buffer([128]) for i in range(128): with T.block("B_rf"): vi0 = T.axis.S(128, i) with T.init(): B_rf[vi0] = 0.0 B_rf[vi0] = B_rf[vi0] + A[vi0] for i in range(128): with T.block("B"): vi0_1 = T.axis.R(128, i) with T.init(): B[()] = 0.0 B[()] = B[()] + B_rf[vi0_1]
def simple_compute_missing_annotation(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): for tx in T.thread_binding(0, 16, thread="threadIdx.x"): for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 1]}): with T.block(): T.reads(A[tx, i]) T.writes(C[tx, i]) B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") with T.block(): T.reads(A[tx, i]) T.writes(B[tx, 0]) B[tx, 0] = A[tx, i] * T.float32(2) with T.block(): T.reads(B[tx, 0]) T.writes(C[tx, i]) C[tx, i] = B[tx, 0] + T.float32(1)
def elementwise_multi_producer_consumer(a: T.handle, c: T.handle, d: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) D = T.match_buffer(d, (128, 128)) for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 # B has two consumers for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 for i, j in T.grid(128, 128): with T.block("D"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = B[vi, vj] + 2.0 + C[vi, vj] # D has two producers
def elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): with T.block([]): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((16, 16), "float32") for j in range(0, 16): with T.block([]) as []: T.reads(A[i, j]) T.writes(B[i, j]) B[i, j] = A[i, j] + 1.0 for j in range(0, 16): with T.block([]) as []: T.reads(B[i, j]) T.writes(C[i, j]) C[i, j] = B[i, j] * 2.0
def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (n * 8, ), "float32") C = T.match_buffer(c, (n * 8, ), "float32") for i in range(0, n): with T.block(): T.reads(A[i * 8:i * 8 + 8]) T.writes(C[i * 8:i * 8 + 8]) B = T.alloc_buffer((8, ), "float32") for j in range(0, 8): with T.block() as []: T.reads(A[i * 8 + j]) T.writes(B[j]) B[j] = A[i * 8 + j] + 1.0 for j in range(0, 8): with T.block() as []: T.reads(B[j]) T.writes(C[i * 8 + j]) C[i * 8 + j] = B[j] * 2.0
def elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((16, 16), "float32") for j in range(0, 16): with T.block(): vi = T.axis.S(16, i) vj = T.axis.S(16, j) B[vi, vj] = A[vi, vj] + 1.0 for j in range(0, 16): with T.block(): vi = T.axis.S(16, i) vj = T.axis.S(16, j) C[vi, vj] = B[vi, vj] * 2.0
def thread_bound_nested_block_after_cache_read( A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16,), "float32"] ) -> None: for i in T.thread_binding(16, thread="blockIdx.x"): with T.block("outer"): vi = T.axis.spatial(16, i) A_shared = T.alloc_buffer([1, 16], dtype="float32", scope="shared") for ax0, ax1 in T.grid(1, 16): with T.block("A_shared"): v0 = T.axis.spatial(16, vi + ax0) v1 = T.axis.spatial(16, ax1) A_shared[v0, v1] = A[v0, v1] for j in T.thread_binding(16, thread="threadIdx.x"): with T.block("inner"): vj = T.axis.reduce(16, j) with T.init(): B[vi] = T.float32(0) B[vi] = B[vi] + A_shared[vi, vj]
def transformed_square_sum_square_root_factor_one_2(a: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [16, 256, 256]) D = T.match_buffer(d, [16]) C = T.alloc_buffer([16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 1, 65536): with T.block("C"): b = T.axis.S(16, i0) i = T.axis.R(256, T.floordiv(i1_i2_fused_inner, 256)) j = T.axis.R(256, T.floormod(i1_i2_fused_inner, 256)) with T.init(): C[b] = 0.0 C[b] = C[b] + (A[b, i, j] * A[b, i, j]) for i0_1 in T.serial(0, 16): with T.block("D"): b_1 = T.axis.S(16, i0_1) D[b_1] = T.sqrt(C[b_1], dtype="float32")
def substituted_elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer([16, 16], "float32") for j in range(0, 16): with T.block(): T.reads([A[i, j]]) T.writes([B[i, j]]) B[i, j] = A[i, j] + 1.0 for j in range(0, 16): with T.block(): T.reads([B[i, j]]) T.writes([C[i, j]]) C[i, j] = B[i, j] * 2.0
def warp_memory_negative(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) C = T.match_buffer(c, [128, 128]) B = T.alloc_buffer([128, 4, 32], scope="warp") for i_o in T.thread_binding(0, 4, thread="threadIdx.y"): for i_i in T.thread_binding(0, 32, thread="threadIdx.x"): for j in T.serial(0, 128): with T.block("B"): warp_id, lane_id, vj = T.axis.remap("SSS", [i_o, i_i, j]) B[vj, warp_id, lane_id] = A[warp_id * 32 + lane_id, vj] * 2.0 for i_o_prime in T.thread_binding(0, 4, thread="threadIdx.y"): for j in T.serial(0, 128): with T.block("C"): _warp_id, warp_id, lane_id, vj = T.axis.remap( "SSSS", [i_o, i_i, i_o_prime, j]) C[warp_id * 32 + lane_id, vj] = B[vj, warp_id, lane_id] + 1.0
def square_sum_rfactor(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [16, 256, 256]) C = T.match_buffer(c, [16]) C_rf = T.alloc_buffer([16, 256]) for i0, i1, i2 in T.grid(16, 256, 256): with T.block("C_rf"): vi2, b, i = T.axis.remap("SSR", [i2, i0, i1]) with T.init(): C_rf[b, vi2] = 0.0 C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2]) for i0_1, i2_1 in T.grid(16, 256): with T.block("C"): vi2_1, b_1 = T.axis.remap("RS", [i2_1, i0_1]) with T.init(): C[b_1] = 0.0 C[b_1] = C[b_1] + C_rf[b_1, vi2_1]
def main(placeholder: T.Buffer[(1, 1, 16, 16, 3), "float32"], placeholder_1: T.Buffer[(2, 1, 5, 5, 3, 4), "float32"], conv2d_NCHWc: T.Buffer[(1, 2, 16, 16, 4), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") data_pad = T.alloc_buffer([1, 1, 20, 20, 3], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 1, 20, 20, 3): with T.block("data_pad"): i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap( "SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1]) T.writes(data_pad[i0_1, i1_1, i2_1, i3_1, i4_1]) data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18, placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1], T.float32(0), dtype="float32") for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(1, 2, 16, 16, 4, 3, 5, 5): with T.block("conv2d_NCHWc"): n, oc_chunk, oh, ow, oc_block, ic, kh, kw = T.axis.remap( "SSSSSRRR", [i0, i1, i2, i3, i4, i5, i6, i7]) T.reads( data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3], placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block]) T.writes(conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block]) T.block_attr({ "workload": [ "conv2d_NCHWc.x86", ["TENSOR", [1, 1, 16, 16, 3], "float32"], ["TENSOR", [2, 1, 5, 5, 3, 4], "float32"], [1, 1], [2, 2, 2, 2], [1, 1], "NCHW3c", "NCHW4c", "float32" ] }) with T.init(): conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = T.float32(0) conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[ n, oc_chunk, oh, ow, oc_block] + data_pad[ n, ic // 3, oh + kh, ow + kw, ic % 3] * placeholder_1[ oc_chunk, ic // 3, kh, kw, ic % 3, oc_block]
def main(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[(1, ), "float32"]) -> None: C = T.alloc_buffer([1], dtype="float32") for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for i0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for i1, i2 in T.grid(256, 256): with T.block("C"): b = T.axis.S(1, 0) i, j = T.axis.remap("RR", [i1, i2]) T.where(i0_fused_1 < 1) with T.init(): C[b] = T.float32(0) C[b] = C[b] + A[b, i, j] * A[b, i, j] for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for i0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): with T.block("D"): b = T.axis.S(1, 0) T.where(i0_fused_1 < 1) D[b] = T.sqrt(C[b], dtype="float32")
def transformed_element_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [16, 16]) C = T.match_buffer(c, [16, 16]) for i_0 in range(0, 16): with T.block([]): T.reads([A[i_0, 0:16]]) T.writes([C[i_0, 0:16]]) B = T.alloc_buffer([16, 16]) for j_0 in T.serial(0, 16): with T.block([16, 16], "") as [i, j]: T.bind(i, i_0) T.bind(j, j_0) B[i, j] = A[i, j] + 1.0 for j_0 in T.serial(0, 16): with T.block([16, 16], "") as [i, j]: T.bind(i, i_0) T.bind(j, j_0) C[i, j] = B[i, j] * 2.0
def matmul_relu(var_A: T.handle, var_B: T.handle, var_compute: T.handle) -> None: A = T.match_buffer(var_A, [512, 512], dtype="float32") B = T.match_buffer(var_B, [512, 512], dtype="float32") compute = T.match_buffer(var_compute, [512, 512], dtype="float32") C = T.alloc_buffer([512, 512], dtype="float32") for i0, i1, i2 in T.grid(512, 512, 512): with T.block("C"): i, j, k = T.axis.remap("SSR", [i0, i1, i2]) T.reads([C[i, j], A[i, k], B[k, j]]) T.writes([C[i, j]]) with T.init(): C[i, j] = T.float32(0) C[i, j] = C[i, j] + A[i, k] * B[k, j] for i0, i1 in T.grid(512, 512): with T.block("compute"): i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) T.reads([C[i0_1, i1_1]]) T.writes([compute[i0_1, i1_1]]) compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0))
def access_under_scope(b: T.handle, c: T.handle) -> None: A = T.alloc_buffer((128, 128)) B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) with T.block([8, 8], "scope") as [i, j]: for x, y in T.grid(16, 16): with T.block([128, 128], "A") as [vi, vj]: T.bind(vi, i * 16 + x) T.bind(vj, j * 16 + y) A[vi, vj] = 1.0 for x, y in T.grid(16, 16): with T.block([128, 128], "B") as [vi, vj]: T.bind(vi, i * 16 + x) T.bind(vj, j * 16 + y) B[vi, vj] = A[vi, vj] + 1.0 with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = A[vi, vj] * 2.0
def factorized(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16, 16], "float32") B = T.match_buffer(b, [16], "float32") B_rf_local = T.alloc_buffer([16, 16], "float32", scope="local") for j in T.thread_binding(0, 16, thread = "blockIdx.x"): for i_o in T.thread_binding(0, 4, thread = "threadIdx.x"): for i_i, k in T.grid(4, 16): with T.block("B_rf"): vi = T.axis.S(16, i_o * 4 + i_i) vj, vk = T.axis.remap("SR", [j, k]) with T.init(): B_rf_local[vi, vj] = 0.0 B_rf_local[vi, vj] = B_rf_local[vi, vj] + A[vj, vi, vk] for i, k in T.grid(16, 16): with T.block("B"): vi, vk = T.axis.remap("SR", [i, k]) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + B_rf_local[vk, vi]
def transformed_trivial_pipeline(A: T.Buffer[(16, 1), "float32"], C: T.Buffer[(16, 1), "float32"]) -> None: for tx in T.thread_binding(16, thread="threadIdx.x"): with T.block(): T.reads(A[tx, 0]) T.writes(C[tx, 0]) B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") with T.block(): T.reads(A[tx, 0]) T.writes(B[0, tx, 0]) B[0, tx, 0] = A[tx, 0] * T.float32(2) with T.block(): T.reads() T.writes() T.evaluate(0) with T.block(): T.reads(B[0, tx, 0]) T.writes(C[tx, 0]) C[tx, 0] = B[0, tx, 0] + T.float32(1)
def fail_subtree_compact_dataflow(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") B = T.alloc_buffer((128, 128), "float32") C = T.match_buffer(c, (128, 128), "float32") for i in range(0, 128): for j in range(0, 64): with T.block("B_0"): vi = T.axis.S(128, i) vj = T.axis.S(128, j) B[vi, vj] = A[vi, vj] * 2.0 for j in range(0, 64): with T.block("B_1"): vi = T.axis.S(128, i) vj = T.axis.S(128, j + 64) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0
def pooling_decompose_3( x: T.Buffer[(1, 16, 225, 225), "int8"], tensor: T.Buffer[(1, 16, 225, 225), "int8"]) -> None: pad_temp = T.alloc_buffer([1, 16, 231, 231], dtype="int8") for i0, i2_0, i3_0 in T.grid(1, 3, 3): for ax0, ax1, ax2 in T.grid(16, 86, 86): with T.block("pad_temp_pad_const"): ax0_1 = T.axis.spatial(1, 0) ax1_1 = T.axis.spatial(16, ax0) ax2_1 = T.axis.spatial(231, i2_0 * 80 + ax1) ax3 = T.axis.spatial(231, i3_0 * 80 + ax2) T.where(i2_0 * 80 + ax1 < 231 and i3_0 * 80 + ax2 < 231) T.reads() T.writes(pad_temp[ax0_1, ax1_1, ax2_1, ax3]) pad_temp[ax0_1, ax1_1, ax2_1, ax3] = T.int8(0) for ax0, ax1, ax2 in T.grid(16, 86, 86): with T.block("pad_temp"): ax0_2 = T.axis.spatial(1, 0) ax1_2 = T.axis.spatial(16, ax0) ax2_2 = T.axis.spatial(225, i2_0 * 80 + ax1 - 3) ax3 = T.axis.spatial(225, i3_0 * 80 + ax2 - 3) T.where(3 <= i2_0 * 80 + ax1 and i2_0 * 80 + ax1 < 228 and 3 <= i3_0 * 80 + ax2 and i3_0 * 80 + ax2 < 228 and i2_0 * 80 + ax1 < 231 and i3_0 * 80 + ax2 < 231) T.reads(x[ax0_2, ax1_2, ax2_2, ax3]) T.writes(pad_temp[ax0_2, ax1_2, ax2_2 + 3, ax3 + 3]) pad_temp[ax0_2, ax1_2, ax2_2 + 3, ax3 + 3] = x[ax0_2, ax1_2, ax2_2, ax3] for i1, i2_1, i3_1, i4, i5 in T.grid(16, 80, 80, 7, 7): with T.block("tensor"): ax0_3, ax1_3 = T.axis.remap("SS", [i0, i1]) ax2_3 = T.axis.spatial(225, i2_0 * 80 + i2_1) ax3 = T.axis.spatial(225, i3_0 * 80 + i3_1) rv0, rv1 = T.axis.remap("RR", [i4, i5]) T.where(i2_0 * 80 + i2_1 < 225 and i3_0 * 80 + i3_1 < 225) T.reads(pad_temp[ax0_3, ax1_3, ax2_3 + rv0, ax3 + rv1]) T.writes(tensor[ax0_3, ax1_3, ax2_3, ax3]) with T.init(): tensor[ax0_3, ax1_3, ax2_3, ax3] = T.int8(0) tensor[ax0_3, ax1_3, ax2_3, ax3] = ( tensor[ax0_3, ax1_3, ax2_3, ax3] + pad_temp[ax0_3, ax1_3, ax2_3 + rv0, ax3 + rv1])
def compacted_storage_align_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((1, 16), strides=(31, 1), dtypes="float32") for j in range(0, 16): with T.block() as []: T.reads(A[i, j]) T.writes(B[0, j]) T.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]}) B[0, j] = A[i, j] + 1.0 for j in range(0, 16): with T.block() as []: T.reads(B[0, j]) T.writes(C[i, j]) C[i, j] = B[0, j] * 2.0
def matmul_rfactor(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) C_rf = T.alloc_buffer([4, 128, 128]) for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid( 128, 128, 4, 8, 4): with T.block([4, 128, 128, T.reduce_axis(0, 4), T.reduce_axis(0, 8)], "update_rf") as [ vi2_inner_inner, vi, vj, vi2_outer, vi2_inner_outer, ]: T.bind(vi2_inner_inner, i2_inner_inner) T.bind(vi, i0) T.bind(vj, i1) T.bind(vi2_outer, i2_outer) T.bind(vi2_inner_outer, i2_inner_outer) with T.init(): C_rf[vi2_inner_inner, vi, vj] = 0.0 C_rf[vi2_inner_inner, vi, vj] = C_rf[vi2_inner_inner, vi, vj] + (A[vi, ( ((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)] * B[vj, ( ((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)]) for i0_1, i1_1, i2_inner_inner_1 in T.grid(128, 128, 4): with T.block([T.reduce_axis(0, 4), 128, 128], "update") as [ vi2_inner_inner_1, vi_1, vj_1, ]: T.bind(vi2_inner_inner_1, i2_inner_inner_1) T.bind(vi_1, i0_1) T.bind(vj_1, i1_1) with T.init(): C[vi_1, vj_1] = 0.0 C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1]
def compacted_match_buffer_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16)) C = T.match_buffer(c, (16, 16)) for i in range(0, 16): with T.block(): A0 = T.match_buffer(A[i, 0:16], (16)) C0 = T.match_buffer(C[i, 0:16], (16)) B = T.alloc_buffer((1, 16)) with T.block(): B0 = T.match_buffer(B[0, 0:16], (16)) for j in range(0, 16): with T.block() as []: A1 = T.match_buffer(A0[j], ()) B1 = T.match_buffer(B0[j], ()) B1[()] = A1[()] + 1.0 for j in range(0, 16): with T.block() as []: C1 = T.match_buffer(C0[j], ()) B2 = T.match_buffer(B[0, j], ()) C1[()] = B2[()] * 2.0
def compacted_warp_mem_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i0 in T.thread_binding(0, 2, thread="blockIdx.x"): for i1 in T.thread_binding(0, 2, thread="vthread"): for i2 in T.thread_binding(0, 4, thread="threadIdx.x"): with T.block(): T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) B = T.alloc_buffer((4, 16), "float32", scope="warp") for j in range(0, 16): with T.block() as []: T.reads(A[i0 * 8 + i1 * 4 + i2, j]) T.writes(B[i2, j]) B[i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 for j in range(0, 16): with T.block() as []: T.reads(B[i2, j]) T.writes(C[i0 * 8 + i1 * 4 + i2, j]) C[i0 * 8 + i1 * 4 + i2, j] = B[i2, j] * 2.0
def matmul_reindex_write( A: T.Buffer[(512, 512), "float32"], B: T.Buffer[(512, 512), "float32"], C: T.Buffer[(512, 512), "float32"], ) -> None: C_reindex = T.alloc_buffer([512, 512], dtype="float32") for i0, i1, i2 in T.grid(512, 512, 512): with T.block("matmul"): i, j, k = T.axis.remap("SSR", [i0, i1, i2]) T.reads(C_reindex[i, j], A[i, k], B[k, j]) T.writes(C_reindex[i, j]) with T.init(): C_reindex[i, j] = T.float32(0) C_reindex[i, j] = C_reindex[i, j] + A[i, k] * B[k, j] for i0, i1, i2 in T.grid(512, 512, 1): with T.block("C_reindex"): v0, v1, v2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(C_reindex[v0, v1]) T.writes(C[v0, v1]) C[v0, v1] = C_reindex[v0, v1]
def func_3( C: T.Buffer[(1,), "float32"], A: T.Buffer[(16,), "float32"], D: T.Buffer[(2,), "float32"], E: T.Buffer[(16,), "float32"], F: T.Buffer[(16,), "float32"], ): for i in T.serial( 0, 16, ): with T.block(): B = T.alloc_buffer((1,), dtype="float32") with T.block(): B[0] = A[i] * T.float32(2) with T.block(): E[i] = A[i] F[i] = E[i] + 1.0 C[0] = C[0] + A[i] + B[0] + T.float32(1) + D[0] A[i] = B[0] + T.float32(1) + D[1]
def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: A = T.match_buffer(a, (n, m), "float32") C = T.match_buffer(c, (n, m), "float32") for i in range(0, n): with T.block(): T.reads(A[i, m]) T.writes(C[i, m]) B = T.alloc_buffer((m, ), "float32", scope="global") for j in range(0, m): with T.block() as []: T.reads(A[i, j]) T.writes(B[j]) B[j] = A[i, j] + 1.0 for j in range(0, m): with T.block() as []: T.reads(B[j]) T.writes(C[i, j]) C[i, j] = B[j] * 2.0
def main( # type: ignore a: T.handle, b: T.handle, d: T.handle, ) -> None: # pylint: disable=no-self-argument T.func_attr({"global_symbol": "main", "tir.noalias": True}) A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") D = T.match_buffer(d, (1024, 1024), "float32") C = T.alloc_buffer((1024, 1024), "float32") for i, j, k in T.grid(1024, 1024, 1024): with T.block("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 # type: ignore C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(1024, 1024): with T.block("relu"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = T.max(C[vi, vj], 0.0) # type: ignore