def test_func_call():
    def shared_16x16_to_ldmatrix_32x8_layout(i, j):
        thread_id = (i % 8) * 4 + (j % 8) // 2
        return thread_id, (j // 8) * 4 + (i // 8) * 2 + (j % 2)

    @T.prim_func
    def mma_sync_m16n16k16_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
        A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
        B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
        C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")

        with T.block("root"):
            T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
            T.writes(C[0:32, 0:8])
            for i, j, k in T.grid(16, 16, 16):
                with T.block("C"):
                    i, j, k = T.axis.remap("SSR", [i, j, k])
                    thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j)
                    thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i, k)
                    thread_id_B, local_id_B = shared_16x16_to_ldmatrix_32x8_layout(k, j)

                    T.reads(
                        C[thread_id_C, local_id_C],
                        A[thread_id_A, local_id_A],
                        B[thread_id_B, local_id_B],
                    )
                    T.writes(C[thread_id_C, local_id_C])

                    C[thread_id_C, local_id_C] += (
                        A[thread_id_A, local_id_A] * B[thread_id_B, local_id_B]
                    )

    @T.prim_func
    def mma_sync_m16n16k16_desc_manual(a: T.handle, b: T.handle, c: T.handle) -> None:
        A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
        B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
        C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")

        with T.block("root"):
            T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
            T.writes(C[0:32, 0:8])
            for i, j, k in T.grid(16, 16, 16):
                with T.block("C"):
                    i, j, k = T.axis.remap("SSR", [i, j, k])
                    T.reads(
                        C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2],
                        A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2],
                        B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2],
                    )
                    T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2])
                    C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = (
                        C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2]
                        + A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2]
                        * B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2]
                    )

    assert_structural_equal(mma_sync_m16n16k16_desc, mma_sync_m16n16k16_desc_manual)
def test_match_buffer_1d():
    @T.prim_func
    def func_no_sugar(a: T.handle):
        A = T.match_buffer(a, shape=(16,))
        for i in T.serial(16):
            A[i] = 0.0

    @T.prim_func
    def func_with_sugar(A: T.Buffer[16, "float32"]):
        for i in T.serial(16):
            A[i] = 0.0

    assert_structural_equal(func_no_sugar, func_with_sugar)
Exemplo n.º 3
0
def test_letstmt_bind_with_constant():
    @T.prim_func
    def constant_binds():
        x = 1
        y = 42.0
        T.evaluate(T.cast(x, "float32") + y)

    @T.prim_func
    def constant_binds_wrapped():
        x = T.int32(1)
        y = T.float32(42.0)
        T.evaluate(T.cast(x, "float32") + y)

    assert_structural_equal(constant_binds, constant_binds_wrapped)
Exemplo n.º 4
0
def test_index_mapping():
    index_map = IndexMap.from_func(lambda i: [i // 4, i % 4])

    assert_structural_equal(index_map.map_indices([0]), [0, 0])
    assert_structural_equal(index_map.map_indices([3]), [0, 3])
    assert_structural_equal(index_map.map_indices([4]), [1, 0])
    assert_structural_equal(index_map.map_indices([42]), [10, 2])
 def apply(
     self,
     task_scheduler: TaskScheduler,
     task_id: int,
     measure_candidates: List[MeasureCandidate],
     builds: List[BuilderResult],
     results: List[RunnerResult],
 ) -> None:
     assert len(measure_candidates) == 1
     assert_structural_equal(measure_candidates[0].sch.mod, Matmul)
     assert (len(builds) == 1 and builds[0].error_msg is None
             and builds[0].artifact_path == "test_build")
     assert (len(results) == 1 and results[0].error_msg is None
             and len(results[0].run_secs) == 2)
Exemplo n.º 6
0
def test_shape_mapping():
    index_map = IndexMap.from_func(lambda i: [i // 4, i % 4])

    assert_structural_equal(index_map.map_shape([4]), [1, 4])
    assert_structural_equal(index_map.map_shape([16]), [4, 4])

    assert_structural_equal(index_map.map_shape([14]), [4, 4])
Exemplo n.º 7
0
def test_reads_writes_syntax_sugar():
    assert_structural_equal(transformed_matmul_no_syntax_sugar,
                            transformed_matmul_syntax_sugar)
Exemplo n.º 8
0
def test_match_buffer_int64():
    original = match_buffer_int64
    after_roundtrip = match_buffer_int64_after_roundtrip
    assert_structural_equal(original, after_roundtrip, True)
Exemplo n.º 9
0
def test_dynamic_shape_gemm():
    gemm_dyn_shape_roundtrip = from_source(gemm_dyn_shape.script())
    assert_structural_equal(gemm_dyn_shape, gemm_dyn_shape_roundtrip)
Exemplo n.º 10
0
def test_match_buffer_syntax_sugar():
    # with kwargs
    assert_structural_equal(elementwise_handle, elementwise_buffer_kwargs)
    # without kwargs
    assert_structural_equal(elementwise_handle, elementwise_buffer_no_kwargs)
Exemplo n.º 11
0
def test_loop_syntax_sugar():
    assert_structural_equal(loop_no_syntax_sugar, loop_syntax_sugar)