def main(A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None: T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") T_softmax_expsum = T.alloc_buffer([256], dtype="float32") for i0, i1 in T.grid(256, 256): with T.block("T_softmax_maxelem"): i0_1, k = T.axis.remap("SR", [i0, i1]) with T.init(): T_softmax_maxelem[i0_1] = T.min_value("float32") T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) for i0, i1 in T.grid(256, 256): with T.block("T_softmax_expsum"): i0_2, k = T.axis.remap("SR", [i0, i1]) with T.init(): T_softmax_expsum[i0_2] = T.float32(0) T_softmax_expsum[i0_2] = T_softmax_expsum[i0_2] + T.exp( A[i0_2, k] - T_softmax_maxelem[i0_2], dtype="float32") for i0_3, i1 in T.grid(256, 256): with T.block("T_softmax_norm"): i0_4, i1_1 = T.axis.remap("SS", [i0_3, i1]) T_softmax_norm[i0_4, i1_1] = T.exp( A[i0_4, i1_1] - T_softmax_maxelem[i0_4], dtype="float32") / T_softmax_expsum[i0_4]
def square_sum_square_root_factor_one_2_rfactor( A: T.Buffer[(16, 256, 256), "float32"], D: T.Buffer[(16, ), "float32"]) -> None: C = T.alloc_buffer([16], dtype="float32") C_rf = T.alloc_buffer([16, 1], dtype="float32") for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 1, 65536): with T.block("C_rf"): b = T.axis.spatial(16, i0) i = T.axis.reduce(256, i1_i2_fused_inner // 256) j = T.axis.reduce(256, i1_i2_fused_inner % 256) vi1_i2_fused_outer = T.axis.spatial(1, i1_i2_fused_outer) with T.init(): C_rf[b, vi1_i2_fused_outer] = T.float32(0) C_rf[b, vi1_i2_fused_outer] = C_rf[ b, vi1_i2_fused_outer] + A[b, i, j] * A[b, i, j] for i0, i1_i2_fused_outer in T.grid(16, 1): with T.block("C"): b, vi1_i2_fused_outer = T.axis.remap("SR", [i0, i1_i2_fused_outer]) with T.init(): C[b] = T.float32(0) C[b] = C[b] + C_rf[b, vi1_i2_fused_outer] for i0_1 in T.serial(16): with T.block("D"): b_1 = T.axis.spatial(16, i0_1) D[b_1] = T.sqrt(C[b_1], dtype="float32")
def matmul_rfactor(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) C_rf = T.alloc_buffer([4, 128, 128]) for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid( 128, 128, 4, 8, 4): with T.block("update_rf"): vi2_inner_inner = T.axis.S(4, i2_inner_inner) vi = T.axis.S(128, i0) vj = T.axis.S(128, i1) vi2_outer = T.axis.R(4, i2_outer) vi2_inner_outer = T.axis.R(8, i2_inner_outer) with T.init(): C_rf[vi2_inner_inner, vi, vj] = 0.0 C_rf[vi2_inner_inner, vi, vj] = C_rf[vi2_inner_inner, vi, vj] + (A[vi, ( ((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)] * B[vj, ( ((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)]) for i0_1, i1_1, i2_inner_inner_1 in T.grid(128, 128, 4): with T.block("update"): vi2_inner_inner_1, vi_1, vj_1 = T.axis.remap( "RSS", [i2_inner_inner_1, i0_1, i1_1]) with T.init(): C[vi_1, vj_1] = 0.0 C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1]
def square_sum_square_root_rfactor(a: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [16, 256, 256]) D = T.match_buffer(d, [16]) C = T.alloc_buffer([16]) C_rf = T.alloc_buffer([1, 16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): with T.block("C_rf"): vi1_i2_fused_inner, b = T.axis.remap("SS", [i1_i2_fused_inner, i0]) i = T.axis.R(256, T.floordiv(i1_i2_fused_outer, 256)) j = T.axis.R(256, T.floormod(i1_i2_fused_outer, 256)) with T.init(): C_rf[vi1_i2_fused_inner, b] = 0.0 C_rf[vi1_i2_fused_inner, b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j]) for i0_1, i1_i2_fused_inner_1 in T.grid(16, 1): with T.block("C"): vi1_i2_fused_inner_1, b_1 = T.axis.remap( "RS", [i1_i2_fused_inner_1, i0_1]) with T.init(): C[b_1] = 0.0 C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1] for i0_2 in T.serial(0, 16): with T.block("D"): b_2 = T.axis.S(16, i0_2) D[b_2] = T.sqrt(C[b_2], dtype="float32")
def duplicate_init() -> None: for i, j in T.grid(16, 16): with T.block(): vi, vj = T.axis.remap("SS", [i, j]) with T.init(): T.evaluate(1.0) with T.init(): # error T.evaluate(1.0)
def multiple_reduction_blocks_rfactor(a: T.handle, f: T.handle) -> None: A = T.match_buffer(a, [16, 16, 16]) C = T.alloc_buffer([16, 16]) D = T.alloc_buffer([16, 16]) E = T.alloc_buffer([16, 16]) F = T.match_buffer(f, [16, 16]) C_rf = T.alloc_buffer([16, 16, 4]) for i, j1, k1o, k1i in T.grid(16, 16, 4, 4): with T.block([4, 16, 16, T.reduce_axis(0, 4)], "C_rf") as [vk1o, ci, cj, vk1i]: T.bind(vk1o, k1o) T.bind(ci, i) T.bind(cj, j1) T.bind(vk1i, k1i) with T.init(): C_rf[ci, cj, vk1o] = 0.0 C_rf[ci, cj, vk1o] = C_rf[ci, cj, vk1o] + A[ci, cj, ((vk1o * 4) + vk1i)] for i_1 in T.serial(0, 16): for j1_1 in T.serial(0, 16): for k1o_1 in T.serial(0, 4): with T.block([T.reduce_axis(0, 4), 16, 16], "C") as [vk1o_1, ci_1, cj_1]: T.bind(vk1o_1, k1o_1) T.bind(ci_1, i_1) T.bind(cj_1, j1_1) with T.init(): C[ci_1, cj_1] = 0.0 C[ci_1, cj_1] = C[ci_1, cj_1] + C_rf[ci_1, cj_1, vk1o_1] for k2o, k2i in T.grid(4, 4): with T.block([16, 16, T.reduce_axis(0, 16)], "D") as [di, dj, dk]: T.bind(di, i_1) T.bind(dj, j1_1) T.bind(dk, (k2o * 4) + k2i) with T.init(): D[di, dj] = 0.0 D[di, dj] = (D[di, dj] + A[di, dj, dk]) + C[di, dj] for j2 in T.serial(0, 16): for k3o, k3i in T.grid(4, 4): with T.block([16, 16, T.reduce_axis(0, 16)], "E") as [ei, ej, ek]: T.bind(ei, i_1) T.bind(ej, j2) T.bind(ek, (k3o * 4) + k3i) with T.init(): E[ei, ej] = 0.0 E[ei, ej] = (E[ei, ej] + A[ei, ej, ek]) + D[ei, ej] for k4o, k4i in T.grid(4, 4): with T.block([16, 16, T.reduce_axis(0, 16)], "F") as [fi, fj, fk]: T.bind(fi, i_1) T.bind(fj, j2) T.bind(fk, (k4o * 4) + k4i) with T.init(): F[fi, fj] = 0.0 F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj]
def softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: A = T.match_buffer(var_A, [256, 256], dtype="float32") T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256], dtype="float32") T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") for i0 in T.thread_binding(0, 256, thread="blockIdx.x"): for ax0_0 in T.serial(0, 8): for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_maxelem"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads([T_softmax_maxelem_shared[i0_1], A[i0_1, k]]) T.writes([T_softmax_maxelem_shared[i0_1]]) with T.init(): T_softmax_maxelem_shared[i0_1] = T.min_value("float32") T_softmax_maxelem_shared[i0_1] = T.max( T_softmax_maxelem_shared[i0_1], A[i0_1, k] ) for ax0_0 in T.serial(0, 8): for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_expsum"): i0_2 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads( [ T_softmax_expsum_shared[i0_2], A[i0_2, k], T_softmax_maxelem_shared[i0_2], ] ) T.writes([T_softmax_expsum_shared[i0_2]]) with T.init(): T_softmax_expsum_shared[i0_2] = T.float32(0) T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp( A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32" ) for i1_0 in T.serial(0, 8): for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_norm"): i0_3 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_0 * 32 + i1_1) T.reads( [ A[i0_3, i1], T_softmax_maxelem_shared[i0_3], T_softmax_expsum_shared[i0_3], ] ) T.writes([T_softmax_norm[i0_3, i1]]) T.block_attr({"axis": 1}) T_softmax_norm[i0_3, i1] = ( T.exp( A[i0_3, i1] - T_softmax_maxelem_shared[i0_3], dtype="float32", ) / T_softmax_expsum_shared[i0_3] )
def single_reduction_loop_with_block_predicate( A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None: T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") for i0 in T.serial(256): for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_maxelem"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) T.reads(T_softmax_maxelem_shared[i0_1], A[i0_1, k]) T.writes(T_softmax_maxelem_shared[i0_1]) with T.init(): T_softmax_maxelem_shared[i0_1] = T.float32( -3.4028234663852886e38) T_softmax_maxelem_shared[i0_1] = T.max( T_softmax_maxelem_shared[i0_1], A[i0_1, k]) for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_expsum"): i0_2 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) T.reads(T_softmax_expsum_shared[i0_2], A[i0_2, k], T_softmax_maxelem_shared[i0_2]) T.writes(T_softmax_expsum_shared[i0_2]) with T.init(): T_softmax_expsum_shared[i0_2] = T.float32(0) T_softmax_expsum_shared[ i0_2] = T_softmax_expsum_shared[i0_2] + T.exp( A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32") for i1_0 in T.serial(1): for i1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_norm"): i0_3 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_1) T.where(i1_0 * 512 + i1_1 < 256) T.reads(A[i0_3, i1], T_softmax_maxelem_shared[i0_3], T_softmax_expsum_shared[i0_3]) T.writes(T_softmax_norm[i0_3, i1]) T.block_attr({"axis": 1}) T_softmax_norm[i0_3, i1] = ( T.exp(A[i0_3, i1] - T_softmax_maxelem_shared[i0_3], dtype="float32") / T_softmax_expsum_shared[i0_3])
def rowsum_zero_dim_rfactor(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128]) B = T.match_buffer(b, []) B_rf = T.alloc_buffer([128]) with T.block([128], "B_rf") as [vi0]: with T.init(): B_rf[vi0] = 0.0 B_rf[vi0] = B_rf[vi0] + A[vi0] with T.block([T.reduce_axis(0, 128)], "B") as [vi0_1]: with T.init(): B[()] = 0.0 B[()] = B[()] + B_rf[vi0_1]
def before_matmul_vectorize( placeholder: T.Buffer[(64, 768), "float32"], placeholder_1: T.Buffer[(768, 768), "float32"], T_matmul_NT: T.Buffer[(64, 768), "float32"], ) -> None: with T.block("root"): T.reads() T.writes() T.block_attr({"meta_schedule.vectorize": 64}) T_matmul_NT_global = T.alloc_buffer([64, 768], dtype="float32") for i0_0, i1_0, i0_1, i1_1 in T.grid(1, 16, 1, 3): for i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid( 48, 8, 1, 16, 8, 16): with T.block("T_matmul_NT"): i = T.axis.spatial(64, i0_2 * 8 + i0_3) j = T.axis.spatial(768, i1_0 * 48 + i1_1 * 16 + i1_3) k = T.axis.reduce(768, i2_0 * 16 + i2_1) T.reads(placeholder[i, k], placeholder_1[j, k]) T.writes(T_matmul_NT_global[i, j]) with T.init(): T_matmul_NT_global[i, j] = T.float32(0) T_matmul_NT_global[i, j] = T_matmul_NT_global[ i, j] + placeholder[i, k] * placeholder_1[j, k] for ax0, ax1 in T.grid(64, 16): with T.block("T_matmul_NT_global"): v0 = T.axis.spatial(64, ax0) v1 = T.axis.spatial(768, i1_0 * 48 + i1_1 * 16 + ax1) T.reads(T_matmul_NT_global[v0, v1]) T.writes(T_matmul_NT[v0, v1]) T_matmul_NT[v0, v1] = T_matmul_NT_global[v0, v1]
def rowsum(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,), "float32"]) -> None: for k, i in T.grid(128, 128): with T.block("B"): vk, vi = T.axis.remap("RS", [k, i]) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk]
def cascade_pool_ops( x: T.Buffer[(1, 16, 112, 112), "float32"], y2: T.Buffer[(1, 16, 108, 108), "float32"] ) -> None: y1 = T.alloc_buffer([1, 16, 110, 110], dtype="float32") for n, c, h, w, kh, kw in T.grid(1, 16, 110, 110, 3, 3): with T.block("pool_0"): ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [n, c, h, w, kh, kw]) with T.init(): y1[ax0, ax1, ax2, ax3] = 0.0 y1[ax0, ax1, ax2, ax3] = y1[ax0, ax1, ax2, ax3] + x[ax0, ax1, ax2 + rv0, ax3 + rv1] for n, c, h, w, kh, kw in T.grid(1, 16, 108, 108, 3, 3): with T.block("pool_1"): ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [n, c, h, w, kh, kw]) with T.init(): y2[ax0, ax1, ax2, ax3] = 0.0 y2[ax0, ax1, ax2, ax3] = y2[ax0, ax1, ax2, ax3] + y1[ax0, ax1, ax2 + rv0, ax3 + rv1]
def conv2d_nhwc( Input: T.Buffer[(1, 224, 224, 3), "float32"], Weight: T.Buffer[(7, 7, 3, 64), "float32"], Conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"], ) -> None: PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): with T.block("PadInput"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( ((((i1_1 >= 3) and (i1_1 < 227)) and (i2_1 >= 3)) and (i2_1 < 227)), Input[i0_1, (i1_1 - 3), (i2_1 - 3), i3_1], T.float32(0), dtype="float32", ) for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3): with T.block("conv2d_nhwc"): n, h, w, co, rh, rw, rc = T.axis.remap( "SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) with T.init(): Conv2d_nhwc[n, h, w, co] = T.float32(0) Conv2d_nhwc[n, h, w, co] = Conv2d_nhwc[n, h, w, co] + ( PadInput[n, ((h * 2) + rh), ((w * 2) + rw), ( (T.floordiv(co, 64) * 3) + rc)] * Weight[rh, rw, rc, co])
def after_matmul_vectorize( placeholder: T.Buffer[(64, 768), "float32"], placeholder_1: T.Buffer[(768, 768), "float32"], T_matmul_NT: T.Buffer[(64, 768), "float32"], ) -> None: T.func_attr({ "global_symbol": "main", "tir.noalias": True, "layout_free_placeholders": [1] }) T_matmul_NT_global = T.alloc_buffer([64, 768], dtype="float32") for i0_0, i1_0, i0_1, i1_1 in T.grid(1, 16, 1, 3): for i2_0, i0_2, i1_2, i2_1, i0_3 in T.grid(48, 8, 1, 16, 8): for i1_3_fused in T.vectorized(16): with T.block("T_matmul_NT"): i = T.axis.spatial(64, i0_2 * 8 + i0_3) j = T.axis.spatial(768, i1_0 * 48 + i1_1 * 16 + i1_3_fused) k = T.axis.reduce(768, i2_0 * 16 + i2_1) T.reads(placeholder[i, k], placeholder_1[j, k]) T.writes(T_matmul_NT_global[i, j]) with T.init(): T_matmul_NT_global[i, j] = T.float32(0) T_matmul_NT_global[i, j] = T_matmul_NT_global[ i, j] + placeholder[i, k] * placeholder_1[j, k] for ax0 in T.serial(64): for ax1_fused in T.vectorized(16): with T.block("T_matmul_NT_global"): v0 = T.axis.spatial(64, ax0) v1 = T.axis.spatial(768, i1_0 * 48 + i1_1 * 16 + ax1_fused) T.reads(T_matmul_NT_global[v0, v1]) T.writes(T_matmul_NT[v0, v1]) T_matmul_NT[v0, v1] = T_matmul_NT_global[v0, v1]
def conv2d_nhwc_transformed( Input: T.Buffer[(1, 224, 224, 3), "float32"], Weight: T.Buffer[(7, 7, 3, 64), "float32"], Conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"], ) -> None: PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): with T.block("PadInput"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(Input[i0_1, i1_1 - 3, i2_1 - 3, i3_1]) T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( i1_1 >= 3 and i1_1 < 227 and i2_1 >= 3 and i2_1 < 227, Input[i0_1, i1_1 - 3, i2_1 - 3, i3_1], T.float32(0), dtype="float32", ) for ax0, ax_1, ax_2 in T.grid(12544, 64, 147): with T.block("conv2d_nhwc"): bv0, bv1, bv2 = T.axis.remap("SSR", [ax0, ax_1, ax_2]) T.reads( PadInput[0, bv0 // 112 * 2 + bv2 // 21, bv0 % 112 * 2 + bv2 % 21 // 3, bv2 % 3], Weight[bv2 // 21, bv2 % 21 // 3, bv2 % 3, bv1], ) T.writes(Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1]) with T.init(): Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1] = T.float32(0) Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1] = ( Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1] + PadInput[0, bv0 // 112 * 2 + bv2 // 21, bv0 % 112 * 2 + bv2 % 21 // 3, bv2 % 3] * Weight[bv2 // 21, bv2 % 21 // 3, bv2 % 3, bv1])
def main( placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4): with T.block("conv2d_NCHWc_int8"): ( n, oc_chunk, oh, ow, oc_block, kh, kw, ic_outer, ic_f_inner, ic_s_inner, ) = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9]) T.reads( placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], ) T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) with T.init(): conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0 conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[ n, oc_chunk, oh, ow, oc_block ] + T.cast( placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32" ) * T.cast( placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], "int32", )
def rfactor_spatial_only_after( A: T.Buffer[(1, 512, 7, 7), "float32"], B: T.Buffer[(1, 512, 1, 1), "float32"], ) -> None: # body # with T.block("root") B_rf = T.alloc_buffer([1, 512, 1, 1, 49], dtype="float32") for _i0, i1, _i2, _i3, i4, _i5 in T.grid(1, 512, 1, 1, 49, 1): with T.block("acc_rf"): vi4 = T.axis.spatial(49, i4) ax0 = T.axis.spatial(1, 0) ax1 = T.axis.spatial(512, i1) ax2 = T.axis.spatial(1, 0) ax3 = T.axis.spatial(1, 0) B_rf[ax0, ax1, ax2, ax3, vi4] = A[ax0, ax1, ax2 * 7 + vi4 // 7, ax3 * 7 + vi4 % 7] for _i0, i1, _i2, _i3, i4, _i5 in T.grid(1, 512, 1, 1, 49, 1): with T.block("acc"): vi4 = T.axis.reduce(49, i4) ax0 = T.axis.spatial(1, 0) ax1 = T.axis.spatial(512, i1) ax2 = T.axis.spatial(1, 0) ax3 = T.axis.spatial(1, 0) with T.init(): B[ax0, ax1, ax2, ax3] = T.float32(0) B[ax0, ax1, ax2, ax3] = B[ax0, ax1, ax2, ax3] + B_rf[ax0, ax1, ax2, ax3, vi4]
def non_perfect_tiling_cache(a: T.handle, b: T.handle) -> None: X = T.match_buffer(a, [224, 224], dtype="float32") Y = T.match_buffer(b, [224, 224], dtype="float32") cache = T.alloc_buffer([224, 224], dtype="float32") for hh_0, ww_0 in T.grid(28, 28): for ax0 in T.serial(0, 10): for ax1 in T.serial(0, 10): with T.block("cache"): h = T.axis.spatial(224, hh_0 * 8 - 1 + ax0) w = T.axis.spatial(224, ww_0 * 8 - 1 + ax1) T.where(1 <= hh_0 * 8 + ax0 and hh_0 * 8 + ax0 < 225 and 1 <= ww_0 * 8 + ax1 and ww_0 * 8 + ax1 < 225) cache[h, w] = X[h, w] for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3): with T.block("compute"): h = T.axis.spatial(224, hh_0 * 8 + hh_1) w = T.axis.spatial(224, ww_0 * 8 + ww_1) kh, kw = T.axis.remap("RR", [khh, kww]) with T.init(): Y[h, w] = 0.0 Y[h, w] = T.max( Y[h, w], T.if_then_else( T.likely(1 <= h + kh, dtype="bool") and T.likely(h + kh < 225, dtype="bool") and T.likely(1 <= w + kw, dtype="bool") and T.likely(w + kw < 225, dtype="bool"), cache[h + kh - 1, w + kw - 1], 0.0, dtype="float32", ), )
def tiled_conv2d_with_padding( inputs: T.Buffer[(1, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7, 3, 64), "float32"], conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"], ) -> None: PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): with T.block("PadInput"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1]) T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( 3 <= i1_1 and i1_1 < 227 and 3 <= i2_1 and i2_1 < 227, inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1], T.float32(0), dtype="float32", ) for ( i0_0, i1_0, i2_0, i3_0, i0_1_1, i1_1_1, i2_1_1, i3_1_1, i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, ) in T.grid(1, 1, 4, 1, 1, 2, 4, 1, 7, 7, 1, 1, 1, 1, 1, 1, 1, 3, 1, 56, 7, 64): with T.block("conv2d_nhwc"): n = T.axis.spatial(1, 0) h = T.axis.spatial(112, i1_1_1 * 56 + i1_3) w = T.axis.spatial(112, i2_0 * 28 + i2_1_1 * 7 + i2_3) co, rh, rw, rc = T.axis.remap("SRRR", [i3_3, i4_0, i5_0, i6_1]) T.reads( conv2d_nhwc[n, h, w, co], PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rh, rw, rc, co], ) T.writes(conv2d_nhwc[n, h, w, co]) with T.init(): conv2d_nhwc[n, h, w, co] = T.float32(0) conv2d_nhwc[n, h, w, co] = ( conv2d_nhwc[n, h, w, co] + PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc] * weight[rh, rw, rc, co])
def multiple_reduction_blocks(a: T.handle, f: T.handle) -> None: A = T.match_buffer(a, (16, 16, 16)) C = T.alloc_buffer((16, 16)) D = T.alloc_buffer((16, 16)) E = T.alloc_buffer((16, 16)) F = T.match_buffer(f, (16, 16)) for i in T.serial(0, 16): for j1 in T.serial(0, 16): for k1o, k1i in T.grid(4, 4): with T.block([16, 16, T.reduce_axis(0, 16)], "C") as [ci, cj, ck]: T.bind(ci, i) T.bind(cj, j1) T.bind(ck, k1o * 4 + k1i) with T.init(): C[ci, cj] = 0.0 C[ci, cj] = C[ci, cj] + A[ci, cj, ck] for k2o, k2i in T.grid(4, 4): with T.block([16, 16, T.reduce_axis(0, 16)], "D") as [di, dj, dk]: T.bind(di, i) T.bind(dj, j1) T.bind(dk, k2o * 4 + k2i) with T.init(): D[di, dj] = 0.0 D[di, dj] = D[di, dj] + A[di, dj, dk] + C[di, dj] for j2 in T.serial(0, 16): for k3o, k3i in T.grid(4, 4): with T.block([16, 16, T.reduce_axis(0, 16)], "E") as [ei, ej, ek]: T.bind(ei, i) T.bind(ej, j2) T.bind(ek, k3o * 4 + k3i) with T.init(): E[ei, ej] = 0.0 E[ei, ej] = E[ei, ej] + A[ei, ej, ek] + D[ei, ej] for k4o, k4i in T.grid(4, 4): with T.block([16, 16, T.reduce_axis(0, 16)], "F") as [fi, fj, fk]: T.bind(fi, i) T.bind(fj, j2) T.bind(fk, k4o * 4 + k4i) with T.init(): F[fi, fj] = 0.0 F[fi, fj] = F[fi, fj] + A[fi, fj, fk] + E[fi, fj]
def rowsum_not_compact_data_flow(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128,)) with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: with T.init(): B[vk] = 0.0 B[vk] = B[vk] + A[vi, vk]
def rowsum_zero_dim(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128]) B = T.match_buffer(b, []) with T.block([T.reduce_axis(0, 128)], "B") as [k]: with T.init(): B[()] = 0.0 B[()] = B[()] + A[k]
def matmul_loop_multiple_children(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) D = T.match_buffer(d, [128, 128]) for k, i, j in T.grid(128, 128, 128): with T.block("C"): ck, ci, cj = T.axis.remap("RSS", [k, i, j]) with T.init(): C[ci, cj] = 0.0 C[ci, cj] = C[ci, cj] + A[ci, ck] * B[ck, cj] with T.block("D"): dk, di, dj = T.axis.remap("RSS", [k, i, j]) with T.init(): D[di, dj] = 0.0 D[di, dj] = D[di, dj] + B[di, dk] * A[dk, dj]
def rowsum_wrong_reduce_pattern2(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, )) with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: with T.init(): B[vi] = 0.0 B[vi] = B[vi] - A[vi, vk]
def multiple_reduction_blocks_rfactor(a: T.handle, f: T.handle) -> None: A = T.match_buffer(a, [16, 16, 16]) C = T.alloc_buffer([16, 16]) D = T.alloc_buffer([16, 16]) E = T.alloc_buffer([16, 16]) F = T.match_buffer(f, [16, 16]) C_rf = T.alloc_buffer([16, 16, 4]) for i, j1, k1o, k1i in T.grid(16, 16, 4, 4): with T.block("C_rf"): vk1o, ci, cj, vk1i = T.axis.remap("SSSR", [k1o, i, j1, k1i]) with T.init(): C_rf[ci, cj, vk1o] = 0.0 C_rf[ci, cj, vk1o] = C_rf[ci, cj, vk1o] + A[ci, cj, ((vk1o * 4) + vk1i)] for i_1 in T.serial(0, 16): for j1_1 in T.serial(0, 16): for k1o_1 in T.serial(0, 4): with T.block("C"): vk1o_1, ci_1, cj_1 = T.axis.remap("RSS", [k1o_1, i_1, j1_1]) with T.init(): C[ci_1, cj_1] = 0.0 C[ci_1, cj_1] = C[ci_1, cj_1] + C_rf[ci_1, cj_1, vk1o_1] for k2o, k2i in T.grid(4, 4): with T.block("D"): di, dj = T.axis.remap("SS", [i_1, j1_1]) dk = T.axis.R(16, k2o * 4 + k2i) with T.init(): D[di, dj] = 0.0 D[di, dj] = (D[di, dj] + A[di, dj, dk]) + C[di, dj] for j2 in T.serial(0, 16): for k3o, k3i in T.grid(4, 4): with T.block("E"): ei, ej = T.axis.remap("SS", [i_1, j2]) ek = T.axis.R(16, k3o * 4 + k3i) with T.init(): E[ei, ej] = 0.0 E[ei, ej] = (E[ei, ej] + A[ei, ej, ek]) + D[ei, ej] for k4o, k4i in T.grid(4, 4): with T.block("F"): fi, fj = T.axis.remap("SS", [i_1, j2]) fk = T.axis.R(16, k4o * 4 + k4i) with T.init(): F[fi, fj] = 0.0 F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj]
def rowsum_predicate_rfactor(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") B_rf = T.alloc_buffer([128, 13], dtype="float32") for i, k_0, k_1 in T.grid(128, 13, 10): with T.block("B_rf"): vk_0, vi, vk_1 = T.axis.remap("SSR", [k_0, i, k_1]) T.where(k_0 * 10 + k_1 < 128) with T.init(): B_rf[vi, vk_0] = T.float32(0) B_rf[vi, vk_0] = B_rf[vi, vk_0] + A[vi, vk_0 * 10 + vk_1] for i, k_0 in T.grid(128, 13): with T.block("B"): vk_0, vi = T.axis.remap("RS", [k_0, i]) with T.init(): B[vi] = T.float32(0) B[vi] = B[vi] + B_rf[vi, vk_0]
def lowered_multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16, 16], dtype="float32") B = T.match_buffer(b, [16], dtype="float32") B_rf_local = T.alloc_buffer([16, 16], dtype="float32", scope="local") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.thread_binding(0, 16, thread="blockIdx.x"): for k0o in T.thread_binding(0, 4, thread="threadIdx.x"): with T.block("B_in_thread_init"): T.reads([]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.float32(0) for k0i0, k1 in T.grid(4, 16): with T.block("B_rf"): vk0 = T.axis.spatial(16, k0o * 4 + k0i0) vi, vk1 = T.axis.remap("SR", [i, k1]) T.reads([B_rf_local[vk0, vi], A[vi, vk0, vk1]]) T.writes([B_rf_local[vk0, vi]]) with T.init(): B_rf_local[vk0, vi] = T.float32(0) B_rf_local[vk0, vi] = B_rf_local[vk0, vi] + A[vi, vk0, vk1] for k0i1 in T.serial(0, 4): with T.block("B_normal_reduction"): vk0 = T.axis.reduce(16, k0o * 4 + k0i1) vi = T.axis.spatial(16, i) T.reads([B_rf_local[vk0, vi], normal_reduce_temp0[0]]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[ 0] = normal_reduce_temp0[0] + B_rf_local[vk0, vi] with T.block("B_cross_thread_reduction"): T.reads([normal_reduce_temp0[0]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), normal_reduce_temp0[0], True, reduce_temp0.data, k0o, dtype="handle", )) with T.block("B_write_back"): vi = T.axis.spatial(16, i) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0]
def cuda_matmul_2(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=undefined-loop-variable A = T.match_buffer(a, [2048, 2048], "float32") B = T.match_buffer(b, [2048, 2048], "float32") C = T.match_buffer(c, [2048, 2048], "float32") A_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") B_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") with T.block([2048, 2048], "A_shared") as [v0, v1]: A_shared[v0, v1] = A[v0, v1] with T.block([2048, 2048], "B_shared") as [v0, v1]: B_shared[v0, v1] = B[v0, v1] with T.block([2048, 2048], "B_shared_local") as [v0, v1]: B_shared_local[v0, v1] = B_shared[v0, v1] for by in T.thread_binding(0, 32, thread="blockIdx.y"): for bx in T.thread_binding(0, 32, thread="blockIdx.x"): for vy in T.thread_binding(0, 2, thread="vthread.y"): for vx in T.thread_binding(0, 2, thread="vthread.x"): for ty in T.thread_binding(0, 8, thread="threadIdx.y"): for tx in T.thread_binding(0, 8, thread="threadIdx.x"): for k_0 in T.serial(0, 256): for k_1 in T.unroll(0, 8): for i, j in T.grid(1, 4): with T.block( [2048, 2048], "A_shared_local") as [v0, v1]: T.bind(v0, k_0 * 8 + k_1 + i) T.bind( v1, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] for _, i, j in T.grid(1, 4, 4): with T.block([ 2048, 2048, T.reduce_axis(0, 2048) ], "C") as [vi, vj, vk]: T.bind( vi, by * 64 + vy * 32 + ty * 4 + i) T.bind( vj, bx * 64 + vx * 32 + tx * 4 + j) T.bind(vk, k_0 * 8 + k_1) with T.init(): C_local[vi, vj] = T.float32(0) C_local[vi, vj] = C_local[ vi, vj] + A_shared_local[ vk, vi] * B_shared_local[ vk, vj] for i, j in T.grid(4, 4): with T.block([2048, 2048], "C_local") as [v0, v1]: T.bind(v0, by * 64 + vy * 32 + ty * 4 + i) T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1]
def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) with T.block([128, 128, T.reduce_axis(0, 128)]) as [i, j, k]: with T.init(): C[i, j] = 0.0 C[i, j] += A[i, k] * B[j, k]
def square_sum_rfactor(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [16, 256, 256]) C = T.match_buffer(c, [16]) C_rf = T.alloc_buffer([16, 256]) for i0, i1, i2 in T.grid(16, 256, 256): with T.block("C_rf"): vi2, b, i = T.axis.remap("SSR", [i2, i0, i1]) with T.init(): C_rf[b, vi2] = 0.0 C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2]) for i0_1, i2_1 in T.grid(16, 256): with T.block("C"): vi2_1, b_1 = T.axis.remap("RS", [i2_1, i0_1]) with T.init(): C[b_1] = 0.0 C[b_1] = C[b_1] + C_rf[b_1, vi2_1]