Example #1
0
 def threadpool_nested_parallel_loop(
         A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4),
                                                     "float32"]) -> None:
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     for i in T.parallel(4):
         for j in T.parallel(4):
             B[i, j] = A[i, j] * 2.0
Example #2
0
 def threadpool_nested_parallel_loop(
         A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4),
                                                     "float32"]) -> None:
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     for i in T.parallel(4):
         for j in T.parallel(4):
             T.store(B.data, i * 4 + j,
                     T.load("float32", A.data, i * 4 + j) * 2.0)
Example #3
0
def scatter_compute_parallelize(
    A: T.Buffer[(16,), "float32"], B: T.Buffer[(16,), "float32"]
) -> None:
    # body
    # with T.block("root")
    for i in T.parallel(8):
        with T.block("first_half"):
            vi = T.axis.spatial(16, 8 + i)
            T.reads(A[vi - 8])
            T.writes(B[vi])
            B[vi] = A[vi - 8]
    for i in T.parallel(8):
        with T.block("last_half"):
            vi = T.axis.spatial(16, i)
            T.reads(A[vi + 8])
            T.writes(B[vi])
            B[vi] = A[vi + 8]
Example #4
0
def element_wise_parallelized(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i0 in T.parallel(0, 128):
        for i1 in T.serial(0, 128):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i0, i1])
                B[vi, vj] = A[vi, vj] * 2.0
Example #5
0
def loop_syntax_sugar(a: T.handle) -> None:
    A = T.match_buffer(a, (128, 128, 128, 128))
    for i in T.serial(128):
        for j in T.parallel(128):
            for k in T.vectorized(128):
                for x in T.unroll(128):
                    for y in T.thread_binding(128, "threadIdx.x"):
                        for z in T.thread_binding(128, thread="threadIdx.x"):
                            A[i, j, k, x] = A[i, j, k, x] * 2.0
Example #6
0
 def elemwise_sum_parallel(a: T.handle, b: T.handle, c: T.handle, n: T.int32):
     T.func_attr({"global_symbol": "elemwise_sum_parallel", "tir.noalias": True})
     A = T.match_buffer(a, (n,), dtype="float32")
     B = T.match_buffer(b, (n,), dtype="float32")
     C = T.match_buffer(c, (n,), dtype="float32")
     for i in T.parallel(n):
         with T.block("C"):
             vi = T.axis.spatial(n, i)
             C[vi] = A[vi] + B[vi]
def element_wise_parallelized(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    for i0 in T.parallel(0, 128):
        for i1 in T.serial(0, 128):
            with T.block([128, 128], "B") as [vi, vj]:
                T.bind(vi, i0)
                T.bind(vj, i1)
                B[vi, vj] = A[vi, vj] * 2.0
def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: T.handle, placeholder_31: T.handle, placeholder_32: T.handle, T_cast_8: T.handle) -> None:
    # function attr dict
    T.func_attr({"global_symbol": "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True})
    placeholder_33 = T.match_buffer(placeholder_30, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1)
    placeholder_34 = T.match_buffer(placeholder_31, [3072], dtype="int16", elem_offset=0, align=128, offset_factor=1)
    placeholder_35 = T.match_buffer(placeholder_32, [16], dtype="int32", elem_offset=0, align=128, offset_factor=1)
    T_cast_9 = T.match_buffer(T_cast_8, [12544], dtype="int16", elem_offset=0, align=128, offset_factor=1)
    # body
    PaddedInput_3 = T.allocate([150528], "int16", "global")
    for i0_i1_fused_3 in T.parallel(0, 28):
        for i2_3, i3_3 in T.grid(28, 192):
            PaddedInput_3[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3) ] = placeholder_33[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)]
    for ax0_ax1_fused_ax2_fused_3 in T.parallel(0, 784):
        for ax3_2 in T.serial(0, 16):
            Conv2dOutput_3 = T.allocate([1], "int32", "global")
            Conv2dOutput_3[0] = 0
            for rc_3 in T.serial(0, 192):
                Conv2dOutput_3[0] = (Conv2dOutput_3[0] + (T.cast(PaddedInput_3[((ax0_ax1_fused_ax2_fused_3*192) + rc_3)], "int32")*T.cast(placeholder_34[((rc_3*16) + ax3_2)], "int32")))
            T_cast_9[((ax0_ax1_fused_ax2_fused_3*16) + ax3_2)] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_3[0] + placeholder_35[ax3_2]), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16")
Example #9
0
def element_wise_split_predicate_parallelized(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    for i in T.serial(0, 128):
        for j_0 in T.parallel(0, 13):
            for j_1 in T.serial(0, 10):
                with T.block("B"):
                    T.where(j_0 * 10 + j_1 < 128)
                    vi = T.axis.S(128, i)
                    vj = T.axis.S(128, j_0 * 10 + j_1)
                    B[vi, vj] = A[vi, vj] * 2.0
Example #10
0
def rowsum_not_serial(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128,))

    for i in T.serial(0, 128):
        for k in T.parallel(0, 128):
            with T.block("B"):
                vi, vk = T.axis.remap("SR", [i, k])
                with T.init():
                    B[vi] = 0.0
                B[vi] = B[vi] + A[vi, vk]
def rowsum_not_serial(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, ))

    for i in T.serial(0, 128):
        for k in T.parallel(0, 128):
            with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]:
                T.bind(vi, i)
                T.bind(vk, k)
                with T.init():
                    B[vi] = 0.0
                B[vi] = B[vi] + A[vi, vk]
 def main(placeholder: T.Buffer[(1, 16, 7, 7, 32), "float32"], placeholder_1: T.Buffer[(25088,), "float32"], T_layout_trans: T.Buffer[(1, 1, 7, 7, 512), "float32"]) -> None:
     # function attr dict
     T.func_attr({"tir.noalias": True, "global_symbol": "main"})
     # body
     # with T.block("root")
     for i0_i1_i2_i3_i4_fused in T.parallel(25088, annotations={"pragma_auto_unroll_max_step":64, "pragma_unroll_explicit":1}):
         with T.block("T_layout_trans_1"):
             ax0 = T.axis.spatial(1, 0)
             ax1 = T.axis.spatial(1, 0)
             ax2 = T.axis.spatial(7, i0_i1_i2_i3_i4_fused // 3584)
             ax3 = T.axis.spatial(7, i0_i1_i2_i3_i4_fused % 3584 // 512)
             ax4 = T.axis.spatial(512, i0_i1_i2_i3_i4_fused % 512)
             T.reads(placeholder[0, (ax4 * 49 + ax2 * 7 + ax3) % 25088 // 1568, (ax2 * 7 + ax3) % 49 // 7, ax3 % 7, (ax4 * 49 + ax2 * 7 + ax3) % 1568 // 49], placeholder_1[(ax4 * 49 + ax2 * 7 + ax3) % 25088])
             T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4])
             T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else(ax0 < 1 and ax1 * 512 + ax4 < 512 and ax2 < 7 and ax3 < 7, T.Select(T.float32(0) < T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32"), T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32"), T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32") * placeholder_1[((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088]), T.float32(0), dtype="float32")
Example #13
0
def Move_PUV0(a: T.handle, b: T.handle) -> None:
    # function attr dict
    T.func_attr({"global_symbol": "main"})
    A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32")
    B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32")
    # body
    with T.block("root"):
        for i0_j0_fused in T.parallel(0, 8192):
            for i1, j1, k0, i2, j2 in T.grid(4, 4, 64, 4, 8):
                for k1_fused in T.vectorized(0, 32):
                    with T.block("move"):
                        vi = T.axis.spatial(
                            1024, i0_j0_fused // 64 * 16 + i1 * 4 + i2)
                        vj = T.axis.spatial(
                            1024, i0_j0_fused % 64 * 32 + j1 * 8 + j2)
                        vk = T.axis.spatial(1024, k0 * 32 + k1_fused)
                        T.where(i0_j0_fused // 64 * 16 + i1 * 4 + i2 < 1024
                                and i0_j0_fused % 64 * 32 + j1 * 8 + j2 < 1024
                                and k0 * 32 + k1_fused < 1024)
                        T.reads([A[vi, vj, vk]])
                        T.writes([B[vi, vj, vk]])
                        B[vi, vj, vk] = A[vi, vj, vk]
def loops() -> None:
    for i in T.parallel(0, 2):
        for j in T.serial(0, 1):
            for z in T.vectorized(3, 4):
                T.evaluate(0)
def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(
        placeholder_30: T.handle, placeholder_31: T.handle,
        placeholder_32: T.handle, T_cast_8: T.handle) -> None:
    # function attr dict
    T.func_attr({
        "global_symbol":
        "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2",
        "tir.noalias": True
    })
    placeholder_33 = T.match_buffer(placeholder_30, [1, 28, 28, 192],
                                    dtype="int16",
                                    elem_offset=0,
                                    align=128,
                                    offset_factor=1)
    placeholder_34 = T.match_buffer(placeholder_31, [1, 1, 192, 16],
                                    dtype="int16",
                                    elem_offset=0,
                                    align=128,
                                    offset_factor=1)
    placeholder_35 = T.match_buffer(placeholder_32, [1, 1, 1, 16],
                                    dtype="int32",
                                    elem_offset=0,
                                    align=128,
                                    offset_factor=1)
    T_cast_9 = T.match_buffer(T_cast_8, [1, 28, 28, 16],
                              dtype="int16",
                              elem_offset=0,
                              align=128,
                              offset_factor=1)
    # body
    PaddedInput_3 = T.allocate([1, 28, 28, 192], "int16", "global")
    for i0_i1_fused_3 in T.parallel(0, 28):
        for i2_3, i3_3 in T.grid(28, 192):
            T.store(
                PaddedInput_3,
                (((i0_i1_fused_3 * 5376) + (i2_3 * 192)) + i3_3),
                T.load("int16", placeholder_33.data,
                       (((i0_i1_fused_3 * 5376) + (i2_3 * 192)) + i3_3)), True)
    for ax0_ax1_fused_ax2_fused_3 in T.parallel(0, 784):
        for ax3_2 in T.serial(0, 16):
            Conv2dOutput_3 = T.allocate([1, 1, 1, 1], "int32", "global")
            T.store(Conv2dOutput_3, 0, 0, True)
            for rc_3 in T.serial(0, 192):
                T.store(Conv2dOutput_3, 0,
                        (T.load("int32", Conv2dOutput_3, 0) + (T.cast(
                            T.load("int16", PaddedInput_3,
                                   ((ax0_ax1_fused_ax2_fused_3 * 192) + rc_3)),
                            "int32") * T.cast(
                                T.load("int16", placeholder_34.data,
                                       ((rc_3 * 16) + ax3_2)), "int32"))),
                        True)
            T.store(
                T_cast_9.data, ((ax0_ax1_fused_ax2_fused_3 * 16) + ax3_2),
                T.cast(
                    T.cast(
                        T.max(
                            T.min(
                                T.q_multiply_shift(
                                    (T.load("int32", Conv2dOutput_3, 0) +
                                     T.load("int32", placeholder_35.data,
                                            ax3_2)),
                                    1764006585,
                                    31,
                                    -7,
                                    dtype="int32"), 255), 0), "uint8"),
                    "int16"), True)