Ejemplo n.º 1
0
 def main(A: T.Buffer[(256, 256), "float32"],
          T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None:
     T_softmax_maxelem = T.alloc_buffer([256], dtype="float32")
     T_softmax_expsum = T.alloc_buffer([256], dtype="float32")
     for i0, i1 in T.grid(256, 256):
         with T.block("T_softmax_maxelem"):
             i0_1, k = T.axis.remap("SR", [i0, i1])
             with T.init():
                 T_softmax_maxelem[i0_1] = T.min_value("float32")
             T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1],
                                             A[i0_1, k])
     for i0, i1 in T.grid(256, 256):
         with T.block("T_softmax_expsum"):
             i0_2, k = T.axis.remap("SR", [i0, i1])
             with T.init():
                 T_softmax_expsum[i0_2] = T.float32(0)
             T_softmax_expsum[i0_2] = T_softmax_expsum[i0_2] + T.exp(
                 A[i0_2, k] - T_softmax_maxelem[i0_2], dtype="float32")
     for i0_3, i1 in T.grid(256, 256):
         with T.block("T_softmax_norm"):
             i0_4, i1_1 = T.axis.remap("SS", [i0_3, i1])
             T_softmax_norm[i0_4,
                            i1_1] = T.exp(
                                A[i0_4, i1_1] - T_softmax_maxelem[i0_4],
                                dtype="float32") / T_softmax_expsum[i0_4]
Ejemplo n.º 2
0
def square_sum_square_root_factor_one_2_rfactor(
        A: T.Buffer[(16, 256, 256),
                    "float32"], D: T.Buffer[(16, ), "float32"]) -> None:
    C = T.alloc_buffer([16], dtype="float32")
    C_rf = T.alloc_buffer([16, 1], dtype="float32")
    for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 1, 65536):
        with T.block("C_rf"):
            b = T.axis.spatial(16, i0)
            i = T.axis.reduce(256, i1_i2_fused_inner // 256)
            j = T.axis.reduce(256, i1_i2_fused_inner % 256)
            vi1_i2_fused_outer = T.axis.spatial(1, i1_i2_fused_outer)
            with T.init():
                C_rf[b, vi1_i2_fused_outer] = T.float32(0)
            C_rf[b, vi1_i2_fused_outer] = C_rf[
                b, vi1_i2_fused_outer] + A[b, i, j] * A[b, i, j]
    for i0, i1_i2_fused_outer in T.grid(16, 1):
        with T.block("C"):
            b, vi1_i2_fused_outer = T.axis.remap("SR", [i0, i1_i2_fused_outer])
            with T.init():
                C[b] = T.float32(0)
            C[b] = C[b] + C_rf[b, vi1_i2_fused_outer]
    for i0_1 in T.serial(16):
        with T.block("D"):
            b_1 = T.axis.spatial(16, i0_1)
            D[b_1] = T.sqrt(C[b_1], dtype="float32")
Ejemplo n.º 3
0
def matmul_rfactor(a: T.handle, b: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    C = T.match_buffer(c, [128, 128])
    C_rf = T.alloc_buffer([4, 128, 128])

    for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(
            128, 128, 4, 8, 4):
        with T.block("update_rf"):
            vi2_inner_inner = T.axis.S(4, i2_inner_inner)
            vi = T.axis.S(128, i0)
            vj = T.axis.S(128, i1)
            vi2_outer = T.axis.R(4, i2_outer)
            vi2_inner_outer = T.axis.R(8, i2_inner_outer)
            with T.init():
                C_rf[vi2_inner_inner, vi, vj] = 0.0
            C_rf[vi2_inner_inner, vi,
                 vj] = C_rf[vi2_inner_inner, vi, vj] + (A[vi, (
                     ((vi2_outer * 32) +
                      (vi2_inner_outer * 4)) + vi2_inner_inner)] * B[vj, (
                          ((vi2_outer * 32) +
                           (vi2_inner_outer * 4)) + vi2_inner_inner)])

    for i0_1, i1_1, i2_inner_inner_1 in T.grid(128, 128, 4):
        with T.block("update"):
            vi2_inner_inner_1, vi_1, vj_1 = T.axis.remap(
                "RSS", [i2_inner_inner_1, i0_1, i1_1])
            with T.init():
                C[vi_1, vj_1] = 0.0
            C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1]
Ejemplo n.º 4
0
def square_sum_square_root_rfactor(a: T.handle, d: T.handle) -> None:
    A = T.match_buffer(a, [16, 256, 256])
    D = T.match_buffer(d, [16])
    C = T.alloc_buffer([16])
    C_rf = T.alloc_buffer([1, 16])

    for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1):
        with T.block("C_rf"):
            vi1_i2_fused_inner, b = T.axis.remap("SS", [i1_i2_fused_inner, i0])
            i = T.axis.R(256, T.floordiv(i1_i2_fused_outer, 256))
            j = T.axis.R(256, T.floormod(i1_i2_fused_outer, 256))
            with T.init():
                C_rf[vi1_i2_fused_inner, b] = 0.0
            C_rf[vi1_i2_fused_inner,
                 b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j])

    for i0_1, i1_i2_fused_inner_1 in T.grid(16, 1):
        with T.block("C"):
            vi1_i2_fused_inner_1, b_1 = T.axis.remap(
                "RS", [i1_i2_fused_inner_1, i0_1])
            with T.init():
                C[b_1] = 0.0
            C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1]

    for i0_2 in T.serial(0, 16):
        with T.block("D"):
            b_2 = T.axis.S(16, i0_2)
            D[b_2] = T.sqrt(C[b_2], dtype="float32")
Ejemplo n.º 5
0
def duplicate_init() -> None:
    for i, j in T.grid(16, 16):
        with T.block():
            vi, vj = T.axis.remap("SS", [i, j])
            with T.init():
                T.evaluate(1.0)
            with T.init():  # error
                T.evaluate(1.0)
Ejemplo n.º 6
0
def multiple_reduction_blocks_rfactor(a: T.handle, f: T.handle) -> None:
    A = T.match_buffer(a, [16, 16, 16])
    C = T.alloc_buffer([16, 16])
    D = T.alloc_buffer([16, 16])
    E = T.alloc_buffer([16, 16])
    F = T.match_buffer(f, [16, 16])
    C_rf = T.alloc_buffer([16, 16, 4])

    for i, j1, k1o, k1i in T.grid(16, 16, 4, 4):
        with T.block([4, 16, 16, T.reduce_axis(0, 4)],
                     "C_rf") as [vk1o, ci, cj, vk1i]:
            T.bind(vk1o, k1o)
            T.bind(ci, i)
            T.bind(cj, j1)
            T.bind(vk1i, k1i)
            with T.init():
                C_rf[ci, cj, vk1o] = 0.0
            C_rf[ci, cj, vk1o] = C_rf[ci, cj, vk1o] + A[ci, cj,
                                                        ((vk1o * 4) + vk1i)]
    for i_1 in T.serial(0, 16):
        for j1_1 in T.serial(0, 16):
            for k1o_1 in T.serial(0, 4):
                with T.block([T.reduce_axis(0, 4), 16, 16],
                             "C") as [vk1o_1, ci_1, cj_1]:
                    T.bind(vk1o_1, k1o_1)
                    T.bind(ci_1, i_1)
                    T.bind(cj_1, j1_1)
                    with T.init():
                        C[ci_1, cj_1] = 0.0
                    C[ci_1, cj_1] = C[ci_1, cj_1] + C_rf[ci_1, cj_1, vk1o_1]
            for k2o, k2i in T.grid(4, 4):
                with T.block([16, 16, T.reduce_axis(0, 16)],
                             "D") as [di, dj, dk]:
                    T.bind(di, i_1)
                    T.bind(dj, j1_1)
                    T.bind(dk, (k2o * 4) + k2i)
                    with T.init():
                        D[di, dj] = 0.0
                    D[di, dj] = (D[di, dj] + A[di, dj, dk]) + C[di, dj]
        for j2 in T.serial(0, 16):
            for k3o, k3i in T.grid(4, 4):
                with T.block([16, 16, T.reduce_axis(0, 16)],
                             "E") as [ei, ej, ek]:
                    T.bind(ei, i_1)
                    T.bind(ej, j2)
                    T.bind(ek, (k3o * 4) + k3i)
                    with T.init():
                        E[ei, ej] = 0.0
                    E[ei, ej] = (E[ei, ej] + A[ei, ej, ek]) + D[ei, ej]
            for k4o, k4i in T.grid(4, 4):
                with T.block([16, 16, T.reduce_axis(0, 16)],
                             "F") as [fi, fj, fk]:
                    T.bind(fi, i_1)
                    T.bind(fj, j2)
                    T.bind(fk, (k4o * 4) + k4i)
                    with T.init():
                        F[fi, fj] = 0.0
                    F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj]
Ejemplo n.º 7
0
def softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None:
    A = T.match_buffer(var_A, [256, 256], dtype="float32")
    T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256], dtype="float32")
    T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
    T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
    for i0 in T.thread_binding(0, 256, thread="blockIdx.x"):
        for ax0_0 in T.serial(0, 8):
            for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
                with T.block("T_softmax_maxelem"):
                    i0_1 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
                    T.reads([T_softmax_maxelem_shared[i0_1], A[i0_1, k]])
                    T.writes([T_softmax_maxelem_shared[i0_1]])
                    with T.init():
                        T_softmax_maxelem_shared[i0_1] = T.min_value("float32")
                    T_softmax_maxelem_shared[i0_1] = T.max(
                        T_softmax_maxelem_shared[i0_1], A[i0_1, k]
                    )
        for ax0_0 in T.serial(0, 8):
            for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
                with T.block("T_softmax_expsum"):
                    i0_2 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
                    T.reads(
                        [
                            T_softmax_expsum_shared[i0_2],
                            A[i0_2, k],
                            T_softmax_maxelem_shared[i0_2],
                        ]
                    )
                    T.writes([T_softmax_expsum_shared[i0_2]])
                    with T.init():
                        T_softmax_expsum_shared[i0_2] = T.float32(0)
                    T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp(
                        A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32"
                    )
        for i1_0 in T.serial(0, 8):
            for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
                with T.block("T_softmax_norm"):
                    i0_3 = T.axis.spatial(256, i0)
                    i1 = T.axis.spatial(256, i1_0 * 32 + i1_1)
                    T.reads(
                        [
                            A[i0_3, i1],
                            T_softmax_maxelem_shared[i0_3],
                            T_softmax_expsum_shared[i0_3],
                        ]
                    )
                    T.writes([T_softmax_norm[i0_3, i1]])
                    T.block_attr({"axis": 1})
                    T_softmax_norm[i0_3, i1] = (
                        T.exp(
                            A[i0_3, i1] - T_softmax_maxelem_shared[i0_3],
                            dtype="float32",
                        )
                        / T_softmax_expsum_shared[i0_3]
                    )
Ejemplo n.º 8
0
def single_reduction_loop_with_block_predicate(
        A: T.Buffer[(256, 256), "float32"],
        T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None:
    T_softmax_maxelem_shared = T.alloc_buffer([256],
                                              dtype="float32",
                                              scope="shared")
    T_softmax_expsum_shared = T.alloc_buffer([256],
                                             dtype="float32",
                                             scope="shared")
    for i0 in T.serial(256):
        for ax0, ax1_0 in T.grid(1, 1):
            for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
                with T.block("T_softmax_maxelem"):
                    i0_1 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax1_1)
                    T.where(ax1_0 * 512 + ax1_1 < 256)
                    T.reads(T_softmax_maxelem_shared[i0_1], A[i0_1, k])
                    T.writes(T_softmax_maxelem_shared[i0_1])
                    with T.init():
                        T_softmax_maxelem_shared[i0_1] = T.float32(
                            -3.4028234663852886e38)
                    T_softmax_maxelem_shared[i0_1] = T.max(
                        T_softmax_maxelem_shared[i0_1], A[i0_1, k])
        for ax0, ax1_0 in T.grid(1, 1):
            for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
                with T.block("T_softmax_expsum"):
                    i0_2 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax1_1)
                    T.where(ax1_0 * 512 + ax1_1 < 256)
                    T.reads(T_softmax_expsum_shared[i0_2], A[i0_2, k],
                            T_softmax_maxelem_shared[i0_2])
                    T.writes(T_softmax_expsum_shared[i0_2])
                    with T.init():
                        T_softmax_expsum_shared[i0_2] = T.float32(0)
                    T_softmax_expsum_shared[
                        i0_2] = T_softmax_expsum_shared[i0_2] + T.exp(
                            A[i0_2, k] - T_softmax_maxelem_shared[i0_2],
                            dtype="float32")
        for i1_0 in T.serial(1):
            for i1_1 in T.thread_binding(512, thread="threadIdx.x"):
                with T.block("T_softmax_norm"):
                    i0_3 = T.axis.spatial(256, i0)
                    i1 = T.axis.spatial(256, i1_1)
                    T.where(i1_0 * 512 + i1_1 < 256)
                    T.reads(A[i0_3, i1], T_softmax_maxelem_shared[i0_3],
                            T_softmax_expsum_shared[i0_3])
                    T.writes(T_softmax_norm[i0_3, i1])
                    T.block_attr({"axis": 1})
                    T_softmax_norm[i0_3, i1] = (
                        T.exp(A[i0_3, i1] - T_softmax_maxelem_shared[i0_3],
                              dtype="float32") / T_softmax_expsum_shared[i0_3])
Ejemplo n.º 9
0
def rowsum_zero_dim_rfactor(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128])
    B = T.match_buffer(b, [])
    B_rf = T.alloc_buffer([128])

    with T.block([128], "B_rf") as [vi0]:
        with T.init():
            B_rf[vi0] = 0.0
        B_rf[vi0] = B_rf[vi0] + A[vi0]

    with T.block([T.reduce_axis(0, 128)], "B") as [vi0_1]:
        with T.init():
            B[()] = 0.0
        B[()] = B[()] + B_rf[vi0_1]
Ejemplo n.º 10
0
def before_matmul_vectorize(
    placeholder: T.Buffer[(64, 768), "float32"],
    placeholder_1: T.Buffer[(768, 768), "float32"],
    T_matmul_NT: T.Buffer[(64, 768), "float32"],
) -> None:
    with T.block("root"):
        T.reads()
        T.writes()
        T.block_attr({"meta_schedule.vectorize": 64})
        T_matmul_NT_global = T.alloc_buffer([64, 768], dtype="float32")
        for i0_0, i1_0, i0_1, i1_1 in T.grid(1, 16, 1, 3):
            for i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid(
                    48, 8, 1, 16, 8, 16):
                with T.block("T_matmul_NT"):
                    i = T.axis.spatial(64, i0_2 * 8 + i0_3)
                    j = T.axis.spatial(768, i1_0 * 48 + i1_1 * 16 + i1_3)
                    k = T.axis.reduce(768, i2_0 * 16 + i2_1)
                    T.reads(placeholder[i, k], placeholder_1[j, k])
                    T.writes(T_matmul_NT_global[i, j])
                    with T.init():
                        T_matmul_NT_global[i, j] = T.float32(0)
                    T_matmul_NT_global[i, j] = T_matmul_NT_global[
                        i, j] + placeholder[i, k] * placeholder_1[j, k]
            for ax0, ax1 in T.grid(64, 16):
                with T.block("T_matmul_NT_global"):
                    v0 = T.axis.spatial(64, ax0)
                    v1 = T.axis.spatial(768, i1_0 * 48 + i1_1 * 16 + ax1)
                    T.reads(T_matmul_NT_global[v0, v1])
                    T.writes(T_matmul_NT[v0, v1])
                    T_matmul_NT[v0, v1] = T_matmul_NT_global[v0, v1]
Ejemplo n.º 11
0
def rowsum(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,), "float32"]) -> None:
    for k, i in T.grid(128, 128):
        with T.block("B"):
            vk, vi = T.axis.remap("RS", [k, i])
            with T.init():
                B[vi] = 0.0
            B[vi] = B[vi] + A[vi, vk]
Ejemplo n.º 12
0
 def cascade_pool_ops(
     x: T.Buffer[(1, 16, 112, 112), "float32"], y2: T.Buffer[(1, 16, 108, 108), "float32"]
 ) -> None:
     y1 = T.alloc_buffer([1, 16, 110, 110], dtype="float32")
     for n, c, h, w, kh, kw in T.grid(1, 16, 110, 110, 3, 3):
         with T.block("pool_0"):
             ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [n, c, h, w, kh, kw])
             with T.init():
                 y1[ax0, ax1, ax2, ax3] = 0.0
             y1[ax0, ax1, ax2, ax3] = y1[ax0, ax1, ax2, ax3] + x[ax0, ax1, ax2 + rv0, ax3 + rv1]
     for n, c, h, w, kh, kw in T.grid(1, 16, 108, 108, 3, 3):
         with T.block("pool_1"):
             ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [n, c, h, w, kh, kw])
             with T.init():
                 y2[ax0, ax1, ax2, ax3] = 0.0
             y2[ax0, ax1, ax2, ax3] = y2[ax0, ax1, ax2, ax3] + y1[ax0, ax1, ax2 + rv0, ax3 + rv1]
Ejemplo n.º 13
0
def conv2d_nhwc(
    Input: T.Buffer[(1, 224, 224, 3), "float32"],
    Weight: T.Buffer[(7, 7, 3, 64), "float32"],
    Conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"],
) -> None:
    PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32")
    for i0, i1, i2, i3 in T.grid(1, 230, 230, 3):
        with T.block("PadInput"):
            i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
            PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
                ((((i1_1 >= 3) and (i1_1 < 227)) and (i2_1 >= 3)) and
                 (i2_1 < 227)),
                Input[i0_1, (i1_1 - 3), (i2_1 - 3), i3_1],
                T.float32(0),
                dtype="float32",
            )
    for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3):
        with T.block("conv2d_nhwc"):
            n, h, w, co, rh, rw, rc = T.axis.remap(
                "SSSSRRR", [i0, i1, i2, i3, i4, i5, i6])
            with T.init():
                Conv2d_nhwc[n, h, w, co] = T.float32(0)
            Conv2d_nhwc[n, h, w, co] = Conv2d_nhwc[n, h, w, co] + (
                PadInput[n, ((h * 2) + rh), ((w * 2) + rw), (
                    (T.floordiv(co, 64) * 3) + rc)] * Weight[rh, rw, rc, co])
Ejemplo n.º 14
0
def after_matmul_vectorize(
    placeholder: T.Buffer[(64, 768), "float32"],
    placeholder_1: T.Buffer[(768, 768), "float32"],
    T_matmul_NT: T.Buffer[(64, 768), "float32"],
) -> None:
    T.func_attr({
        "global_symbol": "main",
        "tir.noalias": True,
        "layout_free_placeholders": [1]
    })
    T_matmul_NT_global = T.alloc_buffer([64, 768], dtype="float32")
    for i0_0, i1_0, i0_1, i1_1 in T.grid(1, 16, 1, 3):
        for i2_0, i0_2, i1_2, i2_1, i0_3 in T.grid(48, 8, 1, 16, 8):
            for i1_3_fused in T.vectorized(16):
                with T.block("T_matmul_NT"):
                    i = T.axis.spatial(64, i0_2 * 8 + i0_3)
                    j = T.axis.spatial(768, i1_0 * 48 + i1_1 * 16 + i1_3_fused)
                    k = T.axis.reduce(768, i2_0 * 16 + i2_1)
                    T.reads(placeholder[i, k], placeholder_1[j, k])
                    T.writes(T_matmul_NT_global[i, j])
                    with T.init():
                        T_matmul_NT_global[i, j] = T.float32(0)
                    T_matmul_NT_global[i, j] = T_matmul_NT_global[
                        i, j] + placeholder[i, k] * placeholder_1[j, k]
        for ax0 in T.serial(64):
            for ax1_fused in T.vectorized(16):
                with T.block("T_matmul_NT_global"):
                    v0 = T.axis.spatial(64, ax0)
                    v1 = T.axis.spatial(768, i1_0 * 48 + i1_1 * 16 + ax1_fused)
                    T.reads(T_matmul_NT_global[v0, v1])
                    T.writes(T_matmul_NT[v0, v1])
                    T_matmul_NT[v0, v1] = T_matmul_NT_global[v0, v1]
def conv2d_nhwc_transformed(
    Input: T.Buffer[(1, 224, 224, 3), "float32"],
    Weight: T.Buffer[(7, 7, 3, 64), "float32"],
    Conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"],
) -> None:
    PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32")
    for i0, i1, i2, i3 in T.grid(1, 230, 230, 3):
        with T.block("PadInput"):
            i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
            T.reads(Input[i0_1, i1_1 - 3, i2_1 - 3, i3_1])
            T.writes(PadInput[i0_1, i1_1, i2_1, i3_1])
            PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
                i1_1 >= 3 and i1_1 < 227 and i2_1 >= 3 and i2_1 < 227,
                Input[i0_1, i1_1 - 3, i2_1 - 3, i3_1],
                T.float32(0),
                dtype="float32",
            )
    for ax0, ax_1, ax_2 in T.grid(12544, 64, 147):
        with T.block("conv2d_nhwc"):
            bv0, bv1, bv2 = T.axis.remap("SSR", [ax0, ax_1, ax_2])
            T.reads(
                PadInput[0, bv0 // 112 * 2 + bv2 // 21,
                         bv0 % 112 * 2 + bv2 % 21 // 3, bv2 % 3],
                Weight[bv2 // 21, bv2 % 21 // 3, bv2 % 3, bv1],
            )
            T.writes(Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1])
            with T.init():
                Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1] = T.float32(0)
            Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1] = (
                Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1] +
                PadInput[0, bv0 // 112 * 2 + bv2 // 21,
                         bv0 % 112 * 2 + bv2 % 21 // 3, bv2 % 3] *
                Weight[bv2 // 21, bv2 % 21 // 3, bv2 % 3, bv1])
 def main(
     placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
     placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"],
     conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"],
 ) -> None:
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4):
         with T.block("conv2d_NCHWc_int8"):
             (
                 n,
                 oc_chunk,
                 oh,
                 ow,
                 oc_block,
                 kh,
                 kw,
                 ic_outer,
                 ic_f_inner,
                 ic_s_inner,
             ) = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9])
             T.reads(
                 placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner],
                 placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
             )
             T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block])
             with T.init():
                 conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0
             conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[
                 n, oc_chunk, oh, ow, oc_block
             ] + T.cast(
                 placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32"
             ) * T.cast(
                 placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
                 "int32",
             )
Ejemplo n.º 17
0
def rfactor_spatial_only_after(
    A: T.Buffer[(1, 512, 7, 7), "float32"],
    B: T.Buffer[(1, 512, 1, 1), "float32"],
) -> None:
    # body
    # with T.block("root")
    B_rf = T.alloc_buffer([1, 512, 1, 1, 49], dtype="float32")
    for _i0, i1, _i2, _i3, i4, _i5 in T.grid(1, 512, 1, 1, 49, 1):
        with T.block("acc_rf"):
            vi4 = T.axis.spatial(49, i4)
            ax0 = T.axis.spatial(1, 0)
            ax1 = T.axis.spatial(512, i1)
            ax2 = T.axis.spatial(1, 0)
            ax3 = T.axis.spatial(1, 0)
            B_rf[ax0, ax1, ax2, ax3, vi4] = A[ax0, ax1, ax2 * 7 + vi4 // 7, ax3 * 7 + vi4 % 7]
    for _i0, i1, _i2, _i3, i4, _i5 in T.grid(1, 512, 1, 1, 49, 1):
        with T.block("acc"):
            vi4 = T.axis.reduce(49, i4)
            ax0 = T.axis.spatial(1, 0)
            ax1 = T.axis.spatial(512, i1)
            ax2 = T.axis.spatial(1, 0)
            ax3 = T.axis.spatial(1, 0)
            with T.init():
                B[ax0, ax1, ax2, ax3] = T.float32(0)
            B[ax0, ax1, ax2, ax3] = B[ax0, ax1, ax2, ax3] + B_rf[ax0, ax1, ax2, ax3, vi4]
def non_perfect_tiling_cache(a: T.handle, b: T.handle) -> None:
    X = T.match_buffer(a, [224, 224], dtype="float32")
    Y = T.match_buffer(b, [224, 224], dtype="float32")
    cache = T.alloc_buffer([224, 224], dtype="float32")
    for hh_0, ww_0 in T.grid(28, 28):
        for ax0 in T.serial(0, 10):
            for ax1 in T.serial(0, 10):
                with T.block("cache"):
                    h = T.axis.spatial(224, hh_0 * 8 - 1 + ax0)
                    w = T.axis.spatial(224, ww_0 * 8 - 1 + ax1)
                    T.where(1 <= hh_0 * 8 + ax0 and hh_0 * 8 + ax0 < 225
                            and 1 <= ww_0 * 8 + ax1 and ww_0 * 8 + ax1 < 225)
                    cache[h, w] = X[h, w]
        for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3):
            with T.block("compute"):
                h = T.axis.spatial(224, hh_0 * 8 + hh_1)
                w = T.axis.spatial(224, ww_0 * 8 + ww_1)
                kh, kw = T.axis.remap("RR", [khh, kww])
                with T.init():
                    Y[h, w] = 0.0
                Y[h, w] = T.max(
                    Y[h, w],
                    T.if_then_else(
                        T.likely(1 <= h + kh, dtype="bool")
                        and T.likely(h + kh < 225, dtype="bool")
                        and T.likely(1 <= w + kw, dtype="bool")
                        and T.likely(w + kw < 225, dtype="bool"),
                        cache[h + kh - 1, w + kw - 1],
                        0.0,
                        dtype="float32",
                    ),
                )
Ejemplo n.º 19
0
def tiled_conv2d_with_padding(
    inputs: T.Buffer[(1, 224, 224, 3), "float32"],
    weight: T.Buffer[(7, 7, 3, 64), "float32"],
    conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"],
) -> None:
    PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32")
    for i0, i1, i2, i3 in T.grid(1, 230, 230, 3):
        with T.block("PadInput"):
            i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
            T.reads(inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1])
            T.writes(PadInput[i0_1, i1_1, i2_1, i3_1])
            PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
                3 <= i1_1 and i1_1 < 227 and 3 <= i2_1 and i2_1 < 227,
                inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1],
                T.float32(0),
                dtype="float32",
            )
    for (
            i0_0,
            i1_0,
            i2_0,
            i3_0,
            i0_1_1,
            i1_1_1,
            i2_1_1,
            i3_1_1,
            i4_0,
            i5_0,
            i6_0,
            i0_2,
            i1_2,
            i2_2,
            i3_2,
            i4_1,
            i5_1,
            i6_1,
            i0_3,
            i1_3,
            i2_3,
            i3_3,
    ) in T.grid(1, 1, 4, 1, 1, 2, 4, 1, 7, 7, 1, 1, 1, 1, 1, 1, 1, 3, 1, 56, 7,
                64):
        with T.block("conv2d_nhwc"):
            n = T.axis.spatial(1, 0)
            h = T.axis.spatial(112, i1_1_1 * 56 + i1_3)
            w = T.axis.spatial(112, i2_0 * 28 + i2_1_1 * 7 + i2_3)
            co, rh, rw, rc = T.axis.remap("SRRR", [i3_3, i4_0, i5_0, i6_1])
            T.reads(
                conv2d_nhwc[n, h, w, co],
                PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc],
                weight[rh, rw, rc, co],
            )
            T.writes(conv2d_nhwc[n, h, w, co])
            with T.init():
                conv2d_nhwc[n, h, w, co] = T.float32(0)
            conv2d_nhwc[n, h, w, co] = (
                conv2d_nhwc[n, h, w, co] +
                PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc] *
                weight[rh, rw, rc, co])
Ejemplo n.º 20
0
def multiple_reduction_blocks(a: T.handle, f: T.handle) -> None:
    A = T.match_buffer(a, (16, 16, 16))
    C = T.alloc_buffer((16, 16))
    D = T.alloc_buffer((16, 16))
    E = T.alloc_buffer((16, 16))
    F = T.match_buffer(f, (16, 16))

    for i in T.serial(0, 16):
        for j1 in T.serial(0, 16):
            for k1o, k1i in T.grid(4, 4):
                with T.block([16, 16, T.reduce_axis(0, 16)],
                             "C") as [ci, cj, ck]:
                    T.bind(ci, i)
                    T.bind(cj, j1)
                    T.bind(ck, k1o * 4 + k1i)
                    with T.init():
                        C[ci, cj] = 0.0
                    C[ci, cj] = C[ci, cj] + A[ci, cj, ck]
            for k2o, k2i in T.grid(4, 4):
                with T.block([16, 16, T.reduce_axis(0, 16)],
                             "D") as [di, dj, dk]:
                    T.bind(di, i)
                    T.bind(dj, j1)
                    T.bind(dk, k2o * 4 + k2i)
                    with T.init():
                        D[di, dj] = 0.0
                    D[di, dj] = D[di, dj] + A[di, dj, dk] + C[di, dj]
        for j2 in T.serial(0, 16):
            for k3o, k3i in T.grid(4, 4):
                with T.block([16, 16, T.reduce_axis(0, 16)],
                             "E") as [ei, ej, ek]:
                    T.bind(ei, i)
                    T.bind(ej, j2)
                    T.bind(ek, k3o * 4 + k3i)
                    with T.init():
                        E[ei, ej] = 0.0
                    E[ei, ej] = E[ei, ej] + A[ei, ej, ek] + D[ei, ej]
            for k4o, k4i in T.grid(4, 4):
                with T.block([16, 16, T.reduce_axis(0, 16)],
                             "F") as [fi, fj, fk]:
                    T.bind(fi, i)
                    T.bind(fj, j2)
                    T.bind(fk, k4o * 4 + k4i)
                    with T.init():
                        F[fi, fj] = 0.0
                    F[fi, fj] = F[fi, fj] + A[fi, fj, fk] + E[fi, fj]
Ejemplo n.º 21
0
def rowsum_not_compact_data_flow(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128,))

    with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]:
        with T.init():
            B[vk] = 0.0
        B[vk] = B[vk] + A[vi, vk]
Ejemplo n.º 22
0
def rowsum_zero_dim(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128])
    B = T.match_buffer(b, [])

    with T.block([T.reduce_axis(0, 128)], "B") as [k]:
        with T.init():
            B[()] = 0.0
        B[()] = B[()] + A[k]
Ejemplo n.º 23
0
def matmul_loop_multiple_children(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None:
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    C = T.match_buffer(c, [128, 128])
    D = T.match_buffer(d, [128, 128])

    for k, i, j in T.grid(128, 128, 128):
        with T.block("C"):
            ck, ci, cj = T.axis.remap("RSS", [k, i, j])
            with T.init():
                C[ci, cj] = 0.0
            C[ci, cj] = C[ci, cj] + A[ci, ck] * B[ck, cj]
        with T.block("D"):
            dk, di, dj = T.axis.remap("RSS", [k, i, j])
            with T.init():
                D[di, dj] = 0.0
            D[di, dj] = D[di, dj] + B[di, dk] * A[dk, dj]
Ejemplo n.º 24
0
def rowsum_wrong_reduce_pattern2(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, ))

    with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]:
        with T.init():
            B[vi] = 0.0
        B[vi] = B[vi] - A[vi, vk]
Ejemplo n.º 25
0
def multiple_reduction_blocks_rfactor(a: T.handle, f: T.handle) -> None:
    A = T.match_buffer(a, [16, 16, 16])
    C = T.alloc_buffer([16, 16])
    D = T.alloc_buffer([16, 16])
    E = T.alloc_buffer([16, 16])
    F = T.match_buffer(f, [16, 16])
    C_rf = T.alloc_buffer([16, 16, 4])

    for i, j1, k1o, k1i in T.grid(16, 16, 4, 4):
        with T.block("C_rf"):
            vk1o, ci, cj, vk1i = T.axis.remap("SSSR", [k1o, i, j1, k1i])
            with T.init():
                C_rf[ci, cj, vk1o] = 0.0
            C_rf[ci, cj, vk1o] = C_rf[ci, cj, vk1o] + A[ci, cj,
                                                        ((vk1o * 4) + vk1i)]
    for i_1 in T.serial(0, 16):
        for j1_1 in T.serial(0, 16):
            for k1o_1 in T.serial(0, 4):
                with T.block("C"):
                    vk1o_1, ci_1, cj_1 = T.axis.remap("RSS",
                                                      [k1o_1, i_1, j1_1])
                    with T.init():
                        C[ci_1, cj_1] = 0.0
                    C[ci_1, cj_1] = C[ci_1, cj_1] + C_rf[ci_1, cj_1, vk1o_1]
            for k2o, k2i in T.grid(4, 4):
                with T.block("D"):
                    di, dj = T.axis.remap("SS", [i_1, j1_1])
                    dk = T.axis.R(16, k2o * 4 + k2i)
                    with T.init():
                        D[di, dj] = 0.0
                    D[di, dj] = (D[di, dj] + A[di, dj, dk]) + C[di, dj]
        for j2 in T.serial(0, 16):
            for k3o, k3i in T.grid(4, 4):
                with T.block("E"):
                    ei, ej = T.axis.remap("SS", [i_1, j2])
                    ek = T.axis.R(16, k3o * 4 + k3i)
                    with T.init():
                        E[ei, ej] = 0.0
                    E[ei, ej] = (E[ei, ej] + A[ei, ej, ek]) + D[ei, ej]
            for k4o, k4i in T.grid(4, 4):
                with T.block("F"):
                    fi, fj = T.axis.remap("SS", [i_1, j2])
                    fk = T.axis.R(16, k4o * 4 + k4i)
                    with T.init():
                        F[fi, fj] = 0.0
                    F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj]
Ejemplo n.º 26
0
def rowsum_predicate_rfactor(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128], dtype="float32")
    B = T.match_buffer(b, [128], dtype="float32")
    B_rf = T.alloc_buffer([128, 13], dtype="float32")
    for i, k_0, k_1 in T.grid(128, 13, 10):
        with T.block("B_rf"):
            vk_0, vi, vk_1 = T.axis.remap("SSR", [k_0, i, k_1])
            T.where(k_0 * 10 + k_1 < 128)
            with T.init():
                B_rf[vi, vk_0] = T.float32(0)
            B_rf[vi, vk_0] = B_rf[vi, vk_0] + A[vi, vk_0 * 10 + vk_1]
    for i, k_0 in T.grid(128, 13):
        with T.block("B"):
            vk_0, vi = T.axis.remap("RS", [k_0, i])
            with T.init():
                B[vi] = T.float32(0)
            B[vi] = B[vi] + B_rf[vi, vk_0]
Ejemplo n.º 27
0
def lowered_multiple_blocks_under_reduction_loop(a: T.handle,
                                                 b: T.handle) -> None:
    A = T.match_buffer(a, [16, 16, 16], dtype="float32")
    B = T.match_buffer(b, [16], dtype="float32")
    B_rf_local = T.alloc_buffer([16, 16], dtype="float32", scope="local")
    reduce_temp0 = T.alloc_buffer([1],
                                  dtype="float32",
                                  strides=[1],
                                  scope="local")
    normal_reduce_temp0 = T.alloc_buffer([1],
                                         dtype="float32",
                                         strides=[1],
                                         scope="local")
    for i in T.thread_binding(0, 16, thread="blockIdx.x"):
        for k0o in T.thread_binding(0, 4, thread="threadIdx.x"):
            with T.block("B_in_thread_init"):
                T.reads([])
                T.writes([normal_reduce_temp0[0]])
                normal_reduce_temp0[0] = T.float32(0)
            for k0i0, k1 in T.grid(4, 16):
                with T.block("B_rf"):
                    vk0 = T.axis.spatial(16, k0o * 4 + k0i0)
                    vi, vk1 = T.axis.remap("SR", [i, k1])
                    T.reads([B_rf_local[vk0, vi], A[vi, vk0, vk1]])
                    T.writes([B_rf_local[vk0, vi]])
                    with T.init():
                        B_rf_local[vk0, vi] = T.float32(0)
                    B_rf_local[vk0, vi] = B_rf_local[vk0, vi] + A[vi, vk0, vk1]
            for k0i1 in T.serial(0, 4):
                with T.block("B_normal_reduction"):
                    vk0 = T.axis.reduce(16, k0o * 4 + k0i1)
                    vi = T.axis.spatial(16, i)
                    T.reads([B_rf_local[vk0, vi], normal_reduce_temp0[0]])
                    T.writes([normal_reduce_temp0[0]])
                    normal_reduce_temp0[
                        0] = normal_reduce_temp0[0] + B_rf_local[vk0, vi]
            with T.block("B_cross_thread_reduction"):
                T.reads([normal_reduce_temp0[0]])
                T.writes([reduce_temp0[0]])
                T.attr(
                    T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
                    "reduce_scope",
                    T.reinterpret(T.uint64(0), dtype="handle"),
                )
                T.evaluate(
                    T.tvm_thread_allreduce(
                        T.uint32(1),
                        normal_reduce_temp0[0],
                        True,
                        reduce_temp0.data,
                        k0o,
                        dtype="handle",
                    ))
            with T.block("B_write_back"):
                vi = T.axis.spatial(16, i)
                T.reads([reduce_temp0[0]])
                T.writes([B[vi]])
                B[vi] = reduce_temp0[0]
Ejemplo n.º 28
0
def cuda_matmul_2(a: T.handle, b: T.handle, c: T.handle) -> None:  # pylint: disable=undefined-loop-variable
    A = T.match_buffer(a, [2048, 2048], "float32")
    B = T.match_buffer(b, [2048, 2048], "float32")
    C = T.match_buffer(c, [2048, 2048], "float32")
    A_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared")
    B_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared")
    A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local")
    B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local")
    C_local = T.alloc_buffer([2048, 2048], "float32", scope="local")
    with T.block([2048, 2048], "A_shared") as [v0, v1]:
        A_shared[v0, v1] = A[v0, v1]
    with T.block([2048, 2048], "B_shared") as [v0, v1]:
        B_shared[v0, v1] = B[v0, v1]
    with T.block([2048, 2048], "B_shared_local") as [v0, v1]:
        B_shared_local[v0, v1] = B_shared[v0, v1]
    for by in T.thread_binding(0, 32, thread="blockIdx.y"):
        for bx in T.thread_binding(0, 32, thread="blockIdx.x"):
            for vy in T.thread_binding(0, 2, thread="vthread.y"):
                for vx in T.thread_binding(0, 2, thread="vthread.x"):
                    for ty in T.thread_binding(0, 8, thread="threadIdx.y"):
                        for tx in T.thread_binding(0, 8, thread="threadIdx.x"):
                            for k_0 in T.serial(0, 256):
                                for k_1 in T.unroll(0, 8):
                                    for i, j in T.grid(1, 4):
                                        with T.block(
                                            [2048, 2048],
                                                "A_shared_local") as [v0, v1]:
                                            T.bind(v0, k_0 * 8 + k_1 + i)
                                            T.bind(
                                                v1,
                                                by * 64 + vy * 32 + ty * 4 + j)
                                            A_shared_local[v0,
                                                           v1] = A_shared[v0,
                                                                          v1]
                                    for _, i, j in T.grid(1, 4, 4):
                                        with T.block([
                                                2048, 2048,
                                                T.reduce_axis(0, 2048)
                                        ], "C") as [vi, vj, vk]:
                                            T.bind(
                                                vi,
                                                by * 64 + vy * 32 + ty * 4 + i)
                                            T.bind(
                                                vj,
                                                bx * 64 + vx * 32 + tx * 4 + j)
                                            T.bind(vk, k_0 * 8 + k_1)
                                            with T.init():
                                                C_local[vi, vj] = T.float32(0)
                                            C_local[vi, vj] = C_local[
                                                vi, vj] + A_shared_local[
                                                    vk, vi] * B_shared_local[
                                                        vk, vj]
                            for i, j in T.grid(4, 4):
                                with T.block([2048, 2048],
                                             "C_local") as [v0, v1]:
                                    T.bind(v0, by * 64 + vy * 32 + ty * 4 + i)
                                    T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j)
                                    C[v0, v1] = C_local[v0, v1]
Ejemplo n.º 29
0
def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    C = T.match_buffer(c, (128, 128))

    with T.block([128, 128, T.reduce_axis(0, 128)]) as [i, j, k]:
        with T.init():
            C[i, j] = 0.0
        C[i, j] += A[i, k] * B[j, k]
Ejemplo n.º 30
0
def square_sum_rfactor(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, [16, 256, 256])
    C = T.match_buffer(c, [16])
    C_rf = T.alloc_buffer([16, 256])

    for i0, i1, i2 in T.grid(16, 256, 256):
        with T.block("C_rf"):
            vi2, b, i = T.axis.remap("SSR", [i2, i0, i1])
            with T.init():
                C_rf[b, vi2] = 0.0
            C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2])

    for i0_1, i2_1 in T.grid(16, 256):
        with T.block("C"):
            vi2_1, b_1 = T.axis.remap("RS", [i2_1, i0_1])
            with T.init():
                C[b_1] = 0.0
            C[b_1] = C[b_1] + C_rf[b_1, vi2_1]