def main(placeholder: T.Buffer[(1, 384), "int64"], placeholder_1: T.Buffer[(30522, 768), "float32"], placeholder_2: T.Buffer[(1, 384, 768), "float32"], T_add: T.Buffer[(1, 384, 768), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") for i0, i1, i2 in T.grid(1, 384, 768): with T.block("T_add_1"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads( placeholder[ax0, ax1], placeholder_1[ T.min(T.max(T.int64(0), placeholder[ ax0, ax1]), T.int64(30521)):T.min( T.max(T.int64(0), placeholder[ax0, ax1] + T.int64(30522)), T.int64(30521)) + T.int64(1), ax2], placeholder_2[ax0, ax1, ax2]) T.writes(T_add[ax0, ax1, ax2]) T_add[ax0, ax1, ax2] = placeholder_1[T.min( T.max( T.int64(0), T.Select( T.cast(placeholder[ax0, ax1] < T.int64(0), "int32" ) != 0, placeholder[ax0, ax1] + T.int64(30522), placeholder[ax0, ax1]) ), T.int64(30521)), ax2] + placeholder_2[ax0, ax1, ax2]
def compacted_sparse_read_cache( A_data: T.Buffer[(819,), "float32"], B: T.Buffer[(128,), "float32"], A_indptr: T.Buffer[(129,), "int32"], A_indices: T.Buffer[(819,), "int32"], ) -> None: for i in T.serial(128): with T.block("rowsum_outer"): T.reads( A_indptr[i : i + 1], A_data[A_indptr[i] + 0 : A_indptr[i] + 0 + (A_indptr[i + 1] - A_indptr[i])], ) T.writes(B[i]) with T.block("rowsum_init"): T.reads() T.writes(B[i]) B[i] = T.float32(0) for k in T.serial(A_indptr[i + 1] - A_indptr[i]): with T.block(): T.reads(A_indptr[i], A_data[A_indptr[i] + k], B[i]) T.writes(B[i]) A_data_local = T.alloc_buffer([1], dtype="float32", scope="local") with T.block("A_data_cache_read"): T.reads(A_indptr[i], A_data[A_indptr[i] + k]) T.writes(A_data_local[T.min(A_indptr[i] + k, 0)]) A_data_local[T.min(A_indptr[i] + k, 0)] = A_data[A_indptr[i] + k] with T.block("rowsum_inner"): T.reads(B[i], A_indptr[i], A_data[A_indptr[i] + k]) T.writes(B[i]) B[i] = B[i] + A_data_local[T.min(A_indptr[i] + k, 0)]
def elementwise_not_affine_fused(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [127, 128]) B = T.match_buffer(b, [127, 128]) for i in T.grid(4): for j_k_fused in T.serial(0, T.min(31, 126 - i * 32) * 128 + 128): with T.block("B"): vi = T.axis.S( 127, i * 32 + T.floormod(T.floordiv(j_k_fused, 128), T.min(31, 126 - i * 32) + 1), ) vj = T.axis.S(128, T.floormod(j_k_fused, 128)) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) B[vi, vj] = A[vi, vj]
def expected_recursive_bufferslice_indices(data: T.handle, index: T.handle) -> None: index_buf = T.match_buffer(index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1) data_buf = T.match_buffer(data, [16, 16], elem_offset=0, align=128, offset_factor=1) with T.block("root"): T.reads([]) T.writes([]) out_buf = T.alloc_buffer([16, 16], elem_offset=0, align=128, offset_factor=1) for i0, i1 in T.grid(16, 16): with T.block(): vi, vj = T.axis.remap("SS", [i0, i1]) T.reads([ data_buf[index_buf[index_buf[0]], index_buf[0]], index_buf[T.min(index_buf[0], 0):T.max(index_buf[0], 0) + 1], ]) T.writes([out_buf[vi, vj]]) out_buf[vi, vj] = data_buf[index_buf[index_buf[0]], index_buf[0]]
def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast( placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", "tir.noalias": True }) placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") T_cast_1 = T.match_buffer(T_cast, [360000], dtype="int16") # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): T_cast_1[ ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast( T.cast( T.max( T.min( T.q_multiply_shift(T.cast( placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16")
def elementwise_not_affine(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (127, 128)) B = T.match_buffer(b, (127, 128)) for i in T.serial(0, 4): for j, k in T.grid(T.min(31, 126 - i * 32) + 1, 128): with T.block("B"): vi = T.axis.S(127, i * 32 + j) vj = T.axis.S(128, k) B[vi, vj] = A[vi, vj]
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast( placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True }) placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16") placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32") T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16") # body PaddedInput = T.allocate([360000], "int16", "global") for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): T.store( PaddedInput, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True) for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): Conv2dOutput = T.allocate([64], "int32", "global") for ff in T.serial(0, 64): T.store(Conv2dOutput, ff, 0, True) for rc in T.serial(0, 64): T.store( Conv2dOutput, ff, T.load("int32", Conv2dOutput, ff) + T.cast( T.load("int16", PaddedInput, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast( T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True) for ax3_inner_1 in T.serial(0, 64): T.store( T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast( T.cast( T.max( T.min( T.q_multiply_shift( T.load("int32", Conv2dOutput, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True)
def access_of_padding_pattern() -> None: X = T.alloc_buffer([28, 28]) X_pad = T.alloc_buffer([32, 32]) Y = T.alloc_buffer([28, 28]) for i, j in T.grid(32, 32): with T.block("padding"): vi, vj = T.axis.remap("SS", [i, j]) T.reads([ X[T.max(vi - 2, 0):T.min(vi - 2, 27) + 1, T.max(vj - 2, 0):T.min(vj - 2, 27) + 1, ] ]) T.writes([X_pad[vi, vj]]) X_pad[vi, vj] = T.if_then_else(2 <= vi and vi < 30 and 2 <= vj and vj < 30, X[vi - 2, vj - 2], 0.0, dtype="float32") with T.block("padding_reverse"): vi, vj = T.axis.remap("SS", [i, j]) T.reads([ X_pad[T.max(vi, 2):T.min(vi, 29) + 1, T.max(vj, 2):T.min(vj, 29) + 1] ]) T.writes([ Y[T.max(vi - 2, 0):T.min(vi - 2, 27) + 1, T.max(vj - 2, 0):T.min(vj - 2, 27) + 1, ] ]) if 2 <= vi and vi < 30 and 2 <= vj and vj < 30: Y[vi - 2, vj - 2] = X_pad[vi, vj]
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1( placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True }) placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") T_cast_5 = T.match_buffer(T_cast_4, [360000], dtype="int16") # body PaddedInput_1 = T.allocate([379456], "int16", "global") for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): PaddedInput_1[i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1] = T.if_then_else( 1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, placeholder_13[i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): Conv2dOutput_1 = T.allocate([64], "int32", "global") for ff_1 in T.serial(0, 64): Conv2dOutput_1[ff_1] = 0 for ry, rx, rc_1 in T.grid(3, 3, 64): Conv2dOutput_1[ff_1] = Conv2dOutput_1[ff_1] + T.cast( PaddedInput_1[ T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1], "int32") * T.cast( placeholder_14[ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1], "int32") for ax3_inner_2 in T.serial(0, 64): T_cast_5[ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2] = T.cast( T.cast( T.max( T.min( T.q_multiply_shift( Conv2dOutput_1[ax3_inner_2] + placeholder_15[ax3_inner_2], 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16")
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_( placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", "tir.noalias": True }) placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16") placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") T_add_1 = T.match_buffer(T_add, [1440000], dtype="int32") # body PaddedInput_2 = T.allocate([360000], "int16", "global") for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): PaddedInput_2[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] = placeholder_19[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): Conv2dOutput_2 = T.allocate([64], "int32", "global") for ax3_outer_1 in T.serial(0, 4): for ff_2 in T.serial(0, 64): Conv2dOutput_2[ff_2] = 0 for rc_2 in T.serial(0, 64): Conv2dOutput_2[ff_2] = Conv2dOutput_2[ff_2] + T.cast( PaddedInput_2[ax0_ax1_fused_ax2_fused_2 * 64 + rc_2], "int32") * T.cast( placeholder_20[rc_2 * 256 + ax3_outer_1 * 64 + ff_2], "int32") for ax3_inner_3 in T.serial(0, 64): T_add_1[ ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3] = T.q_multiply_shift(T.cast( T.cast( T.max( T.min( T.q_multiply_shift( Conv2dOutput_2[ax3_inner_3] + placeholder_21[ax3_outer_1 * 64 + ax3_inner_3], 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136
def reduction_loop_only( A: T.Buffer[2, "float32"], B: T.Buffer[2, "float32"], C: T.Buffer[(), "float32"], ) -> None: for i0 in T.serial(2): with T.block("C"): k0 = T.axis.reduce(2, i0) T.reads(A[k0], B[k0]) T.writes(C[()]) with T.init(): C[()] = T.float32(1.0) C[()] = T.min(C[()], A[k0] / B[k0])
def read_out_of_bound_after_compute_at(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [16], "float32") B = T.alloc_buffer([16], "float32") C = T.match_buffer(c, [16], "float32") for j in T.serial(0, 16): for i in T.serial(0, T.min(1, 15 - j) + 1): with T.block("B"): v = T.axis.S(16, j + i) B[v] = A[v] with T.block("C"): v = T.axis.S(16, j) T.reads([B[v : v + 2]]) C[v] = T.if_then_else(v < 15, T.max(B[v], B[v + 1]), B[v], dtype="float32")
def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (n * 8, ), "float32") C = T.match_buffer(c, (n * 8, ), "float32") for i in range(0, n): with T.block(): T.reads(A[i * 8:i * 8 + 8]) T.writes(C[i * 8:i * 8 + 8]) B = T.alloc_buffer((T.min(n, 1) * 8, ), "float32") for j in range(0, 8): with T.block() as []: T.reads(A[i * 8 + j]) T.writes(B[j]) B[j] = A[i * 8 + j] + 1.0 for j in range(0, 8): with T.block() as []: T.reads(B[j]) T.writes(C[i * 8 + j]) C[i * 8 + j] = B[j] * 2.0
def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: T.handle, placeholder_31: T.handle, placeholder_32: T.handle, T_cast_8: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True}) placeholder_33 = T.match_buffer(placeholder_30, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_34 = T.match_buffer(placeholder_31, [3072], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_35 = T.match_buffer(placeholder_32, [16], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_9 = T.match_buffer(T_cast_8, [12544], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_3 = T.allocate([150528], "int16", "global") for i0_i1_fused_3 in T.parallel(0, 28): for i2_3, i3_3 in T.grid(28, 192): PaddedInput_3[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3) ] = placeholder_33[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)] for ax0_ax1_fused_ax2_fused_3 in T.parallel(0, 784): for ax3_2 in T.serial(0, 16): Conv2dOutput_3 = T.allocate([1], "int32", "global") Conv2dOutput_3[0] = 0 for rc_3 in T.serial(0, 192): Conv2dOutput_3[0] = (Conv2dOutput_3[0] + (T.cast(PaddedInput_3[((ax0_ax1_fused_ax2_fused_3*192) + rc_3)], "int32")*T.cast(placeholder_34[((rc_3*16) + ax3_2)], "int32"))) T_cast_9[((ax0_ax1_fused_ax2_fused_3*16) + ax3_2)] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_3[0] + placeholder_35[ax3_2]), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16")
def different_access_indices(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128], dtype="float32") B = T.match_buffer(b, [128, 128], dtype="float32") for i, j in T.grid(128, 128): for k in T.thread_binding(0, 128, thread="threadIdx.x"): with T.block("B"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) T.reads([B[vi, vj], A[vi, vj, vk]]) T.writes([ B[T.min(vj, vi):T.min(vj, vi) + (T.max(vj, vi) + 1 - T.min(vj, vi)), T.min(vi, vj):T.min(vi, vj) + (T.max(vi, vj) + 1 - T.min(vi, vj)), ] ]) with T.init(): B[vj, vi] = T.float32(0) B[vi, vj] = B[vi, vj] + A[vi, vj, vk]
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast( placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True }) placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16") placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") T_cast_3 = T.match_buffer(T_cast_2, [360000], dtype="int16") # body PaddedInput = T.allocate([360000], "int16", "global") for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): PaddedInput[i0_i1_fused * 4800 + i2 * 64 + i3] = placeholder_7[i0_i1_fused * 4800 + i2 * 64 + i3] for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): Conv2dOutput = T.allocate([64], "int32", "global") for ff in T.serial(0, 64): Conv2dOutput[ff] = 0 for rc in T.serial(0, 64): Conv2dOutput[ff] = Conv2dOutput[ff] + T.cast( PaddedInput[ax0_ax1_fused_ax2_fused * 64 + rc], "int32") * T.cast(placeholder_8[rc * 64 + ff], "int32") for ax3_inner_1 in T.serial(0, 64): T_cast_3[ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1] = T.cast( T.cast( T.max( T.min( T.q_multiply_shift(Conv2dOutput[ax3_inner_1] + placeholder_9[ax3_inner_1], 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16")
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast( placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True }) placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): T.store( PaddedInput_7, (((i0_i1_fused_7 * 687) + (i2_7 * 3)) + i3_7), T.if_then_else( ((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7 * 672) + (i2_7 * 3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True) for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): Conv2dOutput_7 = T.allocate([64], "int32", "global") for ff_3 in T.serial(0, 64): T.store(Conv2dOutput_7, ff_3, 0, True) for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): T.store( Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast( T.load("int16", PaddedInput_7, (( (((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112) * 1374) + (ry_2 * 687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112) * 6)) + (rx_2 * 3)) + rc_7)), "int32") * T.cast( T.load("int16", placeholder_66.data, ((((ry_2 * 1344) + (rx_2 * 192)) + (rc_7 * 64)) + ff_3)), "int32"))), True) for ax3_inner_7 in T.serial(0, 64): T.store( T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7 * 64) + ax3_inner_7), T.cast( T.max( T.min( T.q_multiply_shift( (T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True)
def conv2d_winograd_cuda( # type: ignore placeholder: T.Buffer[(1, 14, 14, 128), "float32"], # type: ignore placeholder_1: T.Buffer[(6, 6, 128, 128), "float32"], # type: ignore conv2d_winograd: T.Buffer[(1, 12, 12, 128), "float32"], # type: ignore ) -> None: # type: ignore data_pad = T.alloc_buffer([1, 16, 16, 128]) input_tile = T.alloc_buffer([6, 6, 9, 128]) B = T.alloc_buffer([6, 6]) data_pack = T.alloc_buffer([6, 6, 9, 128]) bgemm = T.alloc_buffer([6, 6, 9, 128]) A = T.alloc_buffer([6, 4]) inverse = T.alloc_buffer([4, 4, 9, 128]) for i0, i1, i2, i3 in T.grid(1, 16, 16, 128): with T.block("data_pad"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.block_attr({"schedule_rule": "None"}) T.reads([placeholder[i0_1, i1_1, i2_1, i3_1]]) T.writes([data_pad[i0_1, i1_1, i2_1, i3_1]]) data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( 0 <= i1_1 and i1_1 < 14 and 0 <= i2_1 and i2_1 < 14, # type: ignore placeholder[i0_1, i1_1, i2_1, i3_1], T.float32(0), dtype="float32", ) for i0_2, i1_2, i2_2, i3_2 in T.grid(6, 6, 9, 128): with T.block("input_tile"): eps, nu, p, ci = T.axis.remap("SSSS", [i0_2, i1_2, i2_2, i3_2]) T.block_attr({"schedule_rule": "None"}) T.reads( [ data_pad[ T.floordiv(p, 9), # type: ignore ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), # type: ignore ((T.floormod(p, 3) * 4) + nu), # type: ignore ci, ] ] ) T.writes([input_tile[eps, nu, p, ci]]) input_tile[eps, nu, p, ci] = data_pad[ T.floordiv(p, 9), # type: ignore ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), # type: ignore ((T.floormod(p, 3) * 4) + nu), # type: ignore ci, ] for i0_3, i1_3 in T.grid(6, 6): with T.block("B"): i, j = T.axis.remap("SS", [i0_3, i1_3]) T.block_attr({"schedule_rule": "meta_schedule.compute_inline"}) T.writes([B[i, j]]) # fmt: off B[i, j] = T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 0)), T.float32(0), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 5)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 4)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 3)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 2)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 0)), T.float32(1), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 5)), T.float32(-2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 4)), T.float32(-0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 3)), T.float32(2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 2)), T.float32(2.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 1)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 0)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 5)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 4)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 3)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 2)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 1)), T.float32(-2.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 0)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 4)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 3)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 2)), T.float32(-1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 0)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 5)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) # type: ignore # fmt: on for i0_4, i1_4, i2_3, i3_3, i4, i5 in T.grid(6, 6, 9, 128, 6, 6): with T.block("data_pack"): eps_1, nu_1, p_1, ci_1, r_a, r_b = T.axis.remap( "SSSSRR", [i0_4, i1_4, i2_3, i3_3, i4, i5] ) T.block_attr({"schedule_rule": "meta_schedule.winograd_data_pack.cuda"}) T.reads( [ data_pack[eps_1, nu_1, p_1, ci_1], input_tile[r_a, r_b, p_1, ci_1], B[ T.min(r_a, r_b) : ( # type: ignore T.min(r_a, r_b) + ((T.max(r_a, r_b) + 1) - T.min(r_a, r_b)) # type: ignore ), T.min(eps_1, nu_1) : ( # type: ignore T.min(eps_1, nu_1) + ((T.max(eps_1, nu_1) + 1) - T.min(eps_1, nu_1)) # type: ignore ), ], ] ) T.writes([data_pack[eps_1, nu_1, p_1, ci_1]]) with T.init(): data_pack[eps_1, nu_1, p_1, ci_1] = T.float32(0) data_pack[eps_1, nu_1, p_1, ci_1] = data_pack[eps_1, nu_1, p_1, ci_1] + ( (input_tile[r_a, r_b, p_1, ci_1] * B[r_a, eps_1]) * B[r_b, nu_1] ) for i0_5, i1_5, i2_4, i3_4, i4_1 in T.grid(6, 6, 9, 128, 128): with T.block("bgemm"): eps_2, nu_2, p_2, co, ci_2 = T.axis.remap("SSSSR", [i0_5, i1_5, i2_4, i3_4, i4_1]) T.block_attr({"meta_schedule.write_cache_level": [3]}) T.reads( [ bgemm[eps_2, nu_2, p_2, co], data_pack[eps_2, nu_2, p_2, ci_2], placeholder_1[eps_2, nu_2, co, ci_2], ] ) T.writes([bgemm[eps_2, nu_2, p_2, co]]) with T.init(): bgemm[eps_2, nu_2, p_2, co] = T.float32(0) bgemm[eps_2, nu_2, p_2, co] = bgemm[eps_2, nu_2, p_2, co] + ( data_pack[eps_2, nu_2, p_2, ci_2] * placeholder_1[eps_2, nu_2, co, ci_2] ) for i0_6, i1_6 in T.grid(6, 4): with T.block("A"): i_1, j_1 = T.axis.remap("SS", [i0_6, i1_6]) T.block_attr({"schedule_rule": "meta_schedule.compute_inline"}) T.writes([A[i_1, j_1]]) # fmt: off A[i_1, j_1] = T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 0)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 3)), T.float32(-8), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 2)), T.float32(4), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 1)), T.float32(-2), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 3)), T.float32(0.125), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 2)), T.float32(0.25), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 1)), T.float32(0.5), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 1)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 3)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 1)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 3)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))) # type: ignore # fmt: on for i0_7, i1_7, i2_5, i3_5, i4_2, i5_1 in T.grid(4, 4, 9, 128, 6, 6): with T.block("inverse"): vh, vw, p_3, co_1, r_a_1, r_b_1 = T.axis.remap( "SSSSRR", [i0_7, i1_7, i2_5, i3_5, i4_2, i5_1] ) T.block_attr({"schedule_rule": "meta_schedule.winograd_inverse"}) T.reads( [ inverse[vh, vw, p_3, co_1], bgemm[r_a_1, r_b_1, p_3, co_1], A[ T.min(r_a_1, r_b_1) : ( # type: ignore T.min(r_a_1, r_b_1) + ((T.max(r_a_1, r_b_1) + 1) - T.min(r_a_1, r_b_1)) # type: ignore ), T.min(vh, vw) : (T.min(vh, vw) + ((T.max(vh, vw) + 1) - T.min(vh, vw))), # type: ignore ], ] ) T.writes([inverse[vh, vw, p_3, co_1]]) with T.init(): inverse[vh, vw, p_3, co_1] = T.float32(0) inverse[vh, vw, p_3, co_1] = inverse[vh, vw, p_3, co_1] + ( (bgemm[r_a_1, r_b_1, p_3, co_1] * A[r_a_1, vh]) * A[r_b_1, vw] ) for i0_8, i1_8, i2_6, i3_6 in T.grid(1, 12, 12, 128): with T.block("conv2d_winograd"): n, h, w, co_2 = T.axis.remap("SSSS", [i0_8, i1_8, i2_6, i3_6]) T.reads( [ inverse[ T.floormod(h, 4), # type: ignore T.floormod(w, 4), # type: ignore (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), # type: ignore co_2, ] ] ) T.writes([conv2d_winograd[n, h, w, co_2]]) conv2d_winograd[n, h, w, co_2] = inverse[ T.floormod(h, 4), # type: ignore T.floormod(w, 4), # type: ignore (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), # type: ignore co_2, ]
def main( placeholder: T.Buffer[(1, 384), "int64"], placeholder_1: T.Buffer[(30522, 768), "float32"], placeholder_2: T.Buffer[(1, 384, 768), "float32"], T_add: T.Buffer[(1, 384, 768), "float32"], ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) compile_engine_const = T.alloc_buffer([], dtype="int64") T_less = T.alloc_buffer([1, 384], dtype="bool") compile_engine_const_1 = T.alloc_buffer([], dtype="int64") T_add_1 = T.alloc_buffer([1, 384], dtype="int64") T_where = T.alloc_buffer([1, 384], dtype="int64") T_take = T.alloc_buffer([1, 384, 768], dtype="float32") with T.block("compile_engine_const"): vi = T.axis.spatial(1, 0) T.reads() T.writes(compile_engine_const[()]) compile_engine_const[()] = T.int64(0) for i0, i1 in T.grid(1, 384): with T.block("T_less"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(placeholder[ax0, ax1], compile_engine_const[()]) T.writes(T_less[ax0, ax1]) T_less[ax0, ax1] = placeholder[ax0, ax1] < compile_engine_const[()] with T.block("compile_engine_const_1"): vi = T.axis.spatial(1, 0) T.reads() T.writes(compile_engine_const_1[()]) compile_engine_const_1[()] = T.int64(30522) for i0, i1 in T.grid(1, 384): with T.block("T_add"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(placeholder[ax0, ax1], compile_engine_const_1[()]) T.writes(T_add_1[ax0, ax1]) T_add_1[ax0, ax1] = placeholder[ax0, ax1] + compile_engine_const_1[()] for i0, i1 in T.grid(1, 384): with T.block("T_where"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(T_less[ax0, ax1], T_add_1[ax0, ax1], placeholder[ax0, ax1]) T.writes(T_where[ax0, ax1]) T_where[ax0, ax1] = T.Select( T.cast(T_less[ax0, ax1], "int32") != 0, T_add_1[ax0, ax1], placeholder[ax0, ax1]) for i0, i1, i2 in T.grid(1, 384, 768): with T.block("T_take"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads( placeholder_1[T.min(T.max(T.int64(0), T_where[ ax0, ax1]), T.int64(30521)), ax2], T_where[ax0, ax1], ) T.writes(T_take[ax0, ax1, ax2]) T_take[ax0, ax1, ax2] = placeholder_1[ T.min(T.max(T.int64(0), T_where[ax0, ax1]), T.int64(30521)), ax2] for i0, i1, i2 in T.grid(1, 384, 768): with T.block("T_add_1"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_take[ax0, ax1, ax2], placeholder_2[ax0, ax1, ax2]) T.writes(T_add[ax0, ax1, ax2]) T_add[ax0, ax1, ax2] = T_take[ax0, ax1, ax2] + placeholder_2[ax0, ax1, ax2]
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_( placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", "tir.noalias": True }) placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16") placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16") placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32") placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32") T_cast_7 = T.match_buffer(T_cast_6, [1440000], dtype="uint8") # body PaddedInput_3 = T.allocate([360000], "int16", "global") for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): PaddedInput_3[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] = placeholder_29[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): Conv2dOutput_3 = T.allocate([64], "int32", "global") for ax3_outer_2 in T.serial(0, 4): for ff_3 in T.serial(0, 64): Conv2dOutput_3[ff_3] = 0 for rc_3 in T.serial(0, 64): Conv2dOutput_3[ff_3] = Conv2dOutput_3[ff_3] + T.cast( PaddedInput_3[ax0_ax1_fused_ax2_fused_3 * 64 + rc_3], "int32") * T.cast( placeholder_27[rc_3 * 256 + ax3_outer_2 * 64 + ff_3], "int32") for ax3_inner_4 in T.serial(0, 64): T_cast_7[ ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4] = T.cast( T.max( T.min( T.q_multiply_shift(T.cast( T.cast( T.max( T.min( T.q_multiply_shift( Conv2dOutput_3[ ax3_inner_4] + placeholder_26[ ax3_outer_2 * 64 + ax3_inner_4], 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + placeholder_28[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4], 255), 0), "uint8")
def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2( placeholder_30: T.handle, placeholder_31: T.handle, placeholder_32: T.handle, T_cast_8: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True }) placeholder_33 = T.match_buffer(placeholder_30, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_34 = T.match_buffer(placeholder_31, [1, 1, 192, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_35 = T.match_buffer(placeholder_32, [1, 1, 1, 16], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_9 = T.match_buffer(T_cast_8, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_3 = T.allocate([1, 28, 28, 192], "int16", "global") for i0_i1_fused_3 in T.parallel(0, 28): for i2_3, i3_3 in T.grid(28, 192): T.store( PaddedInput_3, (((i0_i1_fused_3 * 5376) + (i2_3 * 192)) + i3_3), T.load("int16", placeholder_33.data, (((i0_i1_fused_3 * 5376) + (i2_3 * 192)) + i3_3)), True) for ax0_ax1_fused_ax2_fused_3 in T.parallel(0, 784): for ax3_2 in T.serial(0, 16): Conv2dOutput_3 = T.allocate([1, 1, 1, 1], "int32", "global") T.store(Conv2dOutput_3, 0, 0, True) for rc_3 in T.serial(0, 192): T.store(Conv2dOutput_3, 0, (T.load("int32", Conv2dOutput_3, 0) + (T.cast( T.load("int16", PaddedInput_3, ((ax0_ax1_fused_ax2_fused_3 * 192) + rc_3)), "int32") * T.cast( T.load("int16", placeholder_34.data, ((rc_3 * 16) + ax3_2)), "int32"))), True) T.store( T_cast_9.data, ((ax0_ax1_fused_ax2_fused_3 * 16) + ax3_2), T.cast( T.cast( T.max( T.min( T.q_multiply_shift( (T.load("int32", Conv2dOutput_3, 0) + T.load("int32", placeholder_35.data, ax3_2)), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True)