def main( placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4): with T.block("conv2d_NCHWc_int8"): ( n, oc_chunk, oh, ow, oc_block, kh, kw, ic_outer, ic_f_inner, ic_s_inner, ) = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9]) T.reads( placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], ) T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) with T.init(): conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0 conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[ n, oc_chunk, oh, ow, oc_block ] + T.cast( placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32" ) * T.cast( placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], "int32", )
def boolean_handling_after(a: T.Buffer[10, "int8"], b: T.Buffer[10, "int8"]) -> None: T.preflattened_buffer(a, [10], dtype="bool", data=a.data) T.preflattened_buffer(b, [10], dtype="bool", data=b.data) # body for i0 in T.serial(10): b[i0] = T.cast(T.cast(a[i0], "bool"), "int8")
def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast( placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", "tir.noalias": True }) placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") T_cast_1 = T.match_buffer(T_cast, [360000], dtype="int16") # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): T_cast_1[ ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast( T.cast( T.max( T.min( T.q_multiply_shift(T.cast( placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16")
def before(A: T.Buffer[16, "float32"]): for i in T.serial(16): x = T.cast( T.ceil(T.log2(T.cast(i + 1024 + 1, "float64"), dtype="float64"), dtype="float64"), dtype="int32", ) if x == 11: A[i] = 0.0
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast( placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True }) placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16") placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32") T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16") # body PaddedInput = T.allocate([360000], "int16", "global") for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): T.store( PaddedInput, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True) for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): Conv2dOutput = T.allocate([64], "int32", "global") for ff in T.serial(0, 64): T.store(Conv2dOutput, ff, 0, True) for rc in T.serial(0, 64): T.store( Conv2dOutput, ff, T.load("int32", Conv2dOutput, ff) + T.cast( T.load("int16", PaddedInput, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast( T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True) for ax3_inner_1 in T.serial(0, 64): T.store( T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast( T.cast( T.max( T.min( T.q_multiply_shift( T.load("int32", Conv2dOutput, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True)
def dp4a_desc( A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"), ) -> None: with T.block("root"): T.reads(C[0], A[0:4], B[0:4]) T.writes(C[0]) for i in range(0, 4): with T.block("update"): vi = T.axis.remap("R", [i]) C[0] = C[0] + T.cast(A[vi], "int32") * T.cast(B[vi], "int32")
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_( placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", "tir.noalias": True }) placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16") placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") T_add_1 = T.match_buffer(T_add, [1440000], dtype="int32") # body PaddedInput_2 = T.allocate([360000], "int16", "global") for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): PaddedInput_2[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] = placeholder_19[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): Conv2dOutput_2 = T.allocate([64], "int32", "global") for ax3_outer_1 in T.serial(0, 4): for ff_2 in T.serial(0, 64): Conv2dOutput_2[ff_2] = 0 for rc_2 in T.serial(0, 64): Conv2dOutput_2[ff_2] = Conv2dOutput_2[ff_2] + T.cast( PaddedInput_2[ax0_ax1_fused_ax2_fused_2 * 64 + rc_2], "int32") * T.cast( placeholder_20[rc_2 * 256 + ax3_outer_1 * 64 + ff_2], "int32") for ax3_inner_3 in T.serial(0, 64): T_add_1[ ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3] = T.q_multiply_shift(T.cast( T.cast( T.max( T.min( T.q_multiply_shift( Conv2dOutput_2[ax3_inner_3] + placeholder_21[ax3_outer_1 * 64 + ax3_inner_3], 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1( placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True }) placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") T_cast_5 = T.match_buffer(T_cast_4, [360000], dtype="int16") # body PaddedInput_1 = T.allocate([379456], "int16", "global") for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): PaddedInput_1[i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1] = T.if_then_else( 1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, placeholder_13[i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): Conv2dOutput_1 = T.allocate([64], "int32", "global") for ff_1 in T.serial(0, 64): Conv2dOutput_1[ff_1] = 0 for ry, rx, rc_1 in T.grid(3, 3, 64): Conv2dOutput_1[ff_1] = Conv2dOutput_1[ff_1] + T.cast( PaddedInput_1[ T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1], "int32") * T.cast( placeholder_14[ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1], "int32") for ax3_inner_2 in T.serial(0, 64): T_cast_5[ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2] = T.cast( T.cast( T.max( T.min( T.q_multiply_shift( Conv2dOutput_1[ax3_inner_2] + placeholder_15[ax3_inner_2], 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16")
def dot_product_16x4_u8i8i32_desc( A: T.Buffer((4,), "uint8", offset_factor=1), B: T.Buffer((16, 4), "int8", offset_factor=1), C: T.Buffer((16,), "int32", offset_factor=1), ) -> None: with T.block("root"): T.reads(C[0:16], A[0:4], B[0:16, 0:4]) T.writes(C[0:16]) for i in T.serial(0, 16): for k in T.serial(0, 4): with T.block("update"): vi, vk = T.axis.remap("SR", [i, k]) C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")
def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True }) placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dTpe="uint8", elem_offset=0, align=128, offset_factor=1) placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): T.store(T_subtract_1.data, (((ax0_ax1_fused_1 * 672) + (ax2_1 * 3)) + ax3_inner_1), (T.cast( T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1 * 672) + (ax2_1 * 3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True)
def main(placeholder: T.Buffer[(1, 384), "int64"], placeholder_1: T.Buffer[(30522, 768), "float32"], placeholder_2: T.Buffer[(1, 384, 768), "float32"], T_add: T.Buffer[(1, 384, 768), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") for i0, i1, i2 in T.grid(1, 384, 768): with T.block("T_add_1"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads( placeholder[ax0, ax1], placeholder_1[ T.min(T.max(T.int64(0), placeholder[ ax0, ax1]), T.int64(30521)):T.min( T.max(T.int64(0), placeholder[ax0, ax1] + T.int64(30522)), T.int64(30521)) + T.int64(1), ax2], placeholder_2[ax0, ax1, ax2]) T.writes(T_add[ax0, ax1, ax2]) T_add[ax0, ax1, ax2] = placeholder_1[T.min( T.max( T.int64(0), T.Select( T.cast(placeholder[ax0, ax1] < T.int64(0), "int32" ) != 0, placeholder[ax0, ax1] + T.int64(30522), placeholder[ax0, ax1]) ), T.int64(30521)), ax2] + placeholder_2[ax0, ax1, ax2]
def main( T_reshape: T.Buffer[(1, 12, 384, 384), "float32"], placeholder_1: T.Buffer[(T.int64(1), T.int64(12), T.int64(384), 384), "bool"], T_where: T.Buffer[(T.int64(1), T.int64(12), T.int64(384), 384), "float32"] ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") for i0_i1_i2_i3_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"): for i0_i1_i2_i3_fused_2 in T.thread_binding( T.int64(1024), thread="threadIdx.x"): for i0_i1_i2_i3_fused_0 in T.serial(T.int64(7)): with T.block("T_where"): ax0 = T.axis.spatial(T.int64(1), T.int64(0)) ax1 = T.axis.spatial( T.int64(12), ((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2) % T.int64(1769472) // T.int64(147456)) ax2 = T.axis.spatial( T.int64(384), ((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2) % T.int64(147456) // T.int64(384)) ax3 = T.axis.spatial( 384, T.cast(((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2) % T.int64(384), "int32")) T.where((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2 < T.int64(1769472)) T.reads(placeholder_1[ax0, ax1, ax2, ax3], T_reshape[ax0, ax1, ax2, ax3]) T.writes(T_where[ax0, ax1, ax2, ax3]) T_where[ax0, ax1, ax2, ax3] = T.Select( T.cast(placeholder_1[ax0, ax1, ax2, ax3], "int32") != 0, T.float32(-1000000000), T_reshape[ax0, ax1, ax2, ax3])
def unified_element_wise_thread_x_different_dtype( A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"], ) -> None: for blockIdx_x in T.thread_binding(128, "blockIdx.x"): for threadIdx_x in T.thread_binding(4, "threadIdx.x"): for j0_1 in T.serial(0, 32): with T.block(""): B[blockIdx_x, threadIdx_x * 32 + j0_1] = (A[blockIdx_x, threadIdx_x * 32 + j0_1] * 2.0) for j1_1 in T.serial(T.int64(32)): with T.block(""): C[blockIdx_x, T.cast(threadIdx_x, "int64") * T.int64(32) + j1_1] = (B[blockIdx_x, T.cast(threadIdx_x, "int64") * T.int64(32) + j1_1] + 1.0)
def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: T.handle, placeholder_31: T.handle, placeholder_32: T.handle, T_cast_8: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True}) placeholder_33 = T.match_buffer(placeholder_30, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_34 = T.match_buffer(placeholder_31, [3072], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_35 = T.match_buffer(placeholder_32, [16], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_9 = T.match_buffer(T_cast_8, [12544], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_3 = T.allocate([150528], "int16", "global") for i0_i1_fused_3 in T.parallel(0, 28): for i2_3, i3_3 in T.grid(28, 192): PaddedInput_3[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3) ] = placeholder_33[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)] for ax0_ax1_fused_ax2_fused_3 in T.parallel(0, 784): for ax3_2 in T.serial(0, 16): Conv2dOutput_3 = T.allocate([1], "int32", "global") Conv2dOutput_3[0] = 0 for rc_3 in T.serial(0, 192): Conv2dOutput_3[0] = (Conv2dOutput_3[0] + (T.cast(PaddedInput_3[((ax0_ax1_fused_ax2_fused_3*192) + rc_3)], "int32")*T.cast(placeholder_34[((rc_3*16) + ax3_2)], "int32"))) T_cast_9[((ax0_ax1_fused_ax2_fused_3*16) + ax3_2)] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_3[0] + placeholder_35[ax3_2]), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16")
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True}) placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16") placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") T_cast_3 = T.match_buffer(T_cast_2, [360000], dtype="int16") # body PaddedInput = T.allocate([360000], "int16", "global") for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): PaddedInput[i0_i1_fused * 4800 + i2 * 64 + i3] = placeholder_7[i0_i1_fused * 4800 + i2 * 64 + i3] for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): Conv2dOutput = T.allocate([64], "int32", "global") for ff in T.serial(0, 64): Conv2dOutput[ff] = 0 for rc in T.serial(0, 64): Conv2dOutput[ff] = Conv2dOutput[ff] + T.cast(PaddedInput[ax0_ax1_fused_ax2_fused * 64 + rc], "int32") * T.cast(placeholder_8[rc * 64 + ff], "int32") for ax3_inner_1 in T.serial(0, 64): T_cast_3[ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput[ax3_inner_1] + placeholder_9[ax3_inner_1], 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16")
def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): T_subtract_1[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)] = (T.cast(placeholder_4[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5[0])
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_21 = T.match_buffer(T_cast_20, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): PaddedInput_7[(((i0_i1_fused_7*687) + (i2_7*3)) + i3_7)] = T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): Conv2dOutput_7 = T.allocate([64], "int32", "global") for ff_3 in T.serial(0, 64): Conv2dOutput_7[ff_3] = 0 for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): Conv2dOutput_7[ff_3] = (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))) for ax3_inner_7 in T.serial(0, 64): T_cast_21[((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8")
def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True }) placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_2 = T.allocate([200704], "uint8", "global") for ax0_ax1_fused_4 in T.serial(0, 56): for ax2_4 in T.serial(0, 56): for ax3_init in T.serial(0, 64): T.store(tensor_2, (((ax0_ax1_fused_4 * 3584) + (ax2_4 * 64)) + ax3_init), T.uint8(0), True) for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): T.store( tensor_2, (((ax0_ax1_fused_4 * 3584) + (ax2_4 * 64)) + ax3_2), T.max( T.load("uint8", tensor_2, (((ax0_ax1_fused_4 * 3584) + (ax2_4 * 64)) + ax3_2)), T.if_then_else( ((((ax0_ax1_fused_4 * 2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4 * 2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, ((( ((ax0_ax1_fused_4 * 14336) + (T.floordiv(rv0_rv1_fused_1, 3) * 7168)) + (ax2_4 * 128)) + (T.floormod( rv0_rv1_fused_1, 3) * 64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) for ax0_ax1_fused_5 in T.serial(0, 56): for ax2_5, ax3_3 in T.grid(56, 64): T.store( T_cast_7.data, (((ax0_ax1_fused_5 * 3584) + (ax2_5 * 64)) + ax3_3), T.cast( T.load("uint8", tensor_2, (((ax0_ax1_fused_5 * 3584) + (ax2_5 * 64)) + ax3_3)), "int16"), True)
def main( placeholder: T.Buffer[(1024, 1024), "uint8"], placeholder_1: T.Buffer[(64, 256, 16, 4), "int8"], compute: T.Buffer[(1024, 1024), "int32"], ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) with T.block("root"): T.reads() T.writes() for i0, i1, i2 in T.grid(1024, 1024, 1024): with T.block("compute"): i, j, k = T.axis.remap("SSR", [i0, i1, i2]) T.reads(placeholder[i, k], placeholder_1[j // 16, k // 4, j % 16, k % 4]) T.writes(compute[i, j]) with T.init(): compute[i, j] = 0 compute[i, j] = compute[i, j] + T.cast( placeholder[i, k], "int32") * T.cast( placeholder_1[j // 16, k // 4, j % 16, k % 4], "int32")
def tir_packed_call() -> None: A = T.var("handle") B = T.var("handle") C = T.var("handle") device_context = T.var("handle") # body T.evaluate( T.tvm_call_cpacked( "tvm_test_cpacked", T.tvm_stack_make_array( A, T.tvm_stack_make_shape(1, dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"), T.uint32(1), T.cast(0, dtype="float32"), 0, dtype="handle", ), T.tvm_stack_make_array( B, T.tvm_stack_make_shape(1, dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"), T.uint32(1), T.cast(0, dtype="float32"), 0, dtype="handle", ), T.tvm_stack_make_array( C, T.tvm_stack_make_shape(1, dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"), T.uint32(1), T.cast(0, dtype="float32"), 0, dtype="handle", ), device_context, dtype="int32", ))
def main( placeholder: T.Buffer[(1024, 1024), "uint8"], placeholder_1: T.Buffer[(64, 256, 16, 4), "int8"], compute: T.Buffer[(1024, 1024), "int32"], ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") for i0, i1_0, i2_0, i1_1, i2_1 in T.grid(1024, 64, 256, 16, 4): with T.block("compute"): i = T.axis.spatial(1024, i0) j = T.axis.spatial(1024, i1_0 * 16 + i1_1) k = T.axis.reduce(1024, i2_0 * 4 + i2_1) T.reads(placeholder[i, k], placeholder_1[j // 16, k // 4, j % 16, k % 4]) T.writes(compute[i, j]) with T.init(): compute[i, j] = 0 compute[i, j] = compute[i, j] + T.cast( placeholder[i, k], "int32") * T.cast( placeholder_1[j // 16, k // 4, j % 16, k % 4], "int32")
def main( placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") for i0, i1, i2, i3, i4_0, i5, i6, i7, i8, i9_0, i4_1, i9_1 in T.grid( 1, 16, 56, 56, 1, 1, 1, 4, 4, 1, 16, 4): with T.block("conv2d_NCHWc_int8"): n = T.axis.spatial(1, 0) oc_chunk, oh, ow, oc_block = T.axis.remap( "SSSS", [i1, i2, i3, i4_1]) kh = T.axis.reduce(1, 0) kw = T.axis.reduce(1, 0) ic_outer, ic_f_inner, ic_s_inner = T.axis.remap( "RRR", [i7, i8, i9_1]) T.reads( placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], ) T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) with T.init(): conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0 conv2d_NCHWc_int8[ n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[ n, oc_chunk, oh, ow, oc_block] + T.cast( placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32" ) * T.cast( placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], "int32", )
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", "tir.noalias": True}) placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16") placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16") placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32") placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32") T_cast_7 = T.match_buffer(T_cast_6, [1440000], dtype="uint8") # body PaddedInput_3 = T.allocate([360000], "int16", "global") for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): PaddedInput_3[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] = placeholder_29[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): Conv2dOutput_3 = T.allocate([64], "int32", "global") for ax3_outer_2 in T.serial(0, 4): for ff_3 in T.serial(0, 64): Conv2dOutput_3[ff_3] = 0 for rc_3 in T.serial(0, 64): Conv2dOutput_3[ff_3] = Conv2dOutput_3[ff_3] + T.cast(PaddedInput_3[ax0_ax1_fused_ax2_fused_3 * 64 + rc_3], "int32") * T.cast(placeholder_27[rc_3 * 256 + ax3_outer_2 * 64 + ff_3], "int32") for ax3_inner_4 in T.serial(0, 64): T_cast_7[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4] = T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_3[ax3_inner_4] + placeholder_26[ax3_outer_2 * 64 + ax3_inner_4], 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + placeholder_28[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4], 255), 0), "uint8")
def constant_binds_wrapped(): x = T.int32(1) y = T.float32(42.0) T.evaluate(T.cast(x, "float32") + y)
def constant_binds(): x = 1 y = 42.0 T.evaluate(T.cast(x, "float32") + y)
def before(A: T.Buffer[(4, 4), "float32"]): for i, j in T.grid(4, 4): x = T.var("float32") A[i, j] = T.Let(x, T.cast(i + 1, "float32"), 5.0 * x + T.cast(j, "float32"))
def expected(A: T.Buffer[(4, 4), "float32"]): for i in T.serial(4): x = T.cast(i + 1, "float32") for j in T.serial(4): A[i, j] = 5.0 * x + T.cast(j, "float32")
def main(placeholder: T.Buffer[(1, 13, 13, 3, 85), "float32"], placeholder_1: T.Buffer[(1, 26, 26, 3, 85), "float32"], placeholder_2: T.Buffer[(1, 52, 52, 3, 85), "float32"], T_expand_dims: T.Buffer[(1, 80, 10647), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") T_strided_slice_with_axes = T.alloc_buffer([1, 52, 52, 3, 1], dtype="float32") T_sigmoid = T.alloc_buffer([1, 52, 52, 3, 1], dtype="float32") T_strided_slice_with_axes_1 = T.alloc_buffer([1, 52, 52, 3, 80], dtype="float32") T_sigmoid_1 = T.alloc_buffer([1, 52, 52, 3, 80], dtype="float32") T_multiply = T.alloc_buffer([1, 52, 52, 3, 80], dtype="float32") T_reshape = T.alloc_buffer([8112, 80], dtype="float32") T_strided_slice_with_axes_2 = T.alloc_buffer([1, 26, 26, 3, 1], dtype="float32") T_sigmoid_2 = T.alloc_buffer([1, 26, 26, 3, 1], dtype="float32") T_strided_slice_with_axes_3 = T.alloc_buffer([1, 26, 26, 3, 80], dtype="float32") T_sigmoid_3 = T.alloc_buffer([1, 26, 26, 3, 80], dtype="float32") T_multiply_1 = T.alloc_buffer([1, 26, 26, 3, 80], dtype="float32") T_reshape_1 = T.alloc_buffer([2028, 80], dtype="float32") T_strided_slice_with_axes_4 = T.alloc_buffer([1, 13, 13, 3, 1], dtype="float32") T_sigmoid_4 = T.alloc_buffer([1, 13, 13, 3, 1], dtype="float32") T_strided_slice_with_axes_5 = T.alloc_buffer([1, 13, 13, 3, 80], dtype="float32") T_sigmoid_5 = T.alloc_buffer([1, 13, 13, 3, 80], dtype="float32") T_multiply_2 = T.alloc_buffer([1, 13, 13, 3, 80], dtype="float32") T_reshape_2 = T.alloc_buffer([507, 80], dtype="float32") T_concat = T.alloc_buffer([10647, 80], dtype="float32") T_transpose = T.alloc_buffer([80, 10647], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 1): with T.block("T_strided_slice_with_axes"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)]) T.writes(T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4]) T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4] = placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)] for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 1): with T.block("T_sigmoid"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4]) T.writes(T_sigmoid[ax0, ax1, ax2, ax3, ax4]) T_sigmoid[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 80): with T.block("T_strided_slice_with_axes_1"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)]) T.writes(T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4]) T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4] = placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)] for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 80): with T.block("T_sigmoid_1"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4]) T.writes(T_sigmoid_1[ax0, ax1, ax2, ax3, ax4]) T_sigmoid_1[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 80): with T.block("T_multiply"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_sigmoid[ax0, ax1, ax2, ax3, 0], T_sigmoid_1[ax0, ax1, ax2, ax3, ax4]) T.writes(T_multiply[ax0, ax1, ax2, ax3, ax4]) T_multiply[ax0, ax1, ax2, ax3, ax4] = T_sigmoid[ax0, ax1, ax2, ax3, 0] * T_sigmoid_1[ax0, ax1, ax2, ax3, ax4] for i0, i1 in T.grid(8112, 80): with T.block("T_reshape"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(T_multiply[0, (ax1 // 80 + ax0) % 8112 // 156, (ax1 // 80 + ax0) % 156 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80]) T.writes(T_reshape[ax0, ax1]) T_reshape[ax0, ax1] = T_multiply[0, (ax1 // 80 + ax0) % 8112 // 156, (ax1 // 80 + ax0) % 156 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80] for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 1): with T.block("T_strided_slice_with_axes_2"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)]) T.writes(T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4]) T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4] = placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)] for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 1): with T.block("T_sigmoid_2"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4]) T.writes(T_sigmoid_2[ax0, ax1, ax2, ax3, ax4]) T_sigmoid_2[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 80): with T.block("T_strided_slice_with_axes_3"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)]) T.writes(T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4]) T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4] = placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)] for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 80): with T.block("T_sigmoid_3"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4]) T.writes(T_sigmoid_3[ax0, ax1, ax2, ax3, ax4]) T_sigmoid_3[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 80): with T.block("T_multiply_1"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_sigmoid_2[ax0, ax1, ax2, ax3, 0], T_sigmoid_3[ax0, ax1, ax2, ax3, ax4]) T.writes(T_multiply_1[ax0, ax1, ax2, ax3, ax4]) T_multiply_1[ax0, ax1, ax2, ax3, ax4] = T_sigmoid_2[ax0, ax1, ax2, ax3, 0] * T_sigmoid_3[ax0, ax1, ax2, ax3, ax4] for i0, i1 in T.grid(2028, 80): with T.block("T_reshape_1"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(T_multiply_1[0, (ax1 // 80 + ax0) % 2028 // 78, (ax1 // 80 + ax0) % 78 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80]) T.writes(T_reshape_1[ax0, ax1]) T_reshape_1[ax0, ax1] = T_multiply_1[0, (ax1 // 80 + ax0) % 2028 // 78, (ax1 // 80 + ax0) % 78 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80] for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 1): with T.block("T_strided_slice_with_axes_4"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)]) T.writes(T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4]) T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4] = placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)] for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 1): with T.block("T_sigmoid_4"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4]) T.writes(T_sigmoid_4[ax0, ax1, ax2, ax3, ax4]) T_sigmoid_4[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 80): with T.block("T_strided_slice_with_axes_5"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)]) T.writes(T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4]) T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4] = placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)] for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 80): with T.block("T_sigmoid_5"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4]) T.writes(T_sigmoid_5[ax0, ax1, ax2, ax3, ax4]) T_sigmoid_5[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 80): with T.block("T_multiply_2"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(T_sigmoid_4[ax0, ax1, ax2, ax3, 0], T_sigmoid_5[ax0, ax1, ax2, ax3, ax4]) T.writes(T_multiply_2[ax0, ax1, ax2, ax3, ax4]) T_multiply_2[ax0, ax1, ax2, ax3, ax4] = T_sigmoid_4[ax0, ax1, ax2, ax3, 0] * T_sigmoid_5[ax0, ax1, ax2, ax3, ax4] for i0, i1 in T.grid(507, 80): with T.block("T_reshape_2"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(T_multiply_2[0, (ax1 // 80 + ax0) % 507 // 39, (ax1 // 80 + ax0) % 39 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80]) T.writes(T_reshape_2[ax0, ax1]) T_reshape_2[ax0, ax1] = T_multiply_2[0, (ax1 // 80 + ax0) % 507 // 39, (ax1 // 80 + ax0) % 39 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80] for i0, i1 in T.grid(10647, 80): with T.block("T_concat"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(T_reshape[ax0 - 2535, ax1], T_reshape_1[ax0 - 507, ax1], T_reshape_2[ax0, ax1]) T.writes(T_concat[ax0, ax1]) T_concat[ax0, ax1] = T.if_then_else(2535 <= ax0, T_reshape[ax0 - 2535, ax1], T.if_then_else(507 <= ax0, T_reshape_1[ax0 - 507, ax1], T_reshape_2[ax0, ax1], dtype="float32"), dtype="float32") for i0, i1 in T.grid(80, 10647): with T.block("T_transpose"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(T_concat[ax1, ax0]) T.writes(T_transpose[ax0, ax1]) T_transpose[ax0, ax1] = T_concat[ax1, ax0] for i0, i1, i2 in T.grid(1, 80, 10647): with T.block("T_expand_dims"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_transpose[ax1, ax2]) T.writes(T_expand_dims[ax0, ax1, ax2]) T_expand_dims[ax0, ax1, ax2] = T_transpose[ax1, ax2]
def before(A: T.Buffer[1, "int32"]): A[0] = T.cast(T.ceil(T.log2(T.cast(14, "float64"), dtype="float64"), dtype="float64"), dtype="int32")
def main( X: T.Buffer[(128, 128), "int8"], W: T.Buffer[(128, 128), "int8"], compute: T.Buffer[(128, 128), "int32"], ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) compute_local = T.alloc_buffer([128, 128], dtype="int32", scope="local") X_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared") W_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared") for i0_0_i1_0_fused in T.thread_binding(16, thread="blockIdx.x"): for i0_1_i1_1_fused in T.thread_binding(2, thread="vthread.x"): for i0_2_i1_2_fused in T.thread_binding(2, thread="threadIdx.x"): for i2_0_0 in T.serial(2): for ax0_ax1_fused in T.serial(1024): with T.block("X_shared"): v0 = T.axis.spatial( 128, i0_0_i1_0_fused // 2 * 16 + ax0_ax1_fused // 64) v1 = T.axis.spatial( 128, i2_0_0 * 64 + ax0_ax1_fused % 64) T.reads(X[v0, v1]) T.writes(X_shared[v0, v1]) T.block_attr( {"meta_schedule.cooperative_fetch": 4}) X_shared[v0, v1] = X[v0, v1] for ax0_ax1_fused in T.serial(4096): with T.block("W_shared"): v0 = T.axis.spatial( 128, i0_0_i1_0_fused % 2 * 64 + ax0_ax1_fused // 64) v1 = T.axis.spatial( 128, i2_0_0 * 64 + ax0_ax1_fused % 64) T.reads(W[v0, v1]) T.writes(W_shared[v0, v1]) T.block_attr( {"meta_schedule.cooperative_fetch": 1}) W_shared[v0, v1] = W[v0, v1] for i2_0_1, i0_3, i1_3, i2_0_2, i0_4, i1_4 in T.grid( 2, 4, 16, 8, 4, 1): with T.block("compute_o"): i = T.axis.spatial( 128, i0_0_i1_0_fused // 2 * 16 + i0_3 * 4 + i0_4) j = T.axis.spatial( 128, i0_0_i1_0_fused % 2 * 64 + i0_1_i1_1_fused * 32 + i0_2_i1_2_fused * 16 + i1_3, ) k_o = T.axis.reduce( 32, i2_0_0 * 16 + i2_0_1 * 8 + i2_0_2) T.reads( X_shared[i, k_o * 4:k_o * 4 + 4], W_shared[j, k_o * 4:k_o * 4 + 4], ) T.writes(compute_local[i, j]) T.block_attr( {"meta_schedule.auto_tensorize": "dp4a"}) with T.init(): with T.block("compute_init"): T.reads() T.writes(compute_local[i, j]) compute_local[i, j] = 0 for i2_1 in T.serial(4): with T.block("compute"): k = T.axis.reduce(4, i2_1) T.reads( compute_local[i, j], X_shared[i, k_o * 4 + k], W_shared[j, k_o * 4 + k], ) T.writes(compute_local[i, j]) T.block_attr({ "meta_schedule.tiling_structure": "SSSRRSRS" }) compute_local[ i, j] = compute_local[i, j] + T.cast( X_shared[i, k_o * 4 + k], "int32") * T.cast( W_shared[j, k_o * 4 + k], "int32") for ax0, ax1 in T.grid(16, 16): with T.block("compute_local"): v0 = T.axis.spatial( 128, i0_0_i1_0_fused // 2 * 16 + ax0) v1 = T.axis.spatial( 128, i0_0_i1_0_fused % 2 * 64 + i0_1_i1_1_fused * 32 + i0_2_i1_2_fused * 16 + ax1, ) T.reads(compute_local[v0, v1]) T.writes(compute[v0, v1]) compute[v0, v1] = compute_local[v0, v1]