Ejemplo n.º 1
0
def tir_extern(a: T.handle, b: T.handle, c: T.handle) -> None:
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    C = T.match_buffer(c, (128, 128))
    # body
    with T.block("C"):
        T.reads([A[0:128, 0:128], B[0:128, 0:128]])
        T.writes([C[0:128, 0:128]])
        T.evaluate(
            T.tvm_call_packed(
                "tvm.contrib.cblas.matmul",
                T.tvm_stack_make_array(
                    A.data,
                    T.tvm_stack_make_shape(128, 128, dtype="handle"),
                    0,
                    2,
                    0.0,
                    0,
                    dtype="handle",
                ),
                T.tvm_stack_make_array(
                    B.data,
                    T.tvm_stack_make_shape(128, 128, dtype="handle"),
                    0,
                    2,
                    0.0,
                    0,
                    dtype="handle",
                ),
                T.tvm_stack_make_array(
                    C.data,
                    T.tvm_stack_make_shape(128, 128, dtype="handle"),
                    0,
                    2,
                    0.0,
                    0,
                    dtype="handle",
                ),
                0,
                0,
                dtype="int32",
            )
        )
Ejemplo n.º 2
0
def tir_extern(a: T.handle, b: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, 128))
    C = T.match_buffer(c, (128, 128))
    # body
    with T.block([], "C"):
        T.reads([A[0:128, 0:128], B[0:128, 0:128]])
        T.writes([C[0:128, 0:128]])
        T.evaluate(
            T.tvm_call_packed(
                "tvm.contrib.cblas.matmul",
                T.tvm_stack_make_array(
                    A.data,
                    T.tvm_stack_make_shape(128, 128, dtype="handle"),
                    0,
                    2,
                    0.0,
                    0,
                    dtype="handle",
                ),
                T.tvm_stack_make_array(
                    B.data,
                    T.tvm_stack_make_shape(128, 128, dtype="handle"),
                    0,
                    2,
                    0.0,
                    0,
                    dtype="handle",
                ),
                T.tvm_stack_make_array(
                    C.data,
                    T.tvm_stack_make_shape(128, 128, dtype="handle"),
                    0,
                    2,
                    0.0,
                    0,
                    dtype="handle",
                ),
                0,
                0,
                dtype="int32",
            ))
    def tir_packed_call() -> None:
        A = T.var("handle")
        B = T.var("handle")
        C = T.var("handle")
        device_context = T.var("handle")

        # body
        T.evaluate(
            T.tvm_call_cpacked(
                "tvm_test_cpacked",
                T.tvm_stack_make_array(
                    A,
                    T.tvm_stack_make_shape(1, dtype="handle"),
                    T.reinterpret(T.uint64(0), dtype="handle"),
                    T.uint32(1),
                    T.cast(0, dtype="float32"),
                    0,
                    dtype="handle",
                ),
                T.tvm_stack_make_array(
                    B,
                    T.tvm_stack_make_shape(1, dtype="handle"),
                    T.reinterpret(T.uint64(0), dtype="handle"),
                    T.uint32(1),
                    T.cast(0, dtype="float32"),
                    0,
                    dtype="handle",
                ),
                T.tvm_stack_make_array(
                    C,
                    T.tvm_stack_make_shape(1, dtype="handle"),
                    T.reinterpret(T.uint64(0), dtype="handle"),
                    T.uint32(1),
                    T.cast(0, dtype="float32"),
                    0,
                    dtype="handle",
                ),
                device_context,
                dtype="int32",
            ))