def primfunc_local_allocates(placeholder_162: T.handle, placeholder_163: T.handle, placeholder_164: T.handle, T_cast_76: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "fused_nn_conv2d_add_cast_fixed_point_multiply_clip_cast_cast_9", "tir.noalias": True}) placeholder_165 = T.match_buffer(placeholder_162, [100352], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_166 = T.match_buffer(placeholder_163, [4608], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_167 = T.match_buffer(placeholder_164, [512], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_77 = T.match_buffer(T_cast_76, [100352], dtype="int16", elem_offset=0, align=128, offset_factor=1) sid_21 = T.allocate_const([0,1,2,3,4,5,6,7], "int8", [8]) # body PaddedInput_25 = T.allocate([131072], "int16", "global") for i1_35, i2_46, i3_47 in T.grid(16, 16, 512): PaddedInput_25[(((i1_35*8192) + (i2_46*512)) + i3_47)] = T.if_then_else(((((1 <= i1_35) and (i1_35 < 15)) and (1 <= i2_46)) and (i2_46 < 15)), placeholder_165[((((i1_35*7168) + (i2_46*512)) + i3_47) - 7680)], T.int16(0), dtype="int16") T_add_11 = T.allocate([100352], "int32", "global") with T.allocate([100352], "int32", "global") as DepthwiseConv2d_11: for i_11, j_11, c_11 in T.grid(14, 14, 512): DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] = 0 for di_11, dj_11 in T.grid(3, 3): DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] = (DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] + (PaddedInput_25[(((((i_11*8192) + (di_11*8192)) + (j_11*512)) + (dj_11*512)) + c_11)].astype("int32")*placeholder_166[(((di_11*1536) + (dj_11*512)) + c_11)].astype("int32"))) for ax1_44, ax2_45, ax3_47 in T.grid(14, 14, 512): T_add_11[(((ax1_44*7168) + (ax2_45*512)) + ax3_47)] = (DepthwiseConv2d_11[(((ax1_44*7168) + (ax2_45*512)) + ax3_47)] + placeholder_167[ax3_47]) compute_22 = T.allocate([100352], "int32", "global") with T.allocate([100352], "int32", "global") as T_cast_78: for ax1_45, ax2_46, ax3_48 in T.grid(14, 14, 512): T_cast_78[(((ax1_45*7168) + (ax2_46*512)) + ax3_48)] = T_add_11[(((ax1_45*7168) + (ax2_46*512)) + ax3_48)] for i1_36, i2_47, i3_48 in T.grid(14, 14, 512): compute_22[(((i1_36*7168) + (i2_47*512)) + i3_48)] = T.q_multiply_shift(T_cast_78[(((i1_36*7168) + (i2_47*512)) + i3_48)], 1948805937, 31, -5, dtype="int32") T_cast_79 = T.allocate([100352], "uint8", "global") with T.allocate([100352], "int32", "global") as compute_23: for i1_37, i2_48, i3_49 in T.grid(14, 14, 512): compute_23[(((i1_37*7168) + (i2_48*512)) + i3_49)] = T.max(T.max(compute_22[(((i1_37*7168) + (i2_48*512)) + i3_49)], 255), 0) for ax1_46, ax2_47, ax3_49 in T.grid(14, 14, 512): T_cast_79[(((ax1_46*7168) + (ax2_47*512)) + ax3_49)] = compute_23[(((ax1_46*7168) + (ax2_47*512)) + ax3_49)].astype("uint8") for ax1_47, ax2_48, ax3_50 in T.grid(14, 14, 512): T_cast_77[(((ax1_47*7168) + (ax2_48*512)) + ax3_50)] = T_cast_79[(((ax1_47*7168) + (ax2_48*512)) + ax3_50)].astype("int16")
def conv2d_nhwc_transformed( Input: T.Buffer[(1, 224, 224, 3), "float32"], Weight: T.Buffer[(7, 7, 3, 64), "float32"], Conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"], ) -> None: PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): with T.block("PadInput"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(Input[i0_1, i1_1 - 3, i2_1 - 3, i3_1]) T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( i1_1 >= 3 and i1_1 < 227 and i2_1 >= 3 and i2_1 < 227, Input[i0_1, i1_1 - 3, i2_1 - 3, i3_1], T.float32(0), dtype="float32", ) for ax0, ax_1, ax_2 in T.grid(12544, 64, 147): with T.block("conv2d_nhwc"): bv0, bv1, bv2 = T.axis.remap("SSR", [ax0, ax_1, ax_2]) T.reads( PadInput[0, bv0 // 112 * 2 + bv2 // 21, bv0 % 112 * 2 + bv2 % 21 // 3, bv2 % 3], Weight[bv2 // 21, bv2 % 21 // 3, bv2 % 3, bv1], ) T.writes(Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1]) with T.init(): Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1] = T.float32(0) Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1] = ( Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1] + PadInput[0, bv0 // 112 * 2 + bv2 // 21, bv0 % 112 * 2 + bv2 % 21 // 3, bv2 % 3] * Weight[bv2 // 21, bv2 % 21 // 3, bv2 % 3, bv1])
def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "float32"], conv2d_transpose_nhwc: T.Buffer[(16384,), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # var definition threadIdx_x = T.env_thread("threadIdx.x") blockIdx_x = T.env_thread("blockIdx.x") T.preflattened_buffer(inputs, [1, 4, 4, 512], dtype="float32", data=inputs.data) T.preflattened_buffer(weight, [4, 4, 512, 256], dtype="float32", data=weight.data) T.preflattened_buffer(conv2d_transpose_nhwc, [1, 8, 8, 256], dtype="float32", data=conv2d_transpose_nhwc.data) # body T.launch_thread(blockIdx_x, 64) conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local") PadInput_shared = T.allocate([768], "float32", "shared") weight_shared = T.allocate([4096], "float32", "shared") T.launch_thread(threadIdx_x, 32) for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2): conv2d_transpose_nhwc_local[i1_4_init * 4 + i2_3_init * 2 + i2_4_init] = T.float32(0) for i6_0 in T.serial(16): for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(4 <= ax0_ax1_ax2_ax3_fused_0 and ax0_ax1_ax2_ax3_fused_0 < 20 and 1 <= blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 and blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 < 5, inputs[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32") for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight[T.ramp(ax0_ax1_ax2_ax3_fused_0 // 2 * 131072 + i6_0 * 8192 + ax0_ax1_ax2_ax3_fused_0 % 2 * 4096 + threadIdx_x // 2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)] for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] = conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, PadInput_shared[threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2], T.float32(0), dtype="float32") * weight_shared[i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024] for ax1, ax2 in T.grid(2, 4): conv2d_transpose_nhwc[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2]
def non_perfect_tiling_cache(a: T.handle, b: T.handle) -> None: X = T.match_buffer(a, [224, 224], dtype="float32") Y = T.match_buffer(b, [224, 224], dtype="float32") cache = T.alloc_buffer([224, 224], dtype="float32") for hh_0, ww_0 in T.grid(28, 28): for ax0 in T.serial(0, 10): for ax1 in T.serial(0, 10): with T.block("cache"): h = T.axis.spatial(224, hh_0 * 8 - 1 + ax0) w = T.axis.spatial(224, ww_0 * 8 - 1 + ax1) T.where(1 <= hh_0 * 8 + ax0 and hh_0 * 8 + ax0 < 225 and 1 <= ww_0 * 8 + ax1 and ww_0 * 8 + ax1 < 225) cache[h, w] = X[h, w] for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3): with T.block("compute"): h = T.axis.spatial(224, hh_0 * 8 + hh_1) w = T.axis.spatial(224, ww_0 * 8 + ww_1) kh, kw = T.axis.remap("RR", [khh, kww]) with T.init(): Y[h, w] = 0.0 Y[h, w] = T.max( Y[h, w], T.if_then_else( T.likely(1 <= h + kh, dtype="bool") and T.likely(h + kh < 225, dtype="bool") and T.likely(1 <= w + kw, dtype="bool") and T.likely(w + kw < 225, dtype="bool"), cache[h + kh - 1, w + kw - 1], 0.0, dtype="float32", ), )
def conv2d_nhwc( Input: T.Buffer[(1, 224, 224, 3), "float32"], Weight: T.Buffer[(7, 7, 3, 64), "float32"], Conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"], ) -> None: PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): with T.block("PadInput"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( ((((i1_1 >= 3) and (i1_1 < 227)) and (i2_1 >= 3)) and (i2_1 < 227)), Input[i0_1, (i1_1 - 3), (i2_1 - 3), i3_1], T.float32(0), dtype="float32", ) for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3): with T.block("conv2d_nhwc"): n, h, w, co, rh, rw, rc = T.axis.remap( "SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) with T.init(): Conv2d_nhwc[n, h, w, co] = T.float32(0) Conv2d_nhwc[n, h, w, co] = Conv2d_nhwc[n, h, w, co] + ( PadInput[n, ((h * 2) + rh), ((w * 2) + rw), ( (T.floordiv(co, 64) * 3) + rc)] * Weight[rh, rw, rc, co])
def primfunc_global_allocates(placeholder_144: T.handle, placeholder_145: T.handle, placeholder_146: T.handle, T_cast_48: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "fused_nn_conv2d_add_cast_fixed_point_multiply_clip_cast_cast_13", "tir.noalias": True}) placeholder_147 = T.match_buffer(placeholder_144, [100352], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_148 = T.match_buffer(placeholder_145, [4608], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_149 = T.match_buffer(placeholder_146, [512], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_49 = T.match_buffer(T_cast_48, [100352], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_22 = T.allocate([131072], "int16", "global") DepthwiseConv2d_9 = T.allocate([100352], "int32", "global") for i1_29, i2_39, i3_40 in T.grid(16, 16, 512): PaddedInput_22[(((i1_29*8192) + (i2_39*512)) + i3_40)] = T.if_then_else(((((1 <= i1_29) and (i1_29 < 15)) and (1 <= i2_39)) and (i2_39 < 15)), placeholder_147[((((i1_29*7168) + (i2_39*512)) + i3_40) - 7680)], T.int16(0), dtype="int16") for i_9, j_9, c_9 in T.grid(14, 14, 512): DepthwiseConv2d_9[(((i_9*7168) + (j_9*512)) + c_9)] = 0 for di_9, dj_9 in T.grid(3, 3): DepthwiseConv2d_9[(((i_9*7168) + (j_9*512)) + c_9)] = (DepthwiseConv2d_9[(((i_9*7168) + (j_9*512)) + c_9)] + (PaddedInput_22[(((((i_9*8192) + (di_9*8192)) + (j_9*512)) + (dj_9*512)) + c_9)].astype("int32")*placeholder_148[(((di_9*1536) + (dj_9*512)) + c_9)].astype("int32"))) for ax1_27, ax2_28, ax3_30 in T.grid(14, 14, 512): DepthwiseConv2d_9[(((ax1_27*7168) + (ax2_28*512)) + ax3_30)] = (DepthwiseConv2d_9[(((ax1_27*7168) + (ax2_28*512)) + ax3_30)] + placeholder_149[ax3_30]) for i1_30, i2_40, i3_41 in T.grid(14, 14, 512): DepthwiseConv2d_9[(((i1_30*7168) + (i2_40*512)) + i3_41)] = T.q_multiply_shift(DepthwiseConv2d_9[(((i1_30*7168) + (i2_40*512)) + i3_41)], 1269068532, 31, -4, dtype="int32") for i1_31, i2_41, i3_42 in T.grid(14, 14, 512): DepthwiseConv2d_9[(((i1_31*7168) + (i2_41*512)) + i3_42)] = T.max(T.max(DepthwiseConv2d_9[(((i1_31*7168) + (i2_41*512)) + i3_42)], 255), 0) for ax1_28, ax2_29, ax3_31 in T.grid(14, 14, 512): PaddedInput_22[(((ax1_28*7168) + (ax2_29*512)) + ax3_31)] = DepthwiseConv2d_9[(((ax1_28*7168) + (ax2_29*512)) + ax3_31)].astype("uint8") for ax1_29, ax2_30, ax3_32 in T.grid(14, 14, 512): T_cast_49[(((ax1_29*7168) + (ax2_30*512)) + ax3_32)] = PaddedInput_22[(((ax1_29*7168) + (ax2_30*512)) + ax3_32)].astype("int16")
def access_of_padding_pattern() -> None: X = T.alloc_buffer([28, 28]) X_pad = T.alloc_buffer([32, 32]) Y = T.alloc_buffer([28, 28]) for i, j in T.grid(32, 32): with T.block("padding"): vi, vj = T.axis.remap("SS", [i, j]) T.reads([ X[T.max(vi - 2, 0):T.min(vi - 2, 27) + 1, T.max(vj - 2, 0):T.min(vj - 2, 27) + 1, ] ]) T.writes([X_pad[vi, vj]]) X_pad[vi, vj] = T.if_then_else(2 <= vi and vi < 30 and 2 <= vj and vj < 30, X[vi - 2, vj - 2], 0.0, dtype="float32") with T.block("padding_reverse"): vi, vj = T.axis.remap("SS", [i, j]) T.reads([ X_pad[T.max(vi, 2):T.min(vi, 29) + 1, T.max(vj, 2):T.min(vj, 29) + 1] ]) T.writes([ Y[T.max(vi - 2, 0):T.min(vi - 2, 27) + 1, T.max(vj - 2, 0):T.min(vj - 2, 27) + 1, ] ]) if 2 <= vi and vi < 30 and 2 <= vj and vj < 30: Y[vi - 2, vj - 2] = X_pad[vi, vj]
def tiled_conv2d_with_padding( inputs: T.Buffer[(1, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7, 3, 64), "float32"], conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"], ) -> None: PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): with T.block("PadInput"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1]) T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( 3 <= i1_1 and i1_1 < 227 and 3 <= i2_1 and i2_1 < 227, inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1], T.float32(0), dtype="float32", ) for ( i0_0, i1_0, i2_0, i3_0, i0_1_1, i1_1_1, i2_1_1, i3_1_1, i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, ) in T.grid(1, 1, 4, 1, 1, 2, 4, 1, 7, 7, 1, 1, 1, 1, 1, 1, 1, 3, 1, 56, 7, 64): with T.block("conv2d_nhwc"): n = T.axis.spatial(1, 0) h = T.axis.spatial(112, i1_1_1 * 56 + i1_3) w = T.axis.spatial(112, i2_0 * 28 + i2_1_1 * 7 + i2_3) co, rh, rw, rc = T.axis.remap("SRRR", [i3_3, i4_0, i5_0, i6_1]) T.reads( conv2d_nhwc[n, h, w, co], PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rh, rw, rc, co], ) T.writes(conv2d_nhwc[n, h, w, co]) with T.init(): conv2d_nhwc[n, h, w, co] = T.float32(0) conv2d_nhwc[n, h, w, co] = ( conv2d_nhwc[n, h, w, co] + PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc] * weight[rh, rw, rc, co])
def before_decompose(x: T.Buffer[128, "int32"], y: T.Buffer[140, "int32"]): for i in range(140): with T.block("block"): vi = T.axis.remap("S", [i]) y[vi] = T.if_then_else(vi >= 6 and vi < 134, x[vi - 6], 0, dtype="int32")
def access_in_if_then_else_func() -> None: A = T.alloc_buffer([8]) B = T.alloc_buffer([8]) with T.block(): T.reads([A[0:5]]) T.writes([B[0:8]]) for i in T.serial(0, 8): B[i] = T.if_then_else(i < 5, A[i], 0.0, dtype="float32")
def main(X: T.Buffer[(1, 512, 56, 56), "float32"], W: T.Buffer[(512, 512, 3, 3), "float32"], B: T.Buffer[(512, 1, 1), "float32"], bn_scale: T.Buffer[(512, 1, 1), "float32"], bn_offset: T.Buffer[(512, 1, 1), "float32"], compute: T.Buffer[(1, 512, 56, 56), "float32"]) -> None: compute_local = T.alloc_buffer([1, 512, 56, 56], dtype="float32", scope="local") for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(224, thread="blockIdx.x"): for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding( 2, thread="vthread.x"): for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding( 8, thread="threadIdx.x"): for i4_0, i5_0, i6_0, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid( 1, 3, 1, 32, 1, 1, 1, 1, 1, 1, 16, 1, 3, 1, 8, 2, 28): with T.block("compute"): nn = T.axis.spatial(1, 0) ff = T.axis.spatial( 512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + i1_4) yy = T.axis.spatial( 56, i0_0_i1_0_i2_0_i3_0_fused // 2 % 7 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + i2_4) xx = T.axis.spatial( 56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + i3_4) rc = T.axis.reduce(512, i4_1 * 16 + i4_2) ry, rx = T.axis.remap("RR", [i5_0, i6_2]) with T.init(): compute_local[nn, ff, yy, xx] = T.float32(0) compute_local[nn, ff, yy, xx] = compute_local[ nn, ff, yy, xx] + T.if_then_else( yy + ry >= 1 and yy + ry < 57 and xx + rx >= 1 and xx + rx < 57, X[nn, rc, yy + ry - 1, xx + rx - 1], T.float32(0), dtype="float32") * W[ff, rc, ry, rx] for ax0, ax1, ax2, ax3 in T.grid(1, 8, 2, 28): with T.block("compute_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial( 512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + ax1) v2 = T.axis.spatial( 56, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + ax2) v3 = T.axis.spatial( 56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + ax3) compute[v0, v1, v2, v3] = T.max( (compute_local[v0, v1, v2, v3] + B[v1, 0, 0]) * bn_scale[v1, 0, 0] + bn_offset[v1, 0, 0], T.float32(0))
def main(a: T.handle, b: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "T.noalias": True}) # var definition threadIdx_x = T.env_thread("threadIdx.x") threadIdx_y = T.env_thread("threadIdx.y") blockIdx_x = T.env_thread("blockIdx.x") blockIdx_y = T.env_thread("blockIdx.y") blockIdx_z = T.env_thread("blockIdx.z") A = T.match_buffer(a, [14 * 14 * 256 * 256], dtype="float32") B = T.match_buffer(b, [14 * 14 * 512 * 256], dtype="float32") # body T.launch_thread(blockIdx_z, 196) B_local = T.allocate([64], "float32", "local") Apad_shared = T.allocate([512000], "float32", "shared") Apad_shared_local = T.allocate([8], "float32", "local") T.launch_thread(blockIdx_y, 8) T.launch_thread(blockIdx_x, 4) T.launch_thread(threadIdx_y, 8) T.launch_thread(threadIdx_x, 8) for ff_c_init, nn_c_init in T.grid(8, 8): B_local[ff_c_init * 8 + nn_c_init] = T.float32(0) for rc_outer, ry, rx in T.grid(32, 3, 3): for ax3_inner_outer in T.serial(0, 2): Apad_shared[T.ramp( threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4)] = T.if_then_else( 1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, A[T.ramp( ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4)], T.broadcast(T.float32(0), 4), dtype="float32x4", ) # Access of the last element of Apad_shared prevents # buffer compacting from reducing the amount of shared # memory used. Apad_shared[512000 - 1] = 0.0 for rc_inner in T.serial(0, 8): for ax3 in T.serial(0, 8): Apad_shared_local[ax3] = Apad_shared[rc_inner * 64 + threadIdx_x * 8 + ax3] for ff_c, nn_c in T.grid(8, 8): B_local[ff_c * 8 + nn_c] = B_local[ff_c * 8 + nn_c] + Apad_shared_local[nn_c] for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): B[blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner] = B_local[ff_inner_inner_inner * 8 + nn_inner_inner_inner] # fmt: on
def expected_after(A: T.Buffer[128, "float32"], B: T.Buffer[130, "float32"]): for i, j in T.grid(2, 65): if i * 65 + j >= 0 and i * 65 + j < 128: A[i * 65 + j] = T.float32(0) for i, j in T.grid(2, 65): B[i * 65 + j] = T.if_then_else(i * 65 + j >= 0 and i * 65 + j < 128, A[i * 65 + j], T.float32(0), dtype="float32")
def loop_carried_dependency(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128,)) B = T.match_buffer(b, (128,)) C = T.match_buffer(c, (128,)) for i in range(0, 128): with T.block("B"): vi = T.axis.S(128, i) B[vi] = A[vi] * 2.0 with T.block("C"): vi = T.axis.S(128, i) C[vi] = T.if_then_else(vi >= 1, B[vi - 1] + 1.0, 0.0, dtype="float32")
def main(placeholder: T.Buffer[(1, 2, 16, 16, 4), "float32"], T_layout_trans: T.Buffer[(1, 8, 16, 16), "float32"]) -> None: # type: ignore # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") for i0, i1, i2, i3 in T.grid(1, 8, 16, 16): with T.block("T_layout_trans"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4]) # type: ignore T.writes(T_layout_trans[ax0, ax1, ax2, ax3]) T_layout_trans[ax0, ax1, ax2, ax3] = T.if_then_else(ax0 < 1 and ax1 < 8 and ax2 < 16 and ax3 < 16, placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4], T.float32(0), dtype="float32") # type: ignore
def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True }) placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_2 = T.allocate([200704], "uint8", "global") for ax0_ax1_fused_4 in T.serial(0, 56): for ax2_4 in T.serial(0, 56): for ax3_init in T.serial(0, 64): T.store(tensor_2, (((ax0_ax1_fused_4 * 3584) + (ax2_4 * 64)) + ax3_init), T.uint8(0), True) for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): T.store( tensor_2, (((ax0_ax1_fused_4 * 3584) + (ax2_4 * 64)) + ax3_2), T.max( T.load("uint8", tensor_2, (((ax0_ax1_fused_4 * 3584) + (ax2_4 * 64)) + ax3_2)), T.if_then_else( ((((ax0_ax1_fused_4 * 2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4 * 2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, ((( ((ax0_ax1_fused_4 * 14336) + (T.floordiv(rv0_rv1_fused_1, 3) * 7168)) + (ax2_4 * 128)) + (T.floormod( rv0_rv1_fused_1, 3) * 64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) for ax0_ax1_fused_5 in T.serial(0, 56): for ax2_5, ax3_3 in T.grid(56, 64): T.store( T_cast_7.data, (((ax0_ax1_fused_5 * 3584) + (ax2_5 * 64)) + ax3_3), T.cast( T.load("uint8", tensor_2, (((ax0_ax1_fused_5 * 3584) + (ax2_5 * 64)) + ax3_3)), "int16"), True)
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1( placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True }) placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") T_cast_5 = T.match_buffer(T_cast_4, [360000], dtype="int16") # body PaddedInput_1 = T.allocate([379456], "int16", "global") for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): PaddedInput_1[i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1] = T.if_then_else( 1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, placeholder_13[i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): Conv2dOutput_1 = T.allocate([64], "int32", "global") for ff_1 in T.serial(0, 64): Conv2dOutput_1[ff_1] = 0 for ry, rx, rc_1 in T.grid(3, 3, 64): Conv2dOutput_1[ff_1] = Conv2dOutput_1[ff_1] + T.cast( PaddedInput_1[ T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1], "int32") * T.cast( placeholder_14[ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1], "int32") for ax3_inner_2 in T.serial(0, 64): T_cast_5[ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2] = T.cast( T.cast( T.max( T.min( T.q_multiply_shift( Conv2dOutput_1[ax3_inner_2] + placeholder_15[ax3_inner_2], 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16")
def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.handle, var_bn_offset: T.handle, var_compute: T.handle) -> None: X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32") W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32") B = T.match_buffer(var_B, [512, 1, 1], dtype="float32") bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32") bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32") compute = T.match_buffer(var_compute, [1, 512, 56, 56], dtype="float32") pad_temp = T.alloc_buffer([1, 512, 58, 58], dtype="float32") compute_1 = T.alloc_buffer([1, 512, 56, 56], dtype="float32") bias_add = T.alloc_buffer([1, 512, 56, 56], dtype="float32") bn_mul = T.alloc_buffer([1, 512, 56, 56], dtype="float32") bn_add = T.alloc_buffer([1, 512, 56, 56], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 512, 58, 58): with T.block("pad_temp"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57 and i3_1 >= 1 and i3_1 < 57, X[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 512, 56, 56, 512, 3, 3): with T.block("compute"): nn, ff, yy, xx, rc, ry, rx = T.axis.remap( "SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) with T.init(): compute_1[nn, ff, yy, xx] = T.float32(0) compute_1[nn, ff, yy, xx] = compute_1[nn, ff, yy, xx] + pad_temp[ nn, rc, yy + ry, xx + rx] * W[ff, rc, ry, rx] for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): with T.block("bias_add"): i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) bias_add[i, j, k, l] = compute_1[i, j, k, l] + B[j, 0, 0] for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): with T.block("bn_mul"): i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) bn_mul[i, j, k, l] = bias_add[i, j, k, l] * bn_scale[j, 0, 0] for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): with T.block("bn_add"): i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) bn_add[i, j, k, l] = bn_mul[i, j, k, l] + bn_offset[j, 0, 0] for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): with T.block("compute_1"): i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) compute[i0_2, i1_2, i2_2, i3_2] = T.max(bn_add[i0_2, i1_2, i2_2, i3_2], T.float32(0))
def compacted_padding_pattern_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [16, 16], dtype="float32") C = T.match_buffer(c, [20, 20], dtype="float32") with T.block(): B = T.alloc_buffer([16, 16], dtype="float32") for i, j in T.grid(16, 16): with T.block(): B[i, j] = A[i, j] for i, j in T.grid(20, 20): with T.block(): C[i, j] = T.if_then_else( 2 <= i and i < 18 and 2 <= j and j < 18, B[i - 2, j - 2], 0.0, dtype="float32" )
def read_out_of_bound_after_compute_at(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [16], "float32") B = T.alloc_buffer([16], "float32") C = T.match_buffer(c, [16], "float32") for j in T.serial(0, 16): for i in T.serial(0, T.min(1, 15 - j) + 1): with T.block("B"): v = T.axis.S(16, j + i) B[v] = A[v] with T.block("C"): v = T.axis.S(16, j) T.reads([B[v : v + 2]]) C[v] = T.if_then_else(v < 15, T.max(B[v], B[v + 1]), B[v], dtype="float32")
def read_out_of_bound(a: T.handle, c:T.handle) -> None: A = T.match_buffer(a, [16], "float32") B = T.alloc_buffer([16], "float32") C = T.match_buffer(c, [16], "float32") for i in T.serial(0, 16): with T.block("B"): v = T.axis.S(16, i) B[v] = A[v] for j in T.serial(0, 16): with T.block("C"): v = T.axis.S(16, j) T.reads(B[v : v + 2]) C[v] = T.if_then_else(v < 15, T.max(B[v], B[v + 1]), B[v], dtype="float32")
def before(A: T.Buffer[(128, ), "float32"], B: T.Buffer[(130, ), "float32"]): for i, j in T.grid(T.int64(2), T.int64(65)): if i * T.int64(65) + j >= T.int64( 0) and i * T.int64(65) + j < T.int64(128): A[i * T.int64(65) + j] = 0.0 for i, j in T.grid(T.int64(2), T.int64(65)): B[i * T.int64(65) + j] = T.if_then_else( i * T.int64(65) + j >= T.int64(0) and i * T.int64(65) + j < T.int64(128), A[i * T.int64(65) + j], 0.0, dtype="float32", )
def compacted_spatial_tiled_pad_and_pooling( X: T.Buffer[(64, 112, 112), "int32"], Y: T.Buffer[(64, 56, 56), "int32"] ) -> None: for h_o, w_o in T.grid(14, 14): with T.block(): T.reads(X[0:64, h_o * 8 - 1 : h_o * 8 + 8, w_o * 8 - 1 : w_o * 8 + 8]) T.writes(Y[h_o * 4 : h_o * 4 + 4, w_o * 4 : w_o * 4 + 4, 0:64]) X_cache = T.alloc_buffer([9, 9, 64], dtype="int32") for ax0, ax1, ax2 in T.grid(64, 9, 9): with T.block("cache"): T.where(1 <= h_o * 8 + ax1 and 1 <= w_o * 8 + ax2) T.reads(X[ax0, h_o * 8 + ax1 - 1, w_o * 8 + ax2 - 1]) T.writes( X_cache[ h_o * 8 + ax1 - T.max(0, h_o * 8 - 1) - 1, w_o * 8 + ax2 - T.max(0, w_o * 8 - 1) - 1, ax0, ] ) X_cache[ h_o * 8 + ax1 - T.max(0, h_o * 8 - 1) - 1, w_o * 8 + ax2 - T.max(0, w_o * 8 - 1) - 1, ax0, ] = X[ax0, h_o * 8 + ax1 - 1, w_o * 8 + ax2 - 1] for h_i, w_i, kh, kw, c in T.grid(4, 4, 3, 3, 64): with T.block("compute"): T.reads( X_cache[ h_o * 8 + h_i * 2 + kh - T.max(0, h_o * 8 - 1) - 1, w_o * 8 + w_i * 2 + kw - T.max(0, w_o * 8 - 1) - 1, c, ] ) T.writes(Y[h_o * 4 + h_i, w_o * 4 + w_i, c]) if kh == 0 and kw == 0: Y[h_o * 4 + h_i, w_o * 4 + w_i, c] = 0 Y[h_o * 4 + h_i, w_o * 4 + w_i, c] = T.max( Y[h_o * 4 + h_i, w_o * 4 + w_i, c], T.if_then_else( T.likely(1 <= h_o * 8 + h_i * 2 + kh, dtype="bool") and T.likely(1 <= w_o * 8 + w_i * 2 + kw, dtype="bool"), X_cache[ h_o * 8 + h_i * 2 + kh - T.max(0, h_o * 8 - 1) - 1, w_o * 8 + w_i * 2 + kw - T.max(0, w_o * 8 - 1) - 1, c, ], 0, dtype="int32", ), )
def main(placeholder: T.Buffer[(1, 16, 7, 7, 32), "float32"], placeholder_1: T.Buffer[(25088,), "float32"], T_layout_trans: T.Buffer[(1, 1, 7, 7, 512), "float32"]) -> None: # function attr dict T.func_attr({"tir.noalias": True, "global_symbol": "main"}) # body # with T.block("root") for i0_i1_i2_i3_i4_fused in T.parallel(25088, annotations={"pragma_auto_unroll_max_step":64, "pragma_unroll_explicit":1}): with T.block("T_layout_trans_1"): ax0 = T.axis.spatial(1, 0) ax1 = T.axis.spatial(1, 0) ax2 = T.axis.spatial(7, i0_i1_i2_i3_i4_fused // 3584) ax3 = T.axis.spatial(7, i0_i1_i2_i3_i4_fused % 3584 // 512) ax4 = T.axis.spatial(512, i0_i1_i2_i3_i4_fused % 512) T.reads(placeholder[0, (ax4 * 49 + ax2 * 7 + ax3) % 25088 // 1568, (ax2 * 7 + ax3) % 49 // 7, ax3 % 7, (ax4 * 49 + ax2 * 7 + ax3) % 1568 // 49], placeholder_1[(ax4 * 49 + ax2 * 7 + ax3) % 25088]) T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4]) T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else(ax0 < 1 and ax1 * 512 + ax4 < 512 and ax2 < 7 and ax3 < 7, T.Select(T.float32(0) < T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32"), T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32"), T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32") * placeholder_1[((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088]), T.float32(0), dtype="float32")
def main( placeholder: T.Buffer[(1, 1, 16, 16, 3), "float32"], placeholder_1: T.Buffer[(2, 1, 5, 5, 3, 4), "float32"], conv2d_NCHWc: T.Buffer[(1, 2, 16, 16, 4), "float32"] ) -> None: # type: ignore # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") data_pad = T.alloc_buffer([1, 1, 20, 20, 3], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(1, 1, 20, 20, 3): with T.block("data_pad"): i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap( "SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1]) T.writes(data_pad[i0_1, i1_1, i2_1, i3_1, i4_1]) data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18, placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1], T.float32(0), dtype="float32") # type: ignore # pylint: disable=R1716 for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(1, 2, 16, 16, 4, 3, 5, 5): with T.block("conv2d_NCHWc"): n, oc_chunk, oh, ow, oc_block, ic, kh, kw = T.axis.remap( "SSSSSRRR", [i0, i1, i2, i3, i4, i5, i6, i7]) T.reads(data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3], placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block]) # type: ignore T.writes(conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block]) T.block_attr({ "workload": [ "conv2d_NCHWc.x86", ["TENSOR", [1, 1, 16, 16, 3], "float32"], ["TENSOR", [2, 1, 5, 5, 3, 4], "float32"], [1, 1], [2, 2, 2, 2], [1, 1], "NCHW3c", "NCHW4c", "float32" ] }) with T.init(): conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = T.float32(0) conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[ n, oc_chunk, oh, ow, oc_block] + data_pad[ n, ic // 3, oh + kh, ow + kw, ic % 3] * placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block] # type: ignore
def main( # type: ignore placeholder: T.Buffer[(1, 3, 16, 16), "float32"], # type: ignore T_layout_trans: T.Buffer[(1, 1, 16, 16, 3), "float32"], # type: ignore ) -> None: # type: ignore # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") for i0, i1, i2, i3, i4 in T.grid(1, 1, 16, 16, 3): with T.block("T_layout_trans"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder[ax0, ax1 * 3 + ax4, ax2, ax3]) T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4]) T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else( ax0 < 1 and ax1 * 3 + ax4 < 3 and ax2 < 16 and ax3 < 16, # type: ignore placeholder[ax0, ax1 * 3 + ax4, ax2, ax3], T.float32(0), dtype="float32", )
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True) for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): Conv2dOutput_7 = T.allocate([64], "int32", "global") for ff_3 in T.serial(0, 64): T.store(Conv2dOutput_7, ff_3, 0, True) for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True) for ax3_inner_7 in T.serial(0, 64): T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True)
def conv2d_nhwc_reindex_weight(var_inputs: T.handle, var_weight: T.handle, var_conv2d_nhwc: T.handle) -> None: inputs = T.match_buffer(var_inputs, [1, 224, 224, 3], dtype="float32") weight = T.match_buffer(var_weight, [7, 7, 3, 64], dtype="float32") conv2d_nhwc = T.match_buffer(var_conv2d_nhwc, [1, 112, 112, 64], dtype="float32") PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") weight_reindex = T.alloc_buffer([64, 7, 7, 3], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): with T.block("PadInput"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1]) T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( i1_1 >= 3 and i1_1 < 227 and i2_1 >= 3 and i2_1 < 227, inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1], T.float32(0), dtype="float32", ) for ax0, ax1, ax2, ax3, ax4, ax5, ax6 in T.grid(1, 1, 1, 64, 7, 7, 3): with T.block("weight_reindex"): v0, v1, v2, v3, v4, v5, v6 = T.axis.remap( "SSSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5, ax6]) T.reads(weight[v4, v5, v6, v3]) T.writes(weight_reindex[v3, v4, v5, v6]) weight_reindex[v3, v4, v5, v6] = weight[v4, v5, v6, v3] for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3): with T.block("conv2d_nhwc"): n, h, w, co, rh, rw, rc = T.axis.remap( "SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) T.reads( PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight_reindex[co, rh, rw, rc], ) T.writes(conv2d_nhwc[n, h, w, co]) with T.init(): conv2d_nhwc[n, h, w, co] = T.float32(0) conv2d_nhwc[n, h, w, co] = ( conv2d_nhwc[n, h, w, co] + PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc] * weight_reindex[co, rh, rw, rc])
def compacted_padding_pattern_inlined( X: T.Buffer[(224, 224), "float32"], Y: T.Buffer[(224, 224), "float32"] ) -> None: cache = T.alloc_buffer([224, 224], dtype="float32") for h, w in T.grid(224, 224): with T.block("cache"): cache[h, w] = X[h, w] for h, w, kh, kw in T.grid(224, 224, 3, 3): with T.block("compute"): Y[h, w] = T.max( Y[h, w], T.if_then_else( T.likely(1 <= h + kh, dtype="bool") and T.likely(h + kh < 225, dtype="bool") and T.likely(1 <= w + kw, dtype="bool") and T.likely(w + kw < 225, dtype="bool"), cache[h + kh - 1, w + kw - 1], 0.0, dtype="float32", ), )
def sum_pool_2d(x: T.Buffer[(1, 16, 225, 225), "int8"], tensor: T.Buffer[(1, 16, 225, 225), "int8"]): pad_temp = T.alloc_buffer([1, 16, 231, 231], dtype="int8") for i0, i1, i2, i3 in T.grid(1, 16, 231, 231): with T.block("pad_temp"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) pad_temp[ax0, ax1, ax2, ax3] = T.if_then_else( 3 <= ax2 and ax2 < 228 and 3 <= ax3 and ax3 < 228, x[ax0, ax1, ax2 - 3, ax3 - 3], T.int8(0), dtype="int8", ) for i0, i1, i2, i3, i4, i5 in T.grid(1, 16, 225, 225, 7, 7): with T.block("tensor"): ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap( "SSSSRR", [i0, i1, i2, i3, i4, i5]) with T.init(): tensor[ax0, ax1, ax2, ax3] = T.int8(0) tensor[ax0, ax1, ax2, ax3] = (tensor[ax0, ax1, ax2, ax3] + pad_temp[ax0, ax1, ax2 + rv0, ax3 + rv1])