def primfunc_global_allocates(placeholder_144: T.handle, placeholder_145: T.handle, placeholder_146: T.handle, T_cast_48: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "fused_nn_conv2d_add_cast_fixed_point_multiply_clip_cast_cast_13", "tir.noalias": True}) placeholder_147 = T.match_buffer(placeholder_144, [100352], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_148 = T.match_buffer(placeholder_145, [4608], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_149 = T.match_buffer(placeholder_146, [512], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_49 = T.match_buffer(T_cast_48, [100352], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_22 = T.allocate([131072], "int16", "global") DepthwiseConv2d_9 = T.allocate([100352], "int32", "global") for i1_29, i2_39, i3_40 in T.grid(16, 16, 512): PaddedInput_22[(((i1_29*8192) + (i2_39*512)) + i3_40)] = T.if_then_else(((((1 <= i1_29) and (i1_29 < 15)) and (1 <= i2_39)) and (i2_39 < 15)), placeholder_147[((((i1_29*7168) + (i2_39*512)) + i3_40) - 7680)], T.int16(0), dtype="int16") for i_9, j_9, c_9 in T.grid(14, 14, 512): DepthwiseConv2d_9[(((i_9*7168) + (j_9*512)) + c_9)] = 0 for di_9, dj_9 in T.grid(3, 3): DepthwiseConv2d_9[(((i_9*7168) + (j_9*512)) + c_9)] = (DepthwiseConv2d_9[(((i_9*7168) + (j_9*512)) + c_9)] + (PaddedInput_22[(((((i_9*8192) + (di_9*8192)) + (j_9*512)) + (dj_9*512)) + c_9)].astype("int32")*placeholder_148[(((di_9*1536) + (dj_9*512)) + c_9)].astype("int32"))) for ax1_27, ax2_28, ax3_30 in T.grid(14, 14, 512): DepthwiseConv2d_9[(((ax1_27*7168) + (ax2_28*512)) + ax3_30)] = (DepthwiseConv2d_9[(((ax1_27*7168) + (ax2_28*512)) + ax3_30)] + placeholder_149[ax3_30]) for i1_30, i2_40, i3_41 in T.grid(14, 14, 512): DepthwiseConv2d_9[(((i1_30*7168) + (i2_40*512)) + i3_41)] = T.q_multiply_shift(DepthwiseConv2d_9[(((i1_30*7168) + (i2_40*512)) + i3_41)], 1269068532, 31, -4, dtype="int32") for i1_31, i2_41, i3_42 in T.grid(14, 14, 512): DepthwiseConv2d_9[(((i1_31*7168) + (i2_41*512)) + i3_42)] = T.max(T.max(DepthwiseConv2d_9[(((i1_31*7168) + (i2_41*512)) + i3_42)], 255), 0) for ax1_28, ax2_29, ax3_31 in T.grid(14, 14, 512): PaddedInput_22[(((ax1_28*7168) + (ax2_29*512)) + ax3_31)] = DepthwiseConv2d_9[(((ax1_28*7168) + (ax2_29*512)) + ax3_31)].astype("uint8") for ax1_29, ax2_30, ax3_32 in T.grid(14, 14, 512): T_cast_49[(((ax1_29*7168) + (ax2_30*512)) + ax3_32)] = PaddedInput_22[(((ax1_29*7168) + (ax2_30*512)) + ax3_32)].astype("int16")
def main(placeholder: T.Buffer[(1, 384), "int64"], placeholder_1: T.Buffer[(30522, 768), "float32"], placeholder_2: T.Buffer[(1, 384, 768), "float32"], T_add: T.Buffer[(1, 384, 768), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") for i0, i1, i2 in T.grid(1, 384, 768): with T.block("T_add_1"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads( placeholder[ax0, ax1], placeholder_1[ T.min(T.max(T.int64(0), placeholder[ ax0, ax1]), T.int64(30521)):T.min( T.max(T.int64(0), placeholder[ax0, ax1] + T.int64(30522)), T.int64(30521)) + T.int64(1), ax2], placeholder_2[ax0, ax1, ax2]) T.writes(T_add[ax0, ax1, ax2]) T_add[ax0, ax1, ax2] = placeholder_1[T.min( T.max( T.int64(0), T.Select( T.cast(placeholder[ax0, ax1] < T.int64(0), "int32" ) != 0, placeholder[ax0, ax1] + T.int64(30522), placeholder[ax0, ax1]) ), T.int64(30521)), ax2] + placeholder_2[ax0, ax1, ax2]
def primfunc_local_allocates(placeholder_162: T.handle, placeholder_163: T.handle, placeholder_164: T.handle, T_cast_76: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "fused_nn_conv2d_add_cast_fixed_point_multiply_clip_cast_cast_9", "tir.noalias": True}) placeholder_165 = T.match_buffer(placeholder_162, [100352], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_166 = T.match_buffer(placeholder_163, [4608], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_167 = T.match_buffer(placeholder_164, [512], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_77 = T.match_buffer(T_cast_76, [100352], dtype="int16", elem_offset=0, align=128, offset_factor=1) sid_21 = T.allocate_const([0,1,2,3,4,5,6,7], "int8", [8]) # body PaddedInput_25 = T.allocate([131072], "int16", "global") for i1_35, i2_46, i3_47 in T.grid(16, 16, 512): PaddedInput_25[(((i1_35*8192) + (i2_46*512)) + i3_47)] = T.if_then_else(((((1 <= i1_35) and (i1_35 < 15)) and (1 <= i2_46)) and (i2_46 < 15)), placeholder_165[((((i1_35*7168) + (i2_46*512)) + i3_47) - 7680)], T.int16(0), dtype="int16") T_add_11 = T.allocate([100352], "int32", "global") with T.allocate([100352], "int32", "global") as DepthwiseConv2d_11: for i_11, j_11, c_11 in T.grid(14, 14, 512): DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] = 0 for di_11, dj_11 in T.grid(3, 3): DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] = (DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] + (PaddedInput_25[(((((i_11*8192) + (di_11*8192)) + (j_11*512)) + (dj_11*512)) + c_11)].astype("int32")*placeholder_166[(((di_11*1536) + (dj_11*512)) + c_11)].astype("int32"))) for ax1_44, ax2_45, ax3_47 in T.grid(14, 14, 512): T_add_11[(((ax1_44*7168) + (ax2_45*512)) + ax3_47)] = (DepthwiseConv2d_11[(((ax1_44*7168) + (ax2_45*512)) + ax3_47)] + placeholder_167[ax3_47]) compute_22 = T.allocate([100352], "int32", "global") with T.allocate([100352], "int32", "global") as T_cast_78: for ax1_45, ax2_46, ax3_48 in T.grid(14, 14, 512): T_cast_78[(((ax1_45*7168) + (ax2_46*512)) + ax3_48)] = T_add_11[(((ax1_45*7168) + (ax2_46*512)) + ax3_48)] for i1_36, i2_47, i3_48 in T.grid(14, 14, 512): compute_22[(((i1_36*7168) + (i2_47*512)) + i3_48)] = T.q_multiply_shift(T_cast_78[(((i1_36*7168) + (i2_47*512)) + i3_48)], 1948805937, 31, -5, dtype="int32") T_cast_79 = T.allocate([100352], "uint8", "global") with T.allocate([100352], "int32", "global") as compute_23: for i1_37, i2_48, i3_49 in T.grid(14, 14, 512): compute_23[(((i1_37*7168) + (i2_48*512)) + i3_49)] = T.max(T.max(compute_22[(((i1_37*7168) + (i2_48*512)) + i3_49)], 255), 0) for ax1_46, ax2_47, ax3_49 in T.grid(14, 14, 512): T_cast_79[(((ax1_46*7168) + (ax2_47*512)) + ax3_49)] = compute_23[(((ax1_46*7168) + (ax2_47*512)) + ax3_49)].astype("uint8") for ax1_47, ax2_48, ax3_50 in T.grid(14, 14, 512): T_cast_77[(((ax1_47*7168) + (ax2_48*512)) + ax3_50)] = T_cast_79[(((ax1_47*7168) + (ax2_48*512)) + ax3_50)].astype("int16")
def main(A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None: T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") T_softmax_expsum = T.alloc_buffer([256], dtype="float32") for i0, i1 in T.grid(256, 256): with T.block("T_softmax_maxelem"): i0_1, k = T.axis.remap("SR", [i0, i1]) with T.init(): T_softmax_maxelem[i0_1] = T.min_value("float32") T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) for i0, i1 in T.grid(256, 256): with T.block("T_softmax_expsum"): i0_2, k = T.axis.remap("SR", [i0, i1]) with T.init(): T_softmax_expsum[i0_2] = T.float32(0) T_softmax_expsum[i0_2] = T_softmax_expsum[i0_2] + T.exp( A[i0_2, k] - T_softmax_maxelem[i0_2], dtype="float32") for i0_3, i1 in T.grid(256, 256): with T.block("T_softmax_norm"): i0_4, i1_1 = T.axis.remap("SS", [i0_3, i1]) T_softmax_norm[i0_4, i1_1] = T.exp( A[i0_4, i1_1] - T_softmax_maxelem[i0_4], dtype="float32") / T_softmax_expsum[i0_4]
def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast( placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", "tir.noalias": True }) placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8") placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") T_cast_1 = T.match_buffer(T_cast, [360000], dtype="int16") # body for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): T_cast_1[ ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner] = T.cast( T.cast( T.max( T.min( T.q_multiply_shift(T.cast( placeholder_2[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), "uint8"), "int16")
def non_perfect_tiling_cache(a: T.handle, b: T.handle) -> None: X = T.match_buffer(a, [224, 224], dtype="float32") Y = T.match_buffer(b, [224, 224], dtype="float32") cache = T.alloc_buffer([224, 224], dtype="float32") for hh_0, ww_0 in T.grid(28, 28): for ax0 in T.serial(0, 10): for ax1 in T.serial(0, 10): with T.block("cache"): h = T.axis.spatial(224, hh_0 * 8 - 1 + ax0) w = T.axis.spatial(224, ww_0 * 8 - 1 + ax1) T.where(1 <= hh_0 * 8 + ax0 and hh_0 * 8 + ax0 < 225 and 1 <= ww_0 * 8 + ax1 and ww_0 * 8 + ax1 < 225) cache[h, w] = X[h, w] for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3): with T.block("compute"): h = T.axis.spatial(224, hh_0 * 8 + hh_1) w = T.axis.spatial(224, ww_0 * 8 + ww_1) kh, kw = T.axis.remap("RR", [khh, kww]) with T.init(): Y[h, w] = 0.0 Y[h, w] = T.max( Y[h, w], T.if_then_else( T.likely(1 <= h + kh, dtype="bool") and T.likely(h + kh < 225, dtype="bool") and T.likely(1 <= w + kw, dtype="bool") and T.likely(w + kw < 225, dtype="bool"), cache[h + kh - 1, w + kw - 1], 0.0, dtype="float32", ), )
def main(A: T.handle, tensor: T.handle) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # buffer definition tensor_2 = T.buffer_decl([1, 10, 12, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) A_1 = T.match_buffer(A, [1, 12, 14, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) tensor_1 = T.match_buffer(tensor, [1, 8, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body T.realize(tensor_1[0:1, 0:8, 0:8, 0:16], "") for ax1_outer in T.serial(0, 2): T.realize(tensor_2[0:1, (ax1_outer*4):((ax1_outer*4) + 6), 0:12, 0:16], "") T.attr(tensor_2, "rolling_buffer_scope", True) for ax1 in T.serial(0, 6): for ax2 in T.serial(0, 12): for ax3 in T.serial(0, 16): tensor_2[0, (ax1 + (ax1_outer*4)), ax2, ax3] = T.int8(0) for dh in T.serial(0, 3): for dw in T.serial(0, 3): tensor_2[0, (ax1 + (ax1_outer*4)), ax2, ax3] = T.max(tensor_2[0, (ax1 + (ax1_outer*4)), ax2, ax3], A_1[0, ((ax1 + (ax1_outer*4)) + dh), (ax2 + dw), ax3]) for ax1_inner in T.serial(0, 4): for ax2_inner in T.serial(0, 8): for ax3_inner in T.serial(0, 16): tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner] = T.int8(0) for dh_1 in T.serial(0, 3): for dw_1 in T.serial(0, 5): tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner] = T.max(tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner], tensor_2[0, ((ax1_inner + (ax1_outer*4)) + dh_1), (ax2_inner + dw_1), ax3_inner])
def expected_recursive_bufferslice_indices(data: T.handle, index: T.handle) -> None: index_buf = T.match_buffer(index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1) data_buf = T.match_buffer(data, [16, 16], elem_offset=0, align=128, offset_factor=1) with T.block("root"): T.reads([]) T.writes([]) out_buf = T.alloc_buffer([16, 16], elem_offset=0, align=128, offset_factor=1) for i0, i1 in T.grid(16, 16): with T.block(): vi, vj = T.axis.remap("SS", [i0, i1]) T.reads([ data_buf[index_buf[index_buf[0]], index_buf[0]], index_buf[T.min(index_buf[0], 0):T.max(index_buf[0], 0) + 1], ]) T.writes([out_buf[vi, vj]]) out_buf[vi, vj] = data_buf[index_buf[index_buf[0]], index_buf[0]]
def lowered_reducer_max(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.serial(0, 128): for k in T.thread_binding(0, 128, thread="threadIdx.x"): with T.block("B_cross_thread_reduction"): vi, vk = T.axis.remap("SR", [i, k]) T.reads([A[vi, vk]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: T.max(x, y), [T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce(T.uint32(1), A[vi, vk], True, reduce_temp0.data, k, dtype="handle")) with T.block("B_write_back"): vi = T.axis.spatial(128, i) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0]
def different_access_indices(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128], dtype="float32") B = T.match_buffer(b, [128, 128], dtype="float32") for i, j in T.grid(128, 128): for k in T.thread_binding(0, 128, thread="threadIdx.x"): with T.block("B"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) T.reads([B[vi, vj], A[vi, vj, vk]]) T.writes([ B[T.min(vj, vi):T.min(vj, vi) + (T.max(vj, vi) + 1 - T.min(vj, vi)), T.min(vi, vj):T.min(vi, vj) + (T.max(vi, vj) + 1 - T.min(vi, vj)), ] ]) with T.init(): B[vj, vi] = T.float32(0) B[vi, vj] = B[vi, vj] + A[vi, vj, vk]
def compacted_spatial_tiled_pad_and_pooling( X: T.Buffer[(64, 112, 112), "int32"], Y: T.Buffer[(64, 56, 56), "int32"] ) -> None: for h_o, w_o in T.grid(14, 14): with T.block(): T.reads(X[0:64, h_o * 8 - 1 : h_o * 8 + 8, w_o * 8 - 1 : w_o * 8 + 8]) T.writes(Y[h_o * 4 : h_o * 4 + 4, w_o * 4 : w_o * 4 + 4, 0:64]) X_cache = T.alloc_buffer([9, 9, 64], dtype="int32") for ax0, ax1, ax2 in T.grid(64, 9, 9): with T.block("cache"): T.where(1 <= h_o * 8 + ax1 and 1 <= w_o * 8 + ax2) T.reads(X[ax0, h_o * 8 + ax1 - 1, w_o * 8 + ax2 - 1]) T.writes( X_cache[ h_o * 8 + ax1 - T.max(0, h_o * 8 - 1) - 1, w_o * 8 + ax2 - T.max(0, w_o * 8 - 1) - 1, ax0, ] ) X_cache[ h_o * 8 + ax1 - T.max(0, h_o * 8 - 1) - 1, w_o * 8 + ax2 - T.max(0, w_o * 8 - 1) - 1, ax0, ] = X[ax0, h_o * 8 + ax1 - 1, w_o * 8 + ax2 - 1] for h_i, w_i, kh, kw, c in T.grid(4, 4, 3, 3, 64): with T.block("compute"): T.reads( X_cache[ h_o * 8 + h_i * 2 + kh - T.max(0, h_o * 8 - 1) - 1, w_o * 8 + w_i * 2 + kw - T.max(0, w_o * 8 - 1) - 1, c, ] ) T.writes(Y[h_o * 4 + h_i, w_o * 4 + w_i, c]) if kh == 0 and kw == 0: Y[h_o * 4 + h_i, w_o * 4 + w_i, c] = 0 Y[h_o * 4 + h_i, w_o * 4 + w_i, c] = T.max( Y[h_o * 4 + h_i, w_o * 4 + w_i, c], T.if_then_else( T.likely(1 <= h_o * 8 + h_i * 2 + kh, dtype="bool") and T.likely(1 <= w_o * 8 + w_i * 2 + kw, dtype="bool"), X_cache[ h_o * 8 + h_i * 2 + kh - T.max(0, h_o * 8 - 1) - 1, w_o * 8 + w_i * 2 + kw - T.max(0, w_o * 8 - 1) - 1, c, ], 0, dtype="int32", ), )
def main(X: T.Buffer[(1, 512, 56, 56), "float32"], W: T.Buffer[(512, 512, 3, 3), "float32"], B: T.Buffer[(512, 1, 1), "float32"], bn_scale: T.Buffer[(512, 1, 1), "float32"], bn_offset: T.Buffer[(512, 1, 1), "float32"], compute: T.Buffer[(1, 512, 56, 56), "float32"]) -> None: compute_local = T.alloc_buffer([1, 512, 56, 56], dtype="float32", scope="local") for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(224, thread="blockIdx.x"): for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding( 2, thread="vthread.x"): for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding( 8, thread="threadIdx.x"): for i4_0, i5_0, i6_0, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid( 1, 3, 1, 32, 1, 1, 1, 1, 1, 1, 16, 1, 3, 1, 8, 2, 28): with T.block("compute"): nn = T.axis.spatial(1, 0) ff = T.axis.spatial( 512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + i1_4) yy = T.axis.spatial( 56, i0_0_i1_0_i2_0_i3_0_fused // 2 % 7 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + i2_4) xx = T.axis.spatial( 56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + i3_4) rc = T.axis.reduce(512, i4_1 * 16 + i4_2) ry, rx = T.axis.remap("RR", [i5_0, i6_2]) with T.init(): compute_local[nn, ff, yy, xx] = T.float32(0) compute_local[nn, ff, yy, xx] = compute_local[ nn, ff, yy, xx] + T.if_then_else( yy + ry >= 1 and yy + ry < 57 and xx + rx >= 1 and xx + rx < 57, X[nn, rc, yy + ry - 1, xx + rx - 1], T.float32(0), dtype="float32") * W[ff, rc, ry, rx] for ax0, ax1, ax2, ax3 in T.grid(1, 8, 2, 28): with T.block("compute_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial( 512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + ax1) v2 = T.axis.spatial( 56, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + ax2) v3 = T.axis.spatial( 56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + ax3) compute[v0, v1, v2, v3] = T.max( (compute_local[v0, v1, v2, v3] + B[v1, 0, 0]) * bn_scale[v1, 0, 0] + bn_offset[v1, 0, 0], T.float32(0))
def softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: A = T.match_buffer(var_A, [256, 256], dtype="float32") T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256], dtype="float32") T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") for i0 in T.thread_binding(0, 256, thread="blockIdx.x"): for ax0_0 in T.serial(0, 8): for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_maxelem"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads([T_softmax_maxelem_shared[i0_1], A[i0_1, k]]) T.writes([T_softmax_maxelem_shared[i0_1]]) with T.init(): T_softmax_maxelem_shared[i0_1] = T.min_value("float32") T_softmax_maxelem_shared[i0_1] = T.max( T_softmax_maxelem_shared[i0_1], A[i0_1, k] ) for ax0_0 in T.serial(0, 8): for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_expsum"): i0_2 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads( [ T_softmax_expsum_shared[i0_2], A[i0_2, k], T_softmax_maxelem_shared[i0_2], ] ) T.writes([T_softmax_expsum_shared[i0_2]]) with T.init(): T_softmax_expsum_shared[i0_2] = T.float32(0) T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp( A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32" ) for i1_0 in T.serial(0, 8): for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_norm"): i0_3 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_0 * 32 + i1_1) T.reads( [ A[i0_3, i1], T_softmax_maxelem_shared[i0_3], T_softmax_expsum_shared[i0_3], ] ) T.writes([T_softmax_norm[i0_3, i1]]) T.block_attr({"axis": 1}) T_softmax_norm[i0_3, i1] = ( T.exp( A[i0_3, i1] - T_softmax_maxelem_shared[i0_3], dtype="float32", ) / T_softmax_expsum_shared[i0_3] )
def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True }) placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_2 = T.allocate([200704], "uint8", "global") for ax0_ax1_fused_4 in T.serial(0, 56): for ax2_4 in T.serial(0, 56): for ax3_init in T.serial(0, 64): T.store(tensor_2, (((ax0_ax1_fused_4 * 3584) + (ax2_4 * 64)) + ax3_init), T.uint8(0), True) for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): T.store( tensor_2, (((ax0_ax1_fused_4 * 3584) + (ax2_4 * 64)) + ax3_2), T.max( T.load("uint8", tensor_2, (((ax0_ax1_fused_4 * 3584) + (ax2_4 * 64)) + ax3_2)), T.if_then_else( ((((ax0_ax1_fused_4 * 2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4 * 2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, ((( ((ax0_ax1_fused_4 * 14336) + (T.floordiv(rv0_rv1_fused_1, 3) * 7168)) + (ax2_4 * 128)) + (T.floormod( rv0_rv1_fused_1, 3) * 64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) for ax0_ax1_fused_5 in T.serial(0, 56): for ax2_5, ax3_3 in T.grid(56, 64): T.store( T_cast_7.data, (((ax0_ax1_fused_5 * 3584) + (ax2_5 * 64)) + ax3_3), T.cast( T.load("uint8", tensor_2, (((ax0_ax1_fused_5 * 3584) + (ax2_5 * 64)) + ax3_3)), "int16"), True)
def reducer_max(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") for i in T.serial(0, 128): for k in T.thread_binding(0, 128, thread="threadIdx.x"): with T.block("B"): vi, vk = T.axis.remap("SR", [i, k]) T.reads([B[vi], A[vi, vk]]) T.writes([B[vi]]) with T.init(): B[vi] = T.min_value("float32") B[vi] = T.max(B[vi], A[vi, vk])
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast( placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True }) placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16") placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32") T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16") # body PaddedInput = T.allocate([360000], "int16", "global") for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): T.store( PaddedInput, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True) for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): Conv2dOutput = T.allocate([64], "int32", "global") for ff in T.serial(0, 64): T.store(Conv2dOutput, ff, 0, True) for rc in T.serial(0, 64): T.store( Conv2dOutput, ff, T.load("int32", Conv2dOutput, ff) + T.cast( T.load("int16", PaddedInput, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast( T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True) for ax3_inner_1 in T.serial(0, 64): T.store( T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast( T.cast( T.max( T.min( T.q_multiply_shift( T.load("int32", Conv2dOutput, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True)
def main(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: disable=no-self-argument T.func_attr({"global_symbol": "main", "tir.noalias": True}) A = T.match_buffer(a, (16, 16), "float32") B = T.match_buffer(b, (16, 16), "float32") D = T.match_buffer(d, (16, 16), "float32") C = T.alloc_buffer((16, 16), "float32") with T.block([16, 16, T.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]: with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] with T.block([16, 16], "relu") as [vi, vj]: D[vi, vj] = T.max(C[vi, vj], 0.0)
def access_of_padding_pattern() -> None: X = T.alloc_buffer([28, 28]) X_pad = T.alloc_buffer([32, 32]) Y = T.alloc_buffer([28, 28]) for i, j in T.grid(32, 32): with T.block("padding"): vi, vj = T.axis.remap("SS", [i, j]) T.reads([ X[T.max(vi - 2, 0):T.min(vi - 2, 27) + 1, T.max(vj - 2, 0):T.min(vj - 2, 27) + 1, ] ]) T.writes([X_pad[vi, vj]]) X_pad[vi, vj] = T.if_then_else(2 <= vi and vi < 30 and 2 <= vj and vj < 30, X[vi - 2, vj - 2], 0.0, dtype="float32") with T.block("padding_reverse"): vi, vj = T.axis.remap("SS", [i, j]) T.reads([ X_pad[T.max(vi, 2):T.min(vi, 29) + 1, T.max(vj, 2):T.min(vj, 29) + 1] ]) T.writes([ Y[T.max(vi - 2, 0):T.min(vi - 2, 27) + 1, T.max(vj - 2, 0):T.min(vj - 2, 27) + 1, ] ]) if 2 <= vi and vi < 30 and 2 <= vj and vj < 30: Y[vi - 2, vj - 2] = X_pad[vi, vj]
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_( placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", "tir.noalias": True }) placeholder_19 = T.match_buffer(placeholder_16, [360000], dtype="int16") placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16") placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32") T_add_1 = T.match_buffer(T_add, [1440000], dtype="int32") # body PaddedInput_2 = T.allocate([360000], "int16", "global") for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): PaddedInput_2[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] = placeholder_19[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): Conv2dOutput_2 = T.allocate([64], "int32", "global") for ax3_outer_1 in T.serial(0, 4): for ff_2 in T.serial(0, 64): Conv2dOutput_2[ff_2] = 0 for rc_2 in T.serial(0, 64): Conv2dOutput_2[ff_2] = Conv2dOutput_2[ff_2] + T.cast( PaddedInput_2[ax0_ax1_fused_ax2_fused_2 * 64 + rc_2], "int32") * T.cast( placeholder_20[rc_2 * 256 + ax3_outer_1 * 64 + ff_2], "int32") for ax3_inner_3 in T.serial(0, 64): T_add_1[ ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3] = T.q_multiply_shift(T.cast( T.cast( T.max( T.min( T.q_multiply_shift( Conv2dOutput_2[ax3_inner_3] + placeholder_21[ax3_outer_1 * 64 + ax3_inner_3], 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136
def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.handle, var_bn_offset: T.handle, var_compute: T.handle) -> None: X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32") W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32") B = T.match_buffer(var_B, [512, 1, 1], dtype="float32") bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32") bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32") compute = T.match_buffer(var_compute, [1, 512, 56, 56], dtype="float32") pad_temp = T.alloc_buffer([1, 512, 58, 58], dtype="float32") compute_1 = T.alloc_buffer([1, 512, 56, 56], dtype="float32") bias_add = T.alloc_buffer([1, 512, 56, 56], dtype="float32") bn_mul = T.alloc_buffer([1, 512, 56, 56], dtype="float32") bn_add = T.alloc_buffer([1, 512, 56, 56], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 512, 58, 58): with T.block("pad_temp"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57 and i3_1 >= 1 and i3_1 < 57, X[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 512, 56, 56, 512, 3, 3): with T.block("compute"): nn, ff, yy, xx, rc, ry, rx = T.axis.remap( "SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) with T.init(): compute_1[nn, ff, yy, xx] = T.float32(0) compute_1[nn, ff, yy, xx] = compute_1[nn, ff, yy, xx] + pad_temp[ nn, rc, yy + ry, xx + rx] * W[ff, rc, ry, rx] for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): with T.block("bias_add"): i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) bias_add[i, j, k, l] = compute_1[i, j, k, l] + B[j, 0, 0] for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): with T.block("bn_mul"): i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) bn_mul[i, j, k, l] = bias_add[i, j, k, l] * bn_scale[j, 0, 0] for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): with T.block("bn_add"): i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) bn_add[i, j, k, l] = bn_mul[i, j, k, l] + bn_offset[j, 0, 0] for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): with T.block("compute_1"): i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) compute[i0_2, i1_2, i2_2, i3_2] = T.max(bn_add[i0_2, i1_2, i2_2, i3_2], T.float32(0))
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1( placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True }) placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") T_cast_5 = T.match_buffer(T_cast_4, [360000], dtype="int16") # body PaddedInput_1 = T.allocate([379456], "int16", "global") for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): PaddedInput_1[i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1] = T.if_then_else( 1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, placeholder_13[i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): Conv2dOutput_1 = T.allocate([64], "int32", "global") for ff_1 in T.serial(0, 64): Conv2dOutput_1[ff_1] = 0 for ry, rx, rc_1 in T.grid(3, 3, 64): Conv2dOutput_1[ff_1] = Conv2dOutput_1[ff_1] + T.cast( PaddedInput_1[ T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1], "int32") * T.cast( placeholder_14[ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1], "int32") for ax3_inner_2 in T.serial(0, 64): T_cast_5[ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2] = T.cast( T.cast( T.max( T.min( T.q_multiply_shift( Conv2dOutput_1[ax3_inner_2] + placeholder_15[ax3_inner_2], 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16")
def read_out_of_bound(a: T.handle, c:T.handle) -> None: A = T.match_buffer(a, [16], "float32") B = T.alloc_buffer([16], "float32") C = T.match_buffer(c, [16], "float32") for i in T.serial(0, 16): with T.block("B"): v = T.axis.S(16, i) B[v] = A[v] for j in T.serial(0, 16): with T.block("C"): v = T.axis.S(16, j) T.reads(B[v : v + 2]) C[v] = T.if_then_else(v < 15, T.max(B[v], B[v + 1]), B[v], dtype="float32")
def read_out_of_bound_after_compute_at(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [16], "float32") B = T.alloc_buffer([16], "float32") C = T.match_buffer(c, [16], "float32") for j in T.serial(0, 16): for i in T.serial(0, T.min(1, 15 - j) + 1): with T.block("B"): v = T.axis.S(16, j + i) B[v] = A[v] with T.block("C"): v = T.axis.S(16, j) T.reads([B[v : v + 2]]) C[v] = T.if_then_else(v < 15, T.max(B[v], B[v + 1]), B[v], dtype="float32")
def single_reduction_loop_with_block_predicate( A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None: T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") for i0 in T.serial(256): for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_maxelem"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) T.reads(T_softmax_maxelem_shared[i0_1], A[i0_1, k]) T.writes(T_softmax_maxelem_shared[i0_1]) with T.init(): T_softmax_maxelem_shared[i0_1] = T.float32( -3.4028234663852886e38) T_softmax_maxelem_shared[i0_1] = T.max( T_softmax_maxelem_shared[i0_1], A[i0_1, k]) for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_expsum"): i0_2 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) T.reads(T_softmax_expsum_shared[i0_2], A[i0_2, k], T_softmax_maxelem_shared[i0_2]) T.writes(T_softmax_expsum_shared[i0_2]) with T.init(): T_softmax_expsum_shared[i0_2] = T.float32(0) T_softmax_expsum_shared[ i0_2] = T_softmax_expsum_shared[i0_2] + T.exp( A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32") for i1_0 in T.serial(1): for i1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_norm"): i0_3 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_1) T.where(i1_0 * 512 + i1_1 < 256) T.reads(A[i0_3, i1], T_softmax_maxelem_shared[i0_3], T_softmax_expsum_shared[i0_3]) T.writes(T_softmax_norm[i0_3, i1]) T.block_attr({"axis": 1}) T_softmax_norm[i0_3, i1] = ( T.exp(A[i0_3, i1] - T_softmax_maxelem_shared[i0_3], dtype="float32") / T_softmax_expsum_shared[i0_3])
def matmul_relu(a: T.handle, b: T.handle, d: T.handle) -> None: A = T.match_buffer(a, (1024, 1024)) B = T.match_buffer(b, (1024, 1024)) C = T.alloc_buffer((1024, 1024)) D = T.match_buffer(d, (1024, 1024)) for i, j, k in T.grid(1024, 1024, 1024): with T.block("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(1024, 1024): with T.block("relu"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = T.max(C[vi, vj], 0.0)
def main(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: disable=no-self-argument T.func_attr({"global_symbol": "main", "tir.noalias": True}) A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") D = T.match_buffer(d, (1024, 1024), "float32") C = T.alloc_buffer((1024, 1024), "float32") for i, j, k in T.grid(1024, 1024, 1024): with T.block("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(1024, 1024): with T.block("relu"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = T.max(C[vi, vj], 0.0)
def matmul_relu_ann1(a: T.handle, b: T.handle, d: T.handle) -> None: A = T.match_buffer(a, (1024, 1024)) B = T.match_buffer(b, (1024, 1024)) C = T.alloc_buffer((1024, 1024)) D = T.match_buffer(d, (1024, 1024)) for i in T.serial(0, 1024, annotations={"test1": "aaa"}): for j in T.serial(0, 1024, annotations={"test2": 612}): for k in T.serial(0, 1024): with T.block("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(1024, 1024): with T.block("relu"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = T.max(C[vi, vj], 0.0)
def matmul_relu_ann2(a: T.handle, b: T.handle, d: T.handle) -> None: A = T.match_buffer(a, (1024, 1024)) B = T.match_buffer(b, (1024, 1024)) C = T.alloc_buffer((1024, 1024)) D = T.match_buffer(d, (1024, 1024)) for i, j, k in T.grid(1024, 1024, 1024): with T.block("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 T.block_attr({"test1": "aaa"}) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(1024, 1024): with T.block("relu"): vi, vj = T.axis.remap("SS", [i, j]) T.block_attr({"test2": 0.22, "test3": ["aa", 1]}) D[vi, vj] = T.max(C[vi, vj], 0.0)
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True}) placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16") placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16") placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32") T_cast_3 = T.match_buffer(T_cast_2, [360000], dtype="int16") # body PaddedInput = T.allocate([360000], "int16", "global") for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): PaddedInput[i0_i1_fused * 4800 + i2 * 64 + i3] = placeholder_7[i0_i1_fused * 4800 + i2 * 64 + i3] for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): Conv2dOutput = T.allocate([64], "int32", "global") for ff in T.serial(0, 64): Conv2dOutput[ff] = 0 for rc in T.serial(0, 64): Conv2dOutput[ff] = Conv2dOutput[ff] + T.cast(PaddedInput[ax0_ax1_fused_ax2_fused * 64 + rc], "int32") * T.cast(placeholder_8[rc * 64 + ff], "int32") for ax3_inner_1 in T.serial(0, 64): T_cast_3[ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput[ax3_inner_1] + placeholder_9[ax3_inner_1], 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16")
def matmul_relu(var_A: T.handle, var_B: T.handle, var_compute: T.handle) -> None: A = T.match_buffer(var_A, [512, 512], dtype="float32") B = T.match_buffer(var_B, [512, 512], dtype="float32") compute = T.match_buffer(var_compute, [512, 512], dtype="float32") C = T.alloc_buffer([512, 512], dtype="float32") for i0, i1, i2 in T.grid(512, 512, 512): with T.block("C"): i, j, k = T.axis.remap("SSR", [i0, i1, i2]) T.reads([C[i, j], A[i, k], B[k, j]]) T.writes([C[i, j]]) with T.init(): C[i, j] = T.float32(0) C[i, j] = C[i, j] + A[i, k] * B[k, j] for i0, i1 in T.grid(512, 512): with T.block("compute"): i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) T.reads([C[i0_1, i1_1]]) T.writes([compute[i0_1, i1_1]]) compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0))