def test_func_call(): def shared_16x16_to_ldmatrix_32x8_layout(i, j): thread_id = (i % 8) * 4 + (j % 8) // 2 return thread_id, (j // 8) * 4 + (i // 8) * 2 + (j % 2) @T.prim_func def mma_sync_m16n16k16_desc(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp") B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp") C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp") with T.block("root"): T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8]) T.writes(C[0:32, 0:8]) for i, j, k in T.grid(16, 16, 16): with T.block("C"): i, j, k = T.axis.remap("SSR", [i, j, k]) thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j) thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i, k) thread_id_B, local_id_B = shared_16x16_to_ldmatrix_32x8_layout(k, j) T.reads( C[thread_id_C, local_id_C], A[thread_id_A, local_id_A], B[thread_id_B, local_id_B], ) T.writes(C[thread_id_C, local_id_C]) C[thread_id_C, local_id_C] += ( A[thread_id_A, local_id_A] * B[thread_id_B, local_id_B] ) @T.prim_func def mma_sync_m16n16k16_desc_manual(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp") B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp") C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp") with T.block("root"): T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8]) T.writes(C[0:32, 0:8]) for i, j, k in T.grid(16, 16, 16): with T.block("C"): i, j, k = T.axis.remap("SSR", [i, j, k]) T.reads( C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2], A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2], B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2], ) T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2]) C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = ( C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] + A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2] * B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2] ) assert_structural_equal(mma_sync_m16n16k16_desc, mma_sync_m16n16k16_desc_manual)
def test_match_buffer_1d(): @T.prim_func def func_no_sugar(a: T.handle): A = T.match_buffer(a, shape=(16,)) for i in T.serial(16): A[i] = 0.0 @T.prim_func def func_with_sugar(A: T.Buffer[16, "float32"]): for i in T.serial(16): A[i] = 0.0 assert_structural_equal(func_no_sugar, func_with_sugar)
def test_letstmt_bind_with_constant(): @T.prim_func def constant_binds(): x = 1 y = 42.0 T.evaluate(T.cast(x, "float32") + y) @T.prim_func def constant_binds_wrapped(): x = T.int32(1) y = T.float32(42.0) T.evaluate(T.cast(x, "float32") + y) assert_structural_equal(constant_binds, constant_binds_wrapped)
def test_index_mapping(): index_map = IndexMap.from_func(lambda i: [i // 4, i % 4]) assert_structural_equal(index_map.map_indices([0]), [0, 0]) assert_structural_equal(index_map.map_indices([3]), [0, 3]) assert_structural_equal(index_map.map_indices([4]), [1, 0]) assert_structural_equal(index_map.map_indices([42]), [10, 2])
def apply( self, task_scheduler: TaskScheduler, task_id: int, measure_candidates: List[MeasureCandidate], builds: List[BuilderResult], results: List[RunnerResult], ) -> None: assert len(measure_candidates) == 1 assert_structural_equal(measure_candidates[0].sch.mod, Matmul) assert (len(builds) == 1 and builds[0].error_msg is None and builds[0].artifact_path == "test_build") assert (len(results) == 1 and results[0].error_msg is None and len(results[0].run_secs) == 2)
def test_shape_mapping(): index_map = IndexMap.from_func(lambda i: [i // 4, i % 4]) assert_structural_equal(index_map.map_shape([4]), [1, 4]) assert_structural_equal(index_map.map_shape([16]), [4, 4]) assert_structural_equal(index_map.map_shape([14]), [4, 4])
def test_reads_writes_syntax_sugar(): assert_structural_equal(transformed_matmul_no_syntax_sugar, transformed_matmul_syntax_sugar)
def test_match_buffer_int64(): original = match_buffer_int64 after_roundtrip = match_buffer_int64_after_roundtrip assert_structural_equal(original, after_roundtrip, True)
def test_dynamic_shape_gemm(): gemm_dyn_shape_roundtrip = from_source(gemm_dyn_shape.script()) assert_structural_equal(gemm_dyn_shape, gemm_dyn_shape_roundtrip)
def test_match_buffer_syntax_sugar(): # with kwargs assert_structural_equal(elementwise_handle, elementwise_buffer_kwargs) # without kwargs assert_structural_equal(elementwise_handle, elementwise_buffer_no_kwargs)
def test_loop_syntax_sugar(): assert_structural_equal(loop_no_syntax_sugar, loop_syntax_sugar)