def main(A: T.handle, tensor: T.handle) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # buffer definition tensor_2 = T.buffer_decl([1, 10, 12, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) A_1 = T.match_buffer(A, [1, 12, 14, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) tensor_1 = T.match_buffer(tensor, [1, 8, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body T.realize(tensor_1[0:1, 0:8, 0:8, 0:16], "") T.realize(tensor_2[0:1, 0:6, 0:12, 0:16], "") for ax1_outer in T.serial(0, 2): for ax1 in T.serial(0, 6): for ax2 in T.serial(0, 12): for ax3 in T.serial(0, 16): if T.likely(((ax1_outer < 1) or (ax1 >= 2)), dtype='bool') : tensor_2[0, T.floormod((ax1 + (ax1_outer*4)), 6), ax2, ax3] = T.int8(0) for dh in T.serial(0, 3): for dw in T.serial(0, 3): if T.likely(((ax1_outer < 1) or (ax1 >= 2)), dtype='bool'): tensor_2[0, T.floormod((ax1 + (ax1_outer*4)), 6), ax2, ax3] = T.max(tensor_2[0, T.floormod((ax1 + (ax1_outer*4)), 6), ax2, ax3], A_1[0, ((ax1 + (ax1_outer*4)) + dh), (ax2 + dw), ax3]) for ax1_inner in T.serial(0, 4): for ax2_inner in T.serial(0, 8): for ax3_inner in T.serial(0, 16): tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner] = T.int8(0) for dh_1 in T.serial(0, 3): for dw_1 in T.serial(0, 5): tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner] = T.max(tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner], tensor_2[0, T.floormod(((ax1_inner + (ax1_outer*4)) + dh_1), 6), (ax2_inner + dw_1), ax3_inner])
def elementwise_fused(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) for fused in T.serial(0, 2097152): with T.block("B"): vi = T.axis.S(128, T.floordiv(fused, 16384)) vj = T.axis.S(128, T.floordiv(T.floormod(fused, 16384), 128)) vk = T.axis.S(128, T.floormod(fused, 128)) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True }) placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_2 = T.allocate([200704], "uint8", "global") for ax0_ax1_fused_4 in T.serial(0, 56): for ax2_4 in T.serial(0, 56): for ax3_init in T.serial(0, 64): T.store(tensor_2, (((ax0_ax1_fused_4 * 3584) + (ax2_4 * 64)) + ax3_init), T.uint8(0), True) for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): T.store( tensor_2, (((ax0_ax1_fused_4 * 3584) + (ax2_4 * 64)) + ax3_2), T.max( T.load("uint8", tensor_2, (((ax0_ax1_fused_4 * 3584) + (ax2_4 * 64)) + ax3_2)), T.if_then_else( ((((ax0_ax1_fused_4 * 2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4 * 2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, ((( ((ax0_ax1_fused_4 * 14336) + (T.floordiv(rv0_rv1_fused_1, 3) * 7168)) + (ax2_4 * 128)) + (T.floormod( rv0_rv1_fused_1, 3) * 64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) for ax0_ax1_fused_5 in T.serial(0, 56): for ax2_5, ax3_3 in T.grid(56, 64): T.store( T_cast_7.data, (((ax0_ax1_fused_5 * 3584) + (ax2_5 * 64)) + ax3_3), T.cast( T.load("uint8", tensor_2, (((ax0_ax1_fused_5 * 3584) + (ax2_5 * 64)) + ax3_3)), "int16"), True)
def elementwise_symbolic_fused(a: T.handle, b: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (128, 128, n)) B = T.match_buffer(b, (128, 128, n)) for i_j_k_fused in T.serial(0, (n * 16384)): with T.block("B"): vi = T.axis.S(128, T.floordiv(i_j_k_fused, n * 128)) vj = T.axis.S(128, T.floordiv(T.floormod(i_j_k_fused, n * 128), n)) vk = T.axis.S(n, T.floormod(i_j_k_fused, n)) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def elementwise_fuse_with_opaque_block(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128, 128, 128]) A = T.match_buffer(a, [128, 128, 128]) for i_j_k_fused in T.serial(0, 2097152): with T.block("opaque"): T.reads([ A[T.floormod(T.floordiv(T.floordiv(i_j_k_fused, 128), 128), 128 ), T.floormod(T.floordiv(i_j_k_fused, 128), 128), T.floormod(i_j_k_fused, 128), ] ]) T.writes([ B[T.floormod(T.floordiv(T.floordiv(i_j_k_fused, 128), 128), 128 ), T.floormod(T.floordiv(i_j_k_fused, 128), 128), T.floormod(i_j_k_fused, 128), ] ]) with T.block("B"): vi = T.axis.S(128, T.floordiv(i_j_k_fused, 16384)) vj = T.axis.S(128, T.floormod(T.floordiv(i_j_k_fused, 128), 128)) vk = T.axis.S(128, T.floormod(i_j_k_fused, 128)) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def elementwise_not_affine_fused(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [127, 128]) B = T.match_buffer(b, [127, 128]) for i in T.grid(4): for j_k_fused in T.serial(0, T.min(31, 126 - i * 32) * 128 + 128): with T.block("B"): vi = T.axis.S( 127, i * 32 + T.floormod(T.floordiv(j_k_fused, 128), T.min(31, 126 - i * 32) + 1), ) vj = T.axis.S(128, T.floormod(j_k_fused, 128)) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) B[vi, vj] = A[vi, vj]
def square_sum_square_root_rfactor(a: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [16, 256, 256]) D = T.match_buffer(d, [16]) C = T.alloc_buffer([16]) C_rf = T.alloc_buffer([1, 16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): with T.block("C_rf"): vi1_i2_fused_inner, b = T.axis.remap("SS", [i1_i2_fused_inner, i0]) i = T.axis.R(256, T.floordiv(i1_i2_fused_outer, 256)) j = T.axis.R(256, T.floormod(i1_i2_fused_outer, 256)) with T.init(): C_rf[vi1_i2_fused_inner, b] = 0.0 C_rf[vi1_i2_fused_inner, b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j]) for i0_1, i1_i2_fused_inner_1 in T.grid(16, 1): with T.block("C"): vi1_i2_fused_inner_1, b_1 = T.axis.remap("RS", [i1_i2_fused_inner_1, i0_1]) with T.init(): C[b_1] = 0.0 C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1] for i0_2 in T.serial(0, 16): with T.block("D"): b_2 = T.axis.S(16, i0_2) D[b_2] = T.sqrt(C[b_2], dtype="float32")
def opaque_access_fused(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16]) B = T.match_buffer(b, [16, 16]) for i_j_fused in T.serial(0, 256): with T.block("A"): vi = T.axis.S(16, T.floordiv(i_j_fused, 16)) vj = T.axis.S(16, T.floormod(i_j_fused, 16)) T.reads([]) T.writes([A[0:16, 0:16]]) A[vi, vj] = 1 for i_j_fused in T.serial(0, 256): with T.block("B"): vi = T.axis.S(16, T.floordiv(i_j_fused, 16)) vj = T.axis.S(16, T.floormod(i_j_fused, 16)) T.reads([]) T.writes([B[0:16, 0:16]]) T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle"))
def rowsum_transformed(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, )) for io, ii_ko_fused, ki in T.grid(32, 128, 4): with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: T.bind(vi, io * 4 + T.floordiv(ii_ko_fused, 32)) T.bind(vk, T.floormod(ii_ko_fused, 32) * 4 + ki) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk]
def rowsum_transformed(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128,)) for io, ii_ko_fused, ki in T.grid(32, 128, 4): with T.block("B"): vi = T.axis.S(128, io * 4 + T.floordiv(ii_ko_fused, 32)) vk = T.axis.R(128, T.floormod(ii_ko_fused, 32) * 4 + ki) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk]
def main(placeholder1: T.Buffer[(100, ), "int8"], placeholder2: T.Buffer[(100, ), "int8"]) -> None: T.attr("i0", "pragma_layout", "NHCWB16") for i0 in T.serial(0, 1): for i1 in T.serial(0, 1): for i2 in T.serial(0, 1): for i3 in T.serial(0, 6): for i4 in T.serial(0, 16): placeholder1[((i3 * 16) + i4)] = placeholder2[((T.floormod( (i3 + 4), 6) * 16) + i4)]
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1( placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True }) placeholder_13 = T.match_buffer(placeholder_10, [360000], dtype="int16") placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16") placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32") T_cast_5 = T.match_buffer(T_cast_4, [360000], dtype="int16") # body PaddedInput_1 = T.allocate([379456], "int16", "global") for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): PaddedInput_1[i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1] = T.if_then_else( 1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, placeholder_13[i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): Conv2dOutput_1 = T.allocate([64], "int32", "global") for ff_1 in T.serial(0, 64): Conv2dOutput_1[ff_1] = 0 for ry, rx, rc_1 in T.grid(3, 3, 64): Conv2dOutput_1[ff_1] = Conv2dOutput_1[ff_1] + T.cast( PaddedInput_1[ T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1], "int32") * T.cast( placeholder_14[ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1], "int32") for ax3_inner_2 in T.serial(0, 64): T_cast_5[ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2] = T.cast( T.cast( T.max( T.min( T.q_multiply_shift( Conv2dOutput_1[ax3_inner_2] + placeholder_15[ax3_inner_2], 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16")
def main(placeholder1: T.Buffer[(100, ), "int8"], placeholder2: T.Buffer[(100, ), "int8"]) -> None: T.attr("i0", "pragma_layout", "NHWC") for i0 in T.serial(0, 1): for i1 in T.serial(0, 5): for i2 in T.serial(0, 6): for i3 in T.serial(0, 4): placeholder1[( ((i1 * 24) + (i2 * 4)) + i3)] = placeholder2[(((((T.floordiv( (i1 - 1), 2) * 48) + (T.floormod( (i1 + 1), 2) * 24)) + (i2 * 4)) + i3) + 96)]
def transformed_square_sum_square_root_factor_one_2(a: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [16, 256, 256]) D = T.match_buffer(d, [16]) C = T.alloc_buffer([16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 1, 65536): with T.block("C"): b = T.axis.S(16, i0) i = T.axis.R(256, T.floordiv(i1_i2_fused_inner, 256)) j = T.axis.R(256, T.floormod(i1_i2_fused_inner, 256)) with T.init(): C[b] = 0.0 C[b] = C[b] + (A[b, i, j] * A[b, i, j]) for i0_1 in T.serial(0, 16): with T.block("D"): b_1 = T.axis.S(16, i0_1) D[b_1] = T.sqrt(C[b_1], dtype="float32")
def transformed_square_sum_square_root(a: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [16, 256, 256]) D = T.match_buffer(d, [16]) C = T.alloc_buffer([16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): with T.block([16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C") as [b, i, j]: T.bind(b, i0) T.bind(i, T.floordiv(i1_i2_fused_outer, 256)) T.bind(j, T.floormod(i1_i2_fused_outer, 256)) T.reads([C[b], A[b, i, j]]) T.writes([C[b]]) with T.init(): C[b] = 0.0 C[b] = C[b] + (A[b, i, j] * A[b, i, j]) for i0_1 in T.serial(0, 16): with T.block([16], "D") as [b_1]: T.bind(b_1, i0_1) T.reads([C[b_1]]) T.writes([D[b_1]]) D[b_1] = T.sqrt(C[b_1], dtype="float32")
def square_sum_square_root_rfactor(a: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [16, 256, 256]) D = T.match_buffer(d, [16]) C = T.alloc_buffer([16]) C_rf = T.alloc_buffer([1, 16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): with T.block( [1, 16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C_rf") as [ vi1_i2_fused_inner, b, i, j, ]: T.bind(vi1_i2_fused_inner, i1_i2_fused_inner) T.bind(b, i0) T.bind(i, T.floordiv(i1_i2_fused_outer, 256)) T.bind(j, T.floormod(i1_i2_fused_outer, 256)) with T.init(): C_rf[vi1_i2_fused_inner, b] = 0.0 C_rf[vi1_i2_fused_inner, b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j]) for i0_1, i1_i2_fused_inner_1 in T.grid(16, 1): with T.block([T.reduce_axis(0, 1), 16], "C") as [vi1_i2_fused_inner_1, b_1]: T.bind(vi1_i2_fused_inner_1, i1_i2_fused_inner_1) T.bind(b_1, i0_1) with T.init(): C[b_1] = 0.0 C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1] for i0_2 in T.serial(0, 16): with T.block([16], "D") as [b_2]: T.bind(b_2, i0_2) D[b_2] = T.sqrt(C[b_2], dtype="float32")
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast( placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True }) placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): T.store( PaddedInput_7, (((i0_i1_fused_7 * 687) + (i2_7 * 3)) + i3_7), T.if_then_else( ((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7 * 672) + (i2_7 * 3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True) for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): Conv2dOutput_7 = T.allocate([64], "int32", "global") for ff_3 in T.serial(0, 64): T.store(Conv2dOutput_7, ff_3, 0, True) for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): T.store( Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast( T.load("int16", PaddedInput_7, (( (((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112) * 1374) + (ry_2 * 687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112) * 6)) + (rx_2 * 3)) + rc_7)), "int32") * T.cast( T.load("int16", placeholder_66.data, ((((ry_2 * 1344) + (rx_2 * 192)) + (rc_7 * 64)) + ff_3)), "int32"))), True) for ax3_inner_7 in T.serial(0, 64): T.store( T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7 * 64) + ax3_inner_7), T.cast( T.max( T.min( T.q_multiply_shift( (T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True)
def conv2d_winograd_cuda( # type: ignore placeholder: T.Buffer[(1, 14, 14, 128), "float32"], # type: ignore placeholder_1: T.Buffer[(6, 6, 128, 128), "float32"], # type: ignore conv2d_winograd: T.Buffer[(1, 12, 12, 128), "float32"], # type: ignore ) -> None: # type: ignore data_pad = T.alloc_buffer([1, 16, 16, 128]) input_tile = T.alloc_buffer([6, 6, 9, 128]) B = T.alloc_buffer([6, 6]) data_pack = T.alloc_buffer([6, 6, 9, 128]) bgemm = T.alloc_buffer([6, 6, 9, 128]) A = T.alloc_buffer([6, 4]) inverse = T.alloc_buffer([4, 4, 9, 128]) for i0, i1, i2, i3 in T.grid(1, 16, 16, 128): with T.block("data_pad"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.block_attr({"schedule_rule": "None"}) T.reads([placeholder[i0_1, i1_1, i2_1, i3_1]]) T.writes([data_pad[i0_1, i1_1, i2_1, i3_1]]) data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( 0 <= i1_1 and i1_1 < 14 and 0 <= i2_1 and i2_1 < 14, # type: ignore placeholder[i0_1, i1_1, i2_1, i3_1], T.float32(0), dtype="float32", ) for i0_2, i1_2, i2_2, i3_2 in T.grid(6, 6, 9, 128): with T.block("input_tile"): eps, nu, p, ci = T.axis.remap("SSSS", [i0_2, i1_2, i2_2, i3_2]) T.block_attr({"schedule_rule": "None"}) T.reads( [ data_pad[ T.floordiv(p, 9), # type: ignore ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), # type: ignore ((T.floormod(p, 3) * 4) + nu), # type: ignore ci, ] ] ) T.writes([input_tile[eps, nu, p, ci]]) input_tile[eps, nu, p, ci] = data_pad[ T.floordiv(p, 9), # type: ignore ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), # type: ignore ((T.floormod(p, 3) * 4) + nu), # type: ignore ci, ] for i0_3, i1_3 in T.grid(6, 6): with T.block("B"): i, j = T.axis.remap("SS", [i0_3, i1_3]) T.block_attr({"schedule_rule": "meta_schedule.compute_inline"}) T.writes([B[i, j]]) # fmt: off B[i, j] = T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 0)), T.float32(0), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 5)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 4)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 3)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 2)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 0)), T.float32(1), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 5)), T.float32(-2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 4)), T.float32(-0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 3)), T.float32(2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 2)), T.float32(2.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 1)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 0)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 5)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 4)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 3)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 2)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 1)), T.float32(-2.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 0)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 4)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 3)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 2)), T.float32(-1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 0)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 5)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) # type: ignore # fmt: on for i0_4, i1_4, i2_3, i3_3, i4, i5 in T.grid(6, 6, 9, 128, 6, 6): with T.block("data_pack"): eps_1, nu_1, p_1, ci_1, r_a, r_b = T.axis.remap( "SSSSRR", [i0_4, i1_4, i2_3, i3_3, i4, i5] ) T.block_attr({"schedule_rule": "meta_schedule.winograd_data_pack.cuda"}) T.reads( [ data_pack[eps_1, nu_1, p_1, ci_1], input_tile[r_a, r_b, p_1, ci_1], B[ T.min(r_a, r_b) : ( # type: ignore T.min(r_a, r_b) + ((T.max(r_a, r_b) + 1) - T.min(r_a, r_b)) # type: ignore ), T.min(eps_1, nu_1) : ( # type: ignore T.min(eps_1, nu_1) + ((T.max(eps_1, nu_1) + 1) - T.min(eps_1, nu_1)) # type: ignore ), ], ] ) T.writes([data_pack[eps_1, nu_1, p_1, ci_1]]) with T.init(): data_pack[eps_1, nu_1, p_1, ci_1] = T.float32(0) data_pack[eps_1, nu_1, p_1, ci_1] = data_pack[eps_1, nu_1, p_1, ci_1] + ( (input_tile[r_a, r_b, p_1, ci_1] * B[r_a, eps_1]) * B[r_b, nu_1] ) for i0_5, i1_5, i2_4, i3_4, i4_1 in T.grid(6, 6, 9, 128, 128): with T.block("bgemm"): eps_2, nu_2, p_2, co, ci_2 = T.axis.remap("SSSSR", [i0_5, i1_5, i2_4, i3_4, i4_1]) T.block_attr({"meta_schedule.write_cache_level": [3]}) T.reads( [ bgemm[eps_2, nu_2, p_2, co], data_pack[eps_2, nu_2, p_2, ci_2], placeholder_1[eps_2, nu_2, co, ci_2], ] ) T.writes([bgemm[eps_2, nu_2, p_2, co]]) with T.init(): bgemm[eps_2, nu_2, p_2, co] = T.float32(0) bgemm[eps_2, nu_2, p_2, co] = bgemm[eps_2, nu_2, p_2, co] + ( data_pack[eps_2, nu_2, p_2, ci_2] * placeholder_1[eps_2, nu_2, co, ci_2] ) for i0_6, i1_6 in T.grid(6, 4): with T.block("A"): i_1, j_1 = T.axis.remap("SS", [i0_6, i1_6]) T.block_attr({"schedule_rule": "meta_schedule.compute_inline"}) T.writes([A[i_1, j_1]]) # fmt: off A[i_1, j_1] = T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 0)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 3)), T.float32(-8), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 2)), T.float32(4), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 1)), T.float32(-2), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 3)), T.float32(0.125), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 2)), T.float32(0.25), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 1)), T.float32(0.5), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 1)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 3)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 1)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 3)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))) # type: ignore # fmt: on for i0_7, i1_7, i2_5, i3_5, i4_2, i5_1 in T.grid(4, 4, 9, 128, 6, 6): with T.block("inverse"): vh, vw, p_3, co_1, r_a_1, r_b_1 = T.axis.remap( "SSSSRR", [i0_7, i1_7, i2_5, i3_5, i4_2, i5_1] ) T.block_attr({"schedule_rule": "meta_schedule.winograd_inverse"}) T.reads( [ inverse[vh, vw, p_3, co_1], bgemm[r_a_1, r_b_1, p_3, co_1], A[ T.min(r_a_1, r_b_1) : ( # type: ignore T.min(r_a_1, r_b_1) + ((T.max(r_a_1, r_b_1) + 1) - T.min(r_a_1, r_b_1)) # type: ignore ), T.min(vh, vw) : (T.min(vh, vw) + ((T.max(vh, vw) + 1) - T.min(vh, vw))), # type: ignore ], ] ) T.writes([inverse[vh, vw, p_3, co_1]]) with T.init(): inverse[vh, vw, p_3, co_1] = T.float32(0) inverse[vh, vw, p_3, co_1] = inverse[vh, vw, p_3, co_1] + ( (bgemm[r_a_1, r_b_1, p_3, co_1] * A[r_a_1, vh]) * A[r_b_1, vw] ) for i0_8, i1_8, i2_6, i3_6 in T.grid(1, 12, 12, 128): with T.block("conv2d_winograd"): n, h, w, co_2 = T.axis.remap("SSSS", [i0_8, i1_8, i2_6, i3_6]) T.reads( [ inverse[ T.floormod(h, 4), # type: ignore T.floormod(w, 4), # type: ignore (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), # type: ignore co_2, ] ] ) T.writes([conv2d_winograd[n, h, w, co_2]]) conv2d_winograd[n, h, w, co_2] = inverse[ T.floormod(h, 4), # type: ignore T.floormod(w, 4), # type: ignore (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), # type: ignore co_2, ]