def symbolic_match(a: ty.handle, b: ty.handle, n: ty.int32, m: ty.int32) -> None: A = tir.match_buffer(a, (n * m, m)) B = tir.match_buffer(b, (n * 2, m * 4)) for i in range(0, n): with tir.block([]): tir.reads([]) tir.writes([A[i * m:i * m + n, 0:m], B[i * n:i * n + 2, 0:m * 4]]) Bs_0 = tir.var("int32") Bs_1 = tir.var("int32") sub_A = tir.match_buffer(A[i * m:i * m + m, 0:m], (m, m), offset_factor=1) sub_B = tir.match_buffer(B[i * n:i * n + 2, 0:m * 4], (2, m * 4), strides=[Bs_0, Bs_1], offset_factor=1) for ii, jj in tir.grid(m, m): sub_A[ii, jj] = 1 for j in range(0, 4): tir.evaluate( tir.intrin_test( sub_B.data, sub_B.elem_offset, sub_B.strides[0], sub_B.strides[1], sub_B.shape[0], sub_B.shape[1], dtype="handle", ))
def recursive_match(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (64, 64, 64)) B = tir.match_buffer(b, (64, 64, 64)) for i, j, k in tir.grid(64, 4, 4): with tir.block([]): tir.reads([]) tir.writes([ A[i, j * 16:j * 16 + 16, k * 16:k * 16 + 16], B[i, j * 16:j * 16 + 16, k * 16:k * 16 + 16], ]) As_0 = tir.var("int32") As_1 = tir.var("int32") sub_A = tir.match_buffer( A[i, j * 16:j * 16 + 16, k * 16:k * 16 + 16], (16, 16), strides=[As_0, As_1], offset_factor=1, ) sub_B = tir.match_buffer( B[i, j * 16:j * 16 + 16, k * 16:k * 16 + 16], (16, 16), offset_factor=1, ) for jj, kk in tir.grid(4, 4): with tir.block([]): tir.reads([]) tir.writes([ sub_A[jj * 4:jj * 4 + 4, kk * 4:kk * 4 + 4], sub_B[jj * 4:jj * 4 + 4, kk * 4:kk * 4 + 4], ]) Ass_0 = tir.var("int32") Ass_1 = tir.var("int32") sub_sub_A = tir.match_buffer( sub_A[jj * 4:jj * 4 + 4, kk * 4:kk * 4 + 4], (4, 4), strides=[Ass_0, Ass_1], offset_factor=1, ) sub_sub_B = tir.match_buffer( sub_B[jj * 4:jj * 4 + 4, kk * 4:kk * 4 + 4], (4, 4), offset_factor=1, ) tir.evaluate( tir.intrin_test( sub_sub_A.data, sub_sub_A.elem_offset, sub_sub_A.strides[0], sub_sub_A.strides[1], sub_sub_A.shape[0], sub_sub_A.shape[1], dtype="handle", )) for jjj, kkk in tir.grid(4, 4): sub_sub_B[jjj, kkk] = 1
def matmul_m_8x(a: ty.handle, b: ty.handle, c: ty.handle) -> None: x = tir.var("int32") m = tir.var("int32") A = tir.match_buffer(a, [m, x * 8]) B = tir.match_buffer(b, [m, x * 8]) C = tir.match_buffer(c, [m, m]) with tir.block([m, m, tir.reduce_axis(0, x * 8)], "update") as [vi, vj, vk]: with tir.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
def test_tir_fma(A: ty.handle, B: ty.handle, C: ty.handle, d: ty.handle) -> None: # function attr dict tir.func_attr({"global_symbol": "test_fma", "tir.noalias": True}) n = tir.var("int32") stride = tir.var("int32") stride_1 = tir.var("int32") stride_2 = tir.var("int32") stride_3 = tir.var("int32") A_1 = tir.match_buffer( A, [n], strides=[stride], elem_offset=0, align=128, offset_factor=1, type="auto", ) B_1 = tir.match_buffer( B, [n], strides=[stride_1], elem_offset=0, align=128, offset_factor=1, type="auto", ) C_1 = tir.match_buffer( C, [n], strides=[stride_2], elem_offset=0, align=128, offset_factor=1, type="auto", ) d_1 = tir.match_buffer( d, [n], strides=[stride_3], elem_offset=0, align=128, offset_factor=1, type="auto", ) # body for i in tir.serial(0, n): d_1.data[(i * stride_3)] = (tir.load("float32", A_1.data, (i * stride)) * tir.load("float32", B_1.data, (i * stride_1))) + tir.load( "float32", C_1.data, (i * stride_2))
def tir_packed_call() -> None: A = tir.var("handle") B = tir.var("handle") C = tir.var("handle") # body tir.evaluate( tir.tvm_call_cpacked( "tvm_test_cpacked", A, B, C, dtype="int32", ))
def tir_multi_output(a0: ty.handle, a1: ty.handle, b0: ty.handle, b1: ty.handle) -> None: m = tir.var("int32") n = tir.var("int32") A0 = tir.match_buffer(a0, (m, n)) A1 = tir.match_buffer(a1, (m, n)) B0 = tir.match_buffer(b0, (m, n)) B1 = tir.match_buffer(b1, (m, n)) for i0, i1 in tir.grid(m, n): with tir.block([m, n], "B.v0") as [i, j]: B0[i, j] = A0[i, j] + 2.0 with tir.block([m, n], "B.v1") as [i, j]: B1[i, j] = A1[i, j] * 3.0
def element_wise(a: ty.handle, c: ty.handle) -> None: m = tir.var("int32") n = tir.var("int32") A = tir.match_buffer(a, (m, n), "float32") C = tir.match_buffer(c, (m, n), "float32") B = tir.alloc_buffer((m, n), "float32") with tir.block([m, n], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 with tir.block([m, n], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0
def opaque_access(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (32, 64, 128)) B = tir.match_buffer(b, (64, 64, 64)) for i, j, k in tir.grid(2, 64, 8): with tir.block([]): tir.reads([]) tir.writes(A[i * 16:i * 16 + 16, j, k * 16:k * 16 + 16]) sub_A = tir.match_buffer( A[i * 16:i * 16 + 16, j, k * 16:k * 16 + 16], (16, 1, 16), strides=[8192, 128, 1], offset_factor=1, ) tir.evaluate( tir.intrin_test( sub_A.data, sub_A.elem_offset, sub_A.strides[0], sub_A.strides[1], sub_A.shape[0], sub_A.shape[1], dtype="handle", )) for i, j, k in tir.grid(64, 2, 8): with tir.block([]): Bs_0 = tir.var("int32") Bs_1 = tir.var("int32") tir.reads([]) tir.writes(B[i, j * 32:j * 32 + 32, k * 8:k * 8 + 8]) sub_B = tir.match_buffer( B[i, j * 32:j * 32 + 32, k * 8:k * 8 + 8], (32, 8), strides=[Bs_0, Bs_1], offset_factor=1, ) tir.evaluate( tir.intrin_test( sub_B.data, sub_B.elem_offset, sub_B.strides[0], sub_B.strides[1], sub_B.shape[0], sub_B.shape[1], dtype="handle", ))
def fail_buffer_bind(a: ty.handle) -> None: A = tir.match_buffer(a, (8, 8)) for i, j in tir.grid(8, 2): with tir.block([]): stride = tir.var("int32") sub_A = tir.match_buffer(A[i, j * 4:j * 4 + 4], (1, 4), strides=[stride, stride], offset_factor=1) for jj in range(0, 4): sub_A[i, j * 4 + jj] = 1
def tir_packed_call() -> None: A = tir.var("handle") B = tir.var("handle") C = tir.var("handle") # body tvm_value_2 = tir.var("handle") tvm_value_1 = tir.var("handle") tvm_value_0 = tir.var("handle") with tir.let(tvm_value_2, tir.tvm_stack_alloca("array", 1, dtype="handle")): with tir.let(tvm_value_1, tir.tvm_stack_alloca("array", 1, dtype="handle")): with tir.let(tvm_value_0, tir.tvm_stack_alloca("array", 1, dtype="handle")): tir.evaluate( tir.tvm_struct_set(tvm_value_0, 0, 1, A, dtype="handle")) tir.evaluate( tir.tvm_struct_set(tvm_value_1, 0, 1, B, dtype="handle")) tir.evaluate( tir.tvm_struct_set(tvm_value_2, 0, 1, C, dtype="handle")) tir.evaluate( tir.tvm_call_cpacked( "tvm_test_cpacked", tvm_value_0, tvm_value_1, tvm_value_2, dtype="int32", ))
def param_in_arith_exprs_n_16(a: ty.handle, b: ty.handle) -> None: n = tir.var("int32") A = tir.match_buffer(a, [2, 8], "int32") B = tir.match_buffer(b, [16], "int32") with tir.block([15], "") as [vi]: B[vi] = A[vi // 8, vi % 8] + 714
def param_in_arith_exprs(a: ty.handle, b: ty.handle) -> None: n = tir.var("int32") A = tir.match_buffer(a, [n // 8, 8], "int32") B = tir.match_buffer(b, [n], "int32") with tir.block([n - 1], "") as [vi]: B[vi] = A[vi // 8, vi % 8] + (n + 1) * 42
def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # match buffer A = tir.match_buffer(a, [1024, 1024], "float16") B = tir.match_buffer(b, [1024, 1024], "float16") C = tir.match_buffer(c, [1024, 1024], "float32") # body for blockIdx_x in tir.thread_binding(0, 16, "blockIdx.x"): for blockIdx_y in tir.thread_binding(0, 8, "blockIdx.y"): with tir.block([16, 8]) as [bx, by]: tir.bind(bx, blockIdx_x) tir.bind(by, blockIdx_y) shared_A = tir.alloc_buffer([1024, 1024], "float16", scope="shared") shared_B = tir.alloc_buffer([1024, 1024], "float16", scope="shared") wmma_A = tir.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a") wmma_B = tir.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_b") wmma_C = tir.alloc_buffer([1024, 1024], "float32", scope="wmma.accumulator") for ty in tir.thread_binding(0, 2, "threadIdx.y"): for tz in tir.thread_binding(0, 2, "threadIdx.z"): for i, j in tir.grid(2, 4): with tir.block([64, 64]) as [vi, vj]: tir.bind(vi, bx * 4 + ty * 2 + i) tir.bind(vj, by * 8 + tz * 4 + j) tir.reads([]) tir.writes(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) C0 = tir.match_buffer( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) tir.evaluate( tir.tvm_fill_fragment( C0.data, 16, 16, 16, i * 4 + j, tir.float32(0), dtype="handle", ) ) for ko in range(0, 32): # copy data from global to shared for tx in tir.thread_binding(0, 32, "threadIdx.x"): for i0, j0 in tir.grid(1, 4): for j1 in tir.vectorized(0, 4): with tir.block([1024, 1024]) as [vi, vj]: tir.bind(vi, bx * 64 + ty * 32 + tx + i0) tir.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) shared_A[vi, vj + 8] = A[vi, vj] for i0, j0 in tir.grid(2, 4): for j1 in tir.vectorized(0, 4): with tir.block([1024, 1024]) as [vi, vj]: tir.bind(vi, by * 128 + ty * 64 + tx * 2 + i0) tir.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) shared_B[vi, vj + 8] = B[vi, vj] for ki in range(0, 2): for i in range(0, 2): with tir.block([64, 64]) as [vi, vk]: tir.bind(vi, bx * 4 + ty * 2 + i) tir.bind(vk, ko * 2 + ki) tir.reads( shared_A[ vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ] ) tir.writes( wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16] ) s0 = tir.var("int32") s1 = tir.var("int32") A0 = tir.match_buffer( shared_A[ vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ], (16, 16 + 8), "float16", strides=[s0, s1], scope="shared", offset_factor=1, ) wmma_A0 = tir.match_buffer( wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_a", offset_factor=1, ) tir.evaluate( tir.tvm_load_matrix_sync( wmma_A0.data, 16, 16, 16, i, tir.tvm_access_ptr( tir.type_annotation(dtype="float16"), A0.data, A0.elem_offset + 8, A0.strides[0], 1, dtype="handle", ), A0.strides[0], "row_major", dtype="handle", ) ) for j in range(0, 4): with tir.block([64, 64]) as [vj, vk]: tir.bind(vj, by * 8 + tz * 4 + j) tir.bind(vk, ko * 2 + ki) tir.reads( shared_B[ vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ] ) tir.writes( wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16] ) s0 = tir.var("int32") s1 = tir.var("int32") B0 = tir.match_buffer( shared_B[ vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ], (16, 16 + 8), "float16", strides=[s0, s1], scope="shared", offset_factor=1, ) wmma_B0 = tir.match_buffer( wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_b", offset_factor=1, ) tir.evaluate( tir.tvm_load_matrix_sync( wmma_B0.data, 16, 16, 16, j, tir.tvm_access_ptr( tir.type_annotation(dtype="float16"), B0.data, B0.elem_offset + 8, B0.strides[0], 1, dtype="handle", ), B0.strides[0], "col_major", dtype="handle", ) ) for i, j in tir.grid(2, 4): with tir.block([64, 64, tir.reduce_axis(0, 64)]) as [ vi, vj, vk, ]: tir.bind(vi, bx * 4 + ty * 2 + i) tir.bind(vj, by * 8 + tz * 4 + j) tir.bind(vk, ko * 2 + ki) tir.reads( [ wmma_A[ vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 ], wmma_B[ vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 ], wmma_C[ vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16 ], ] ) tir.writes( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16] ) wmma_A1 = tir.match_buffer( wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_a", offset_factor=1, ) wmma_B1 = tir.match_buffer( wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_b", offset_factor=1, ) wmma_C1 = tir.match_buffer( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) tir.evaluate( tir.tvm_mma_sync( wmma_C1.data, i * 4 + j, wmma_A1.data, i, wmma_B1.data, j, wmma_C1.data, i * 4 + j, dtype="handle", ) ) for i, j in tir.grid(2, 4): with tir.block([64, 64]) as [vi, vj]: tir.bind(vi, bx * 4 + ty * 2 + i) tir.bind(vj, by * 8 + tz * 4 + j) tir.reads(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) tir.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) s0 = tir.var("int32") s1 = tir.var("int32") wmma_C2 = tir.match_buffer( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) C1 = tir.match_buffer( C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[s0, s1], offset_factor=1, ) tir.evaluate( tir.tvm_store_matrix_sync( wmma_C2.data, 16, 16, 16, i * 4 + j, tir.tvm_access_ptr( tir.type_annotation(dtype="float32"), C1.data, C1.elem_offset, C1.strides[0], 1, dtype="handle", ), C1.strides[0], "row_major", dtype="handle", ) )