def tir_extern(a: T.handle, b: T.handle, c: T.handle) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) # body with T.block("C"): T.reads([A[0:128, 0:128], B[0:128, 0:128]]) T.writes([C[0:128, 0:128]]) T.evaluate( T.tvm_call_packed( "tvm.contrib.cblas.matmul", T.tvm_stack_make_array( A.data, T.tvm_stack_make_shape(128, 128, dtype="handle"), 0, 2, 0.0, 0, dtype="handle", ), T.tvm_stack_make_array( B.data, T.tvm_stack_make_shape(128, 128, dtype="handle"), 0, 2, 0.0, 0, dtype="handle", ), T.tvm_stack_make_array( C.data, T.tvm_stack_make_shape(128, 128, dtype="handle"), 0, 2, 0.0, 0, dtype="handle", ), 0, 0, dtype="int32", ) )
def tir_extern(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) # body with T.block([], "C"): T.reads([A[0:128, 0:128], B[0:128, 0:128]]) T.writes([C[0:128, 0:128]]) T.evaluate( T.tvm_call_packed( "tvm.contrib.cblas.matmul", T.tvm_stack_make_array( A.data, T.tvm_stack_make_shape(128, 128, dtype="handle"), 0, 2, 0.0, 0, dtype="handle", ), T.tvm_stack_make_array( B.data, T.tvm_stack_make_shape(128, 128, dtype="handle"), 0, 2, 0.0, 0, dtype="handle", ), T.tvm_stack_make_array( C.data, T.tvm_stack_make_shape(128, 128, dtype="handle"), 0, 2, 0.0, 0, dtype="handle", ), 0, 0, dtype="int32", ))
def tir_packed_call() -> None: A = T.var("handle") B = T.var("handle") C = T.var("handle") device_context = T.var("handle") # body T.evaluate( T.tvm_call_cpacked( "tvm_test_cpacked", T.tvm_stack_make_array( A, T.tvm_stack_make_shape(1, dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"), T.uint32(1), T.cast(0, dtype="float32"), 0, dtype="handle", ), T.tvm_stack_make_array( B, T.tvm_stack_make_shape(1, dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"), T.uint32(1), T.cast(0, dtype="float32"), 0, dtype="handle", ), T.tvm_stack_make_array( C, T.tvm_stack_make_shape(1, dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"), T.uint32(1), T.cast(0, dtype="float32"), 0, dtype="handle", ), device_context, dtype="int32", ))