Ejemplo n.º 1
0
def flattened_elementwise_func(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (16, 16), "float32")
    C = tir.match_buffer(c, (16, 16), "float32")
    for i in tir.serial(0, 16):
        B_new = tir.allocate([16], "float32", "global")
        for j in tir.serial(0, 16):
            B_new[j] = tir.load("float32", A.data, ((i * 16) + j)) + 1.0
        for j in tir.serial(0, 16):
            C.data[((i * 16) + j)] = tir.load("float32", B_new, j) * 2.0
Ejemplo n.º 2
0
def flattened_multi_alloc_func(a: ty.handle, d: ty.handle) -> None:
    A = tir.match_buffer(a, (32), "float32")
    D = tir.match_buffer(d, (32), "float32")

    for i in range(0, 32):
        B = tir.allocate((32, ), "float32", "global")
        C = tir.allocate((32, ), "float32", "global")
        B[i] = tir.load("float32", A.data, i) + 1.0
        C[i] = tir.load("float32", A.data, i) + tir.load("float32", B, i)
        D.data[i] = tir.load("float32", C, i) * 2.0
Ejemplo n.º 3
0
def flattened_symbolic_func(a: ty.handle, c: ty.handle, n: ty.int32, m: ty.int32) -> None:
    A = tir.match_buffer(a, (n, m), "float32")
    C = tir.match_buffer(c, (n, m), "float32")

    for i in range(0, n):
        B = tir.allocate([m], "float32", "global")
        for j in range(0, m):
            B[j] = tir.load("float32", A.data, i * m + j) + 1.0
        for j in range(0, m):
            C.data[i * m + j] = tir.load("float32", B, j) * 2.0
Ejemplo n.º 4
0
 def test_tir_fma(A: ty.handle, B: ty.handle, C: ty.handle,
                  d: ty.handle) -> None:
     # function attr dict
     tir.func_attr({"global_symbol": "test_fma", "tir.noalias": True})
     n = tir.var("int32")
     stride = tir.var("int32")
     stride_1 = tir.var("int32")
     stride_2 = tir.var("int32")
     stride_3 = tir.var("int32")
     A_1 = tir.match_buffer(
         A,
         [n],
         strides=[stride],
         elem_offset=0,
         align=128,
         offset_factor=1,
         type="auto",
     )
     B_1 = tir.match_buffer(
         B,
         [n],
         strides=[stride_1],
         elem_offset=0,
         align=128,
         offset_factor=1,
         type="auto",
     )
     C_1 = tir.match_buffer(
         C,
         [n],
         strides=[stride_2],
         elem_offset=0,
         align=128,
         offset_factor=1,
         type="auto",
     )
     d_1 = tir.match_buffer(
         d,
         [n],
         strides=[stride_3],
         elem_offset=0,
         align=128,
         offset_factor=1,
         type="auto",
     )
     # body
     for i in tir.serial(0, n):
         d_1.data[(i * stride_3)] = (tir.load("float32", A_1.data,
                                              (i * stride)) *
                                     tir.load("float32", B_1.data,
                                              (i * stride_1))) + tir.load(
                                                  "float32", C_1.data,
                                                  (i * stride_2))
Ejemplo n.º 5
0
 def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
     # function attr dict
     tir.func_attr({"global_symbol": "main", "tir.noalias": True})
     A = tir.match_buffer(a, [128, 128])
     B = tir.match_buffer(b, [128, 128])
     C = tir.match_buffer(c, [128, 128])
     # body
     for x, y in tir.grid(128, 128):
         C.data[x * 128 + y] = 0.0
         for k in tir.serial(0, 128):
             C.data[x * 128 + y] = tir.load("float32", C.data, x * 128 + y) + tir.load(
                 "float32", A.data, x * 128 + k
             ) * tir.load("float32", B.data, y * 128 + k)
Ejemplo n.º 6
0
def flattened_predicate_func(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (32), "float32")
    C = tir.match_buffer(c, (32), "float32")

    for i, j in tir.grid(5, 7):
        if i * 7 + j < 32:
            C.data[i * 7 + j] = tir.load("float32", A.data, i * 7 + j) + 1.0
Ejemplo n.º 7
0
def buffer_opaque_access(b: ty.handle, c: ty.handle) -> None:
    B = tir.match_buffer(b, [16, 16], "float32")
    C = tir.match_buffer(c, [16, 16], "float32")

    with tir.block([]):
        tir.reads([])
        tir.writes(B[0:16, 0:16])
        A = tir.allocate([256], "float32", "global")
        for i, j in tir.grid(16, 16):
            tir.store(A, i * 16 + j, 1)
        for i in range(0, 16):
            for j in range(0, 16):
                tir.evaluate(tir.load("float32", A, i * 16 + j))
            for j in range(0, 16):
                tir.evaluate(
                    tir.tvm_fill_fragment(B.data,
                                          16,
                                          16,
                                          16,
                                          0,
                                          tir.float32(0),
                                          dtype="handle"))

    for i, j in tir.grid(16, 16):
        with tir.block([16, 16]) as [vi, vj]:
            tir.bind(vi, i)
            tir.bind(vj, j)
            C[vi, vj] = B[vi, vj]
Ejemplo n.º 8
0
def flattened_gpu_func(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (16, 16), "float32")
    C = tir.match_buffer(c, (16, 16), "float32")

    i0 = tir.env_thread("blockIdx.x")
    i1 = tir.env_thread("threadIdx.x")
    i2 = tir.env_thread("vthread")

    tir.launch_thread(i0, 4)
    tir.launch_thread(i1, 2)
    tir.launch_thread(i2, 2)
    B = tir.allocate([16], "float32", "local")
    for j in range(0, 16):
        B[j] = tir.load("float32", A.data, i0 * 64 + i1 * 32 + i2 * 16 + j) + 1.0
    for j in range(0, 16):
        C.data[i0 * 64 + i1 * 32 + i2 * 16 + j] = tir.load("float32", B, j) * 2.0
Ejemplo n.º 9
0
def fail_match_load(a: ty.handle) -> None:
    A = tir.match_buffer(a, (8, 8))
    for i, j in tir.grid(8, 8):
        with tir.block([]):
            tir.reads(A[i, j])
            tir.writes([])
            sub_A = tir.match_buffer(A[i, j], ())
            tir.evaluate(tir.load("float32", sub_A.data, 0))
Ejemplo n.º 10
0
def opaque_access_load(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128))
    B = tir.alloc_buffer((128, 128))
    C = tir.match_buffer(c, (128, 128))
    with tir.block([128, 128], "B") as [vi, vj]:
        B[vi, vj] = A[vi, vj] * 2.0
    with tir.block([128, 128], "C") as [vi, vj]:
        tir.reads(B[0:128, 0:128])
        tir.writes(C[0:128, 0:128])
        C[vi, vj] = tir.load("float32", B.data, vi * 128 + vj) + 1.0
Ejemplo n.º 11
0
 def main(
     placeholder: ty.handle,
     placeholder_1: ty.handle,
     placeholder_2: ty.handle,
     ethosu_conv2d: ty.handle,
 ) -> None:
     # function attr dict
     tir.func_attr({"global_symbol": "main", "tir.noalias": True})
     placeholder_3 = tir.match_buffer(placeholder, [1, 8, 8, 3],
                                      dtype="uint8",
                                      elem_offset=0,
                                      align=128,
                                      offset_factor=1)
     placeholder_4 = tir.match_buffer(placeholder_1, [48],
                                      dtype="uint8",
                                      elem_offset=0,
                                      align=128,
                                      offset_factor=1)
     placeholder_5 = tir.match_buffer(placeholder_2, [16],
                                      dtype="int32",
                                      elem_offset=0,
                                      align=128,
                                      offset_factor=1)
     ethosu_conv2d_1 = tir.match_buffer(ethosu_conv2d, [1, 8, 8, 16],
                                        dtype="uint8",
                                        elem_offset=0,
                                        align=128,
                                        offset_factor=1)
     # body
     tir.evaluate(
         tir.call_extern(
             "ethosu_conv2d",
             "uint8",
             8,
             8,
             3,
             8,
             0,
             8,
             tir.load("uint8", placeholder_3.data, 0),
             0,
             0,
             0,
             tir.float32(0.5),
             10,
             "NHWC",
             24,
             3,
             1,
             "uint8",
             8,
             8,
             16,
             8,
             0,
             8,
             tir.load("uint8", ethosu_conv2d_1.data, 0),
             0,
             0,
             0,
             tir.float32(0.25),
             14,
             "NHWC",
             128,
             16,
             1,
             1,
             1,
             1,
             1,
             1,
             1,
             tir.load("uint8", placeholder_4.data, 0),
             0,
             12,
             tir.load("uint8", placeholder_5.data, 0),
             0,
             0,
             0,
             0,
             0,
             "CLIP",
             0,
             0,
             "NONE",
             dtype="uint8",
         ))
Ejemplo n.º 12
0
def opaque_access_during_complete(a: ty.handle) -> None:  # error
    A = tir.match_buffer(a, (16, 16), "float32")
    with tir.block([16, 16]) as [vi, vj]:
        tir.evaluate(tir.load("float32", A.data, vi * 16 + vj))
Ejemplo n.º 13
0
def flattened_unit_loop_func(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (32), "float32")
    C = tir.match_buffer(c, (32), "float32")

    for x, z in tir.grid(4, 8):
        C.data[x * 8 + z] = tir.load("float32", A.data, x * 8 + z) + 1.0
def opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128), dtype="float16")
    B = tir.match_buffer(b, (128, 128), dtype="float16")
    C = tir.match_buffer(c, (128, 128), dtype="float16")
    D = tir.match_buffer(d, (128, 128), dtype="float16")

    with tir.block([128, 128], "load_store") as [vi, vj]:
        tir.reads(A[vi, vj])
        tir.writes(D[vi, vj])
        D.data[vi * 128 + vj] = tir.load("float16", A.data, vi * 128 + vj)
    with tir.block([8, 8], "opaque") as [vi, vj]:
        tir.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
        tir.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
        tir.evaluate(
            tir.tvm_load_matrix_sync(
                B.data,
                16,
                16,
                16,
                vi * 8 + vj,
                tir.tvm_access_ptr(
                    tir.type_annotation(dtype="float16"),
                    A.data,
                    vi * 2048 + vj * 16,
                    128,
                    1,
                    dtype="handle",
                ),
                128,
                "row_major",
                dtype="handle",
            )
        )
    with tir.block([8, 8], "match_buffer") as [vi, vj]:
        tir.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
        tir.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
        A0 = tir.match_buffer(
            A[
                vi * 16 : vi * 16 + 16,
                vj * 16 : vj * 16 + 16,
            ],
            (16, 16),
            "float16",
            strides=[128, 1],
            offset_factor=1,
        )
        C0 = tir.match_buffer(
            C[
                vi * 16 : vi * 16 + 16,
                vj * 16 : vj * 16 + 16,
            ],
            (16, 16),
            "float16",
            strides=[128, 1],
            offset_factor=1,
        )
        tir.evaluate(
            tir.tvm_load_matrix_sync(
                C0.data,
                16,
                16,
                16,
                vi * 8 + vj,
                tir.tvm_access_ptr(
                    tir.type_annotation(dtype="float16"),
                    A0.data,
                    A0.elem_offset,
                    A0.strides[0],
                    1,
                    dtype="handle",
                ),
                128,
                "row_major",
                dtype="handle",
            )
        )
def primfunc_global_allocates(placeholder_144: ty.handle,
                              placeholder_145: ty.handle,
                              placeholder_146: ty.handle,
                              T_cast_48: ty.handle) -> None:
    # function attr dict
    tir.func_attr({
        "global_symbol":
        "fused_nn_conv2d_add_cast_fixed_point_multiply_clip_cast_cast_13",
        "tir.noalias": True
    })
    placeholder_147 = tir.match_buffer(placeholder_144, [1, 14, 14, 512],
                                       dtype="int16",
                                       elem_offset=0,
                                       align=128,
                                       offset_factor=1)
    placeholder_148 = tir.match_buffer(placeholder_145, [3, 3, 512, 1],
                                       dtype="int16",
                                       elem_offset=0,
                                       align=128,
                                       offset_factor=1)
    placeholder_149 = tir.match_buffer(placeholder_146, [1, 1, 1, 512],
                                       dtype="int32",
                                       elem_offset=0,
                                       align=128,
                                       offset_factor=1)
    T_cast_49 = tir.match_buffer(T_cast_48, [1, 14, 14, 512],
                                 dtype="int16",
                                 elem_offset=0,
                                 align=128,
                                 offset_factor=1)
    # body
    PaddedInput_22 = tir.allocate([131072], "int16", "global")
    DepthwiseConv2d_9 = tir.allocate([100352], "int32", "global")
    for i1_29, i2_39, i3_40 in tir.grid(16, 16, 512):
        PaddedInput_22[(((i1_29 * 8192) + (i2_39 * 512)) +
                        i3_40)] = tir.if_then_else(
                            ((((1 <= i1_29) and (i1_29 < 15)) and
                              (1 <= i2_39)) and (i2_39 < 15)),
                            tir.load("int16", placeholder_147.data,
                                     ((((i1_29 * 7168) +
                                        (i2_39 * 512)) + i3_40) - 7680)),
                            tir.int16(0),
                            dtype="int16")
    for i_9, j_9, c_9 in tir.grid(14, 14, 512):
        DepthwiseConv2d_9[(((i_9 * 7168) + (j_9 * 512)) + c_9)] = 0
        for di_9, dj_9 in tir.grid(3, 3):
            DepthwiseConv2d_9[(((i_9 * 7168) + (j_9 * 512)) + c_9)] = (
                tir.load("int32", DepthwiseConv2d_9,
                         (((i_9 * 7168) + (j_9 * 512)) + c_9)) +
                (tir.load("int16", PaddedInput_22,
                          (((((i_9 * 8192) + (di_9 * 8192)) + (j_9 * 512)) +
                            (dj_9 * 512)) + c_9)).astype("int32") *
                 tir.load("int16", placeholder_148.data,
                          (((di_9 * 1536) +
                            (dj_9 * 512)) + c_9)).astype("int32")))
    for ax1_27, ax2_28, ax3_30 in tir.grid(14, 14, 512):
        DepthwiseConv2d_9[(((ax1_27 * 7168) + (ax2_28 * 512)) + ax3_30)] = (
            tir.load("int32", DepthwiseConv2d_9,
                     (((ax1_27 * 7168) + (ax2_28 * 512)) + ax3_30)) +
            tir.load("int32", placeholder_149.data, ax3_30))
    for i1_30, i2_40, i3_41 in tir.grid(14, 14, 512):
        DepthwiseConv2d_9[(((i1_30 * 7168) + (i2_40 * 512)) +
                           i3_41)] = tir.q_multiply_shift(tir.load(
                               "int32", DepthwiseConv2d_9,
                               (((i1_30 * 7168) + (i2_40 * 512)) + i3_41)),
                                                          1269068532,
                                                          31,
                                                          -4,
                                                          dtype="int32")
    for i1_31, i2_41, i3_42 in tir.grid(14, 14, 512):
        DepthwiseConv2d_9[(((i1_31 * 7168) + (i2_41 * 512)) +
                           i3_42)] = tir.max(
                               tir.max(
                                   tir.load("int32", DepthwiseConv2d_9,
                                            (((i1_31 * 7168) +
                                              (i2_41 * 512)) + i3_42)), 255),
                               0)
    for ax1_28, ax2_29, ax3_31 in tir.grid(14, 14, 512):
        PaddedInput_22[(((ax1_28 * 7168) +
                         (ax2_29 * 512)) + ax3_31)] = tir.load(
                             "int32", DepthwiseConv2d_9,
                             (((ax1_28 * 7168) +
                               (ax2_29 * 512)) + ax3_31)).astype("uint8")
    for ax1_29, ax2_30, ax3_32 in tir.grid(14, 14, 512):
        T_cast_49.data[(((ax1_29 * 7168) +
                         (ax2_30 * 512)) + ax3_32)] = tir.load(
                             "uint8", PaddedInput_22,
                             (((ax1_29 * 7168) +
                               (ax2_30 * 512)) + ax3_32)).astype("int16")
def primfunc_local_allocates(placeholder_162: ty.handle,
                             placeholder_163: ty.handle,
                             placeholder_164: ty.handle,
                             T_cast_76: ty.handle) -> None:
    # function attr dict
    tir.func_attr({
        "global_symbol":
        "fused_nn_conv2d_add_cast_fixed_point_multiply_clip_cast_cast_9",
        "tir.noalias": True
    })
    placeholder_165 = tir.match_buffer(placeholder_162, [1, 14, 14, 512],
                                       dtype="int16",
                                       elem_offset=0,
                                       align=128,
                                       offset_factor=1)
    placeholder_166 = tir.match_buffer(placeholder_163, [3, 3, 512, 1],
                                       dtype="int16",
                                       elem_offset=0,
                                       align=128,
                                       offset_factor=1)
    placeholder_167 = tir.match_buffer(placeholder_164, [1, 1, 1, 512],
                                       dtype="int32",
                                       elem_offset=0,
                                       align=128,
                                       offset_factor=1)
    T_cast_77 = tir.match_buffer(T_cast_76, [1, 14, 14, 512],
                                 dtype="int16",
                                 elem_offset=0,
                                 align=128,
                                 offset_factor=1)
    # body
    PaddedInput_25 = tir.allocate([1, 16, 16, 512], "int16", "global")
    for i1_35, i2_46, i3_47 in tir.grid(16, 16, 512):
        PaddedInput_25[(((i1_35 * 8192) + (i2_46 * 512)) +
                        i3_47)] = tir.if_then_else(
                            ((((1 <= i1_35) and (i1_35 < 15)) and
                              (1 <= i2_46)) and (i2_46 < 15)),
                            tir.load("int16", placeholder_165.data,
                                     ((((i1_35 * 7168) +
                                        (i2_46 * 512)) + i3_47) - 7680)),
                            tir.int16(0),
                            dtype="int16")
    T_add_11 = tir.allocate([1, 14, 14, 512], "int32", "global")
    with tir.allocate([1, 14, 14, 512], "int32",
                      "global") as DepthwiseConv2d_11:
        for i_11, j_11, c_11 in tir.grid(14, 14, 512):
            DepthwiseConv2d_11[(((i_11 * 7168) + (j_11 * 512)) + c_11)] = 0
            for di_11, dj_11 in tir.grid(3, 3):
                DepthwiseConv2d_11[(((i_11 * 7168) + (j_11 * 512)) + c_11)] = (
                    tir.load("int32", DepthwiseConv2d_11,
                             (((i_11 * 7168) + (j_11 * 512)) + c_11)) +
                    (tir.load("int16", PaddedInput_25,
                              (((((i_11 * 8192) + (di_11 * 8192)) +
                                 (j_11 * 512)) +
                                (dj_11 * 512)) + c_11)).astype("int32") *
                     tir.load("int16", placeholder_166.data,
                              (((di_11 * 1536) +
                                (dj_11 * 512)) + c_11)).astype("int32")))
        for ax1_44, ax2_45, ax3_47 in tir.grid(14, 14, 512):
            T_add_11[(((ax1_44 * 7168) + (ax2_45 * 512)) + ax3_47)] = (
                tir.load("int32", DepthwiseConv2d_11,
                         (((ax1_44 * 7168) + (ax2_45 * 512)) + ax3_47)) +
                tir.load("int32", placeholder_167.data, ax3_47))
    compute_22 = tir.allocate([1, 14, 14, 512], "int32", "global")
    with tir.allocate([1, 14, 14, 512], "int32", "global") as T_cast_78:
        for ax1_45, ax2_46, ax3_48 in tir.grid(14, 14, 512):
            T_cast_78[(((ax1_45 * 7168) +
                        (ax2_46 * 512)) + ax3_48)] = tir.load(
                            "int32", T_add_11,
                            (((ax1_45 * 7168) + (ax2_46 * 512)) + ax3_48))
        for i1_36, i2_47, i3_48 in tir.grid(14, 14, 512):
            compute_22[(((i1_36 * 7168) + (i2_47 * 512)) +
                        i3_48)] = tir.q_multiply_shift(tir.load(
                            "int32", T_cast_78,
                            (((i1_36 * 7168) + (i2_47 * 512)) + i3_48)),
                                                       1948805937,
                                                       31,
                                                       -5,
                                                       dtype="int32")
    T_cast_79 = tir.allocate([1, 14, 14, 512], "uint8", "global")
    with tir.allocate([1, 14, 14, 512], "int32", "global") as compute_23:
        for i1_37, i2_48, i3_49 in tir.grid(14, 14, 512):
            compute_23[(((i1_37 * 7168) + (i2_48 * 512)) + i3_49)] = tir.max(
                tir.max(
                    tir.load("int32", compute_22,
                             (((i1_37 * 7168) + (i2_48 * 512)) + i3_49)), 255),
                0)
        for ax1_46, ax2_47, ax3_49 in tir.grid(14, 14, 512):
            T_cast_79[(((ax1_46 * 7168) +
                        (ax2_47 * 512)) + ax3_49)] = tir.load(
                            "int32", compute_23,
                            (((ax1_46 * 7168) +
                              (ax2_47 * 512)) + ax3_49)).astype("uint8")
    for ax1_47, ax2_48, ax3_50 in tir.grid(14, 14, 512):
        T_cast_77.data[(((ax1_47 * 7168) +
                         (ax2_48 * 512)) + ax3_50)] = tir.load(
                             "uint8", T_cast_79,
                             (((ax1_47 * 7168) +
                               (ax2_48 * 512)) + ax3_50)).astype("int16")
Ejemplo n.º 17
0
 def main(placeholder: ty.handle, placeholder_1: ty.handle,
          placeholder_2: ty.handle, ethosu_write: ty.handle,
          placeholder_3: ty.handle, placeholder_4: ty.handle,
          placeholder_5: ty.handle, placeholder_6: ty.handle,
          placeholder_7: ty.handle, placeholder_8: ty.handle,
          placeholder_9: ty.handle, placeholder_10: ty.handle) -> None:
     # function attr dict
     tir.func_attr({
         "from_legacy_te_schedule": True,
         "global_symbol": "main",
         "tir.noalias": True
     })
     buffer = tir.match_buffer(placeholder_7, [80],
                               dtype="uint8",
                               elem_offset=0,
                               align=128,
                               offset_factor=1)
     buffer_1 = tir.match_buffer(placeholder_5, [80],
                                 dtype="uint8",
                                 elem_offset=0,
                                 align=128,
                                 offset_factor=1)
     buffer_2 = tir.match_buffer(placeholder_3, [80],
                                 dtype="uint8",
                                 elem_offset=0,
                                 align=128,
                                 offset_factor=1)
     buffer_3 = tir.match_buffer(placeholder_4, [32],
                                 dtype="uint8",
                                 elem_offset=0,
                                 align=128,
                                 offset_factor=1)
     buffer_4 = tir.match_buffer(placeholder_9, [80],
                                 dtype="uint8",
                                 elem_offset=0,
                                 align=128,
                                 offset_factor=1)
     buffer_5 = tir.match_buffer(placeholder_6, [32],
                                 dtype="uint8",
                                 elem_offset=0,
                                 align=128,
                                 offset_factor=1)
     placeholder_11 = tir.match_buffer(placeholder, [1, 16, 16, 32],
                                       dtype="int8",
                                       elem_offset=0,
                                       align=128,
                                       offset_factor=1)
     buffer_6 = tir.match_buffer(placeholder_1, [592],
                                 dtype="uint8",
                                 elem_offset=0,
                                 align=128,
                                 offset_factor=1)
     ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 16, 16, 8],
                                       dtype="int8",
                                       elem_offset=0,
                                       align=128,
                                       offset_factor=1)
     buffer_7 = tir.match_buffer(placeholder_2, [160],
                                 dtype="uint8",
                                 elem_offset=0,
                                 align=128,
                                 offset_factor=1)
     buffer_8 = tir.match_buffer(placeholder_8, [32],
                                 dtype="uint8",
                                 elem_offset=0,
                                 align=128,
                                 offset_factor=1)
     buffer_9 = tir.match_buffer(placeholder_10, [32],
                                 dtype="uint8",
                                 elem_offset=0,
                                 align=128,
                                 offset_factor=1)
     # body
     ethosu_write_2 = tir.allocate([4096], "int8", "global")
     placeholder_global = tir.allocate([80], "uint8", "global")
     placeholder_d_global = tir.allocate([32], "uint8", "global")
     tir.evaluate(
         tir.call_extern("ethosu_conv2d",
                         "int8",
                         16,
                         16,
                         32,
                         16,
                         0,
                         16,
                         tir.load("int8", placeholder_11.data, 0),
                         0,
                         0,
                         0,
                         tir.float32(0.5),
                         10,
                         "NHWC",
                         512,
                         32,
                         1,
                         "int8",
                         16,
                         16,
                         16,
                         16,
                         0,
                         16,
                         tir.load("int8", ethosu_write_2, 0),
                         0,
                         0,
                         0,
                         tir.float32(0.25),
                         14,
                         "NHWC",
                         256,
                         16,
                         1,
                         1,
                         1,
                         1,
                         1,
                         1,
                         1,
                         tir.load("uint8", buffer_6.data, 0),
                         592,
                         12,
                         tir.load("uint8", buffer_7.data, 0),
                         160,
                         0,
                         0,
                         0,
                         0,
                         "NONE",
                         0,
                         0,
                         "NONE",
                         dtype="handle"))
     tir.evaluate(
         tir.call_extern("ethosu_copy",
                         tir.load("uint8", buffer_2.data, 0),
                         80,
                         tir.load("uint8", placeholder_global, 0),
                         dtype="handle"))
     tir.evaluate(
         tir.call_extern("ethosu_copy",
                         tir.load("uint8", buffer_3.data, 0),
                         32,
                         tir.load("uint8", placeholder_d_global, 0),
                         dtype="handle"))
     tir.evaluate(
         tir.call_extern("ethosu_conv2d",
                         "int8",
                         16,
                         16,
                         16,
                         16,
                         0,
                         16,
                         tir.load("int8", ethosu_write_2, 0),
                         0,
                         0,
                         0,
                         tir.float32(0.5),
                         10,
                         "NHWC",
                         256,
                         16,
                         1,
                         "int8",
                         16,
                         16,
                         2,
                         16,
                         0,
                         16,
                         tir.load("int8", ethosu_write_1.data, 0),
                         0,
                         0,
                         0,
                         tir.float32(0.25),
                         14,
                         "NHWC",
                         128,
                         8,
                         1,
                         1,
                         1,
                         1,
                         1,
                         1,
                         1,
                         tir.load("uint8", placeholder_global, 0),
                         80,
                         12,
                         tir.load("uint8", placeholder_d_global, 0),
                         32,
                         0,
                         0,
                         0,
                         0,
                         "NONE",
                         0,
                         0,
                         "NONE",
                         dtype="handle"))
     tir.evaluate(
         tir.call_extern("ethosu_copy",
                         tir.load("uint8", buffer_1.data, 0),
                         80,
                         tir.load("uint8", placeholder_global, 0),
                         dtype="handle"))
     tir.evaluate(
         tir.call_extern("ethosu_copy",
                         tir.load("uint8", buffer_5.data, 0),
                         32,
                         tir.load("uint8", placeholder_d_global, 0),
                         dtype="handle"))
     tir.evaluate(
         tir.call_extern("ethosu_conv2d",
                         "int8",
                         16,
                         16,
                         16,
                         16,
                         0,
                         16,
                         tir.load("int8", ethosu_write_2, 0),
                         0,
                         0,
                         0,
                         tir.float32(0.5),
                         10,
                         "NHWC",
                         256,
                         16,
                         1,
                         "int8",
                         16,
                         16,
                         2,
                         16,
                         0,
                         16,
                         tir.load("int8", ethosu_write_1.data, 2),
                         0,
                         0,
                         0,
                         tir.float32(0.25),
                         14,
                         "NHWC",
                         128,
                         8,
                         1,
                         1,
                         1,
                         1,
                         1,
                         1,
                         1,
                         tir.load("uint8", placeholder_global, 0),
                         80,
                         12,
                         tir.load("uint8", placeholder_d_global, 0),
                         32,
                         0,
                         0,
                         0,
                         0,
                         "NONE",
                         0,
                         0,
                         "NONE",
                         dtype="handle"))
     tir.evaluate(
         tir.call_extern("ethosu_copy",
                         tir.load("uint8", buffer.data, 0),
                         80,
                         tir.load("uint8", placeholder_global, 0),
                         dtype="handle"))
     tir.evaluate(
         tir.call_extern("ethosu_copy",
                         tir.load("uint8", buffer_8.data, 0),
                         32,
                         tir.load("uint8", placeholder_d_global, 0),
                         dtype="handle"))
     tir.evaluate(
         tir.call_extern("ethosu_conv2d",
                         "int8",
                         16,
                         16,
                         16,
                         16,
                         0,
                         16,
                         tir.load("int8", ethosu_write_2, 0),
                         0,
                         0,
                         0,
                         tir.float32(0.5),
                         10,
                         "NHWC",
                         256,
                         16,
                         1,
                         "int8",
                         16,
                         16,
                         2,
                         16,
                         0,
                         16,
                         tir.load("int8", ethosu_write_1.data, 4),
                         0,
                         0,
                         0,
                         tir.float32(0.25),
                         14,
                         "NHWC",
                         128,
                         8,
                         1,
                         1,
                         1,
                         1,
                         1,
                         1,
                         1,
                         tir.load("uint8", placeholder_global, 0),
                         80,
                         12,
                         tir.load("uint8", placeholder_d_global, 0),
                         32,
                         0,
                         0,
                         0,
                         0,
                         "NONE",
                         0,
                         0,
                         "NONE",
                         dtype="handle"))
     tir.evaluate(
         tir.call_extern("ethosu_copy",
                         tir.load("uint8", buffer_4.data, 0),
                         80,
                         tir.load("uint8", placeholder_global, 0),
                         dtype="handle"))
     tir.evaluate(
         tir.call_extern("ethosu_copy",
                         tir.load("uint8", buffer_9.data, 0),
                         32,
                         tir.load("uint8", placeholder_d_global, 0),
                         dtype="handle"))
     tir.evaluate(
         tir.call_extern("ethosu_conv2d",
                         "int8",
                         16,
                         16,
                         16,
                         16,
                         0,
                         16,
                         tir.load("int8", ethosu_write_2, 0),
                         0,
                         0,
                         0,
                         tir.float32(0.5),
                         10,
                         "NHWC",
                         256,
                         16,
                         1,
                         "int8",
                         16,
                         16,
                         2,
                         16,
                         0,
                         16,
                         tir.load("int8", ethosu_write_1.data, 6),
                         0,
                         0,
                         0,
                         tir.float32(0.25),
                         14,
                         "NHWC",
                         128,
                         8,
                         1,
                         1,
                         1,
                         1,
                         1,
                         1,
                         1,
                         tir.load("uint8", placeholder_global, 0),
                         80,
                         12,
                         tir.load("uint8", placeholder_d_global, 0),
                         32,
                         0,
                         0,
                         0,
                         0,
                         "NONE",
                         0,
                         0,
                         "NONE",
                         dtype="handle"))
Ejemplo n.º 18
0
def intrin_except_assign(a: ty.handle) -> None:
    A = tir.match_buffer(a, (16, 16), "float32")
    A[0, 0] = tir.load(A, A, A)  # error