def test_cpu_matmul(): expected = [ [], [ 'b0 = sch.get_block(name="C", func_name="main")', "l1, l2, l3 = sch.get_loops(block=b0)", "v4, v5 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", "l6, l7 = sch.split(loop=l3, factors=[v4, v5], preserve_unit_iters=True)", "b8 = sch.rfactor(loop=l7, factor_axis=2)", 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', ], [ 'b0 = sch.get_block(name="C", func_name="main")', "l1, l2, l3 = sch.get_loops(block=b0)", "v4, v5 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", "l6, l7 = sch.split(loop=l3, factors=[v4, v5], preserve_unit_iters=True)", "b8 = sch.rfactor(loop=l6, factor_axis=2)", 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', ], ] target = Target("llvm --num-cores=32") ctx = _create_context( create_prim_func(te_workload.matmul( n=4, m=4, k=512, )), target=target, rule=add_rfactor(target=target), ) spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) assert len(spaces) == 3 check_trace(spaces, expected)
def test_cpu_matmul(): expected = [ [ 'b0 = sch.get_block(name="C", func_name="main")', 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', "l1, l2, l3 = sch.get_loops(block=b0)", "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", 'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=1)", ], [ 'b0 = sch.get_block(name="C", func_name="main")', 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', "l1, l2, l3 = sch.get_loops(block=b0)", "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", 'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=1)", ], [ 'b0 = sch.get_block(name="C", func_name="main")', 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', "l1, l2, l3 = sch.get_loops(block=b0)", "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", ], ] target = Target("llvm") ctx = _create_context( create_prim_func( te_workload.matmul( n=512, m=512, k=512, ) ), target=target, rule=multi_level_tiling(target=target), ) spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) assert len(spaces) == 3 check_trace(spaces, expected)
def test_cuda_matmul(): # pylint: disable=line-too-long expected = [ [ 'b0 = sch.get_block(name="C", func_name="main")', 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")', "l1, l2, l3 = sch.get_loops(block=b0)", "v4, v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l1, n=5, max_innermost_factor=64)", "l9, l10, l11, l12, l13 = sch.split(loop=l1, factors=[v4, v5, v6, v7, v8])", "v14, v15, v16, v17, v18 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", "l19, l20, l21, l22, l23 = sch.split(loop=l2, factors=[v14, v15, v16, v17, v18])", "v24, v25, v26 = sch.sample_perfect_tile(loop=l3, n=3, max_innermost_factor=64)", "l27, l28, l29 = sch.split(loop=l3, factors=[v24, v25, v26])", "sch.reorder(l9, l19, l10, l20, l11, l21, l27, l28, l12, l22, l29, l13, l23)", "l30 = sch.fuse(l9, l19)", 'sch.bind(loop=l30, thread_axis="blockIdx.x")', "l31 = sch.fuse(l10, l20)", 'sch.bind(loop=l31, thread_axis="vthread.x")', "l32 = sch.fuse(l11, l21)", 'sch.bind(loop=l32, thread_axis="threadIdx.x")', 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32)', 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024)', 'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=1)", 'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=1)", "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", "l41 = sch.fuse(l39, l40)", "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)', 'b43 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=1)", "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)", "l50 = sch.fuse(l48, l49)", "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", 'sch.annotate(block_or_loop=b43, ann_key="meta_schedule.cooperative_fetch", ann_val=v51)', ] ] # pylint: enable=line-too-long target = Target("cuda --max_threads_per_block=1024 --thread_warp_size=32", host="llvm") ctx = _create_context( create_prim_func( te_workload.matmul( n=512, m=512, k=512, ) ), target=target, rule=multi_level_tiling(target=target), ) spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) assert len(spaces) == 1 check_trace(spaces, expected)
def test_rewrite_cooperative_fetch(): mod = create_prim_func(te_workload.matmul(n=512, m=512, k=512)) target = _target() ctx = _create_context(mod, target) sch = tir.Schedule(mod, debug_mask="all") # fmt: off # pylint: disable=line-too-long,invalid-name b0 = sch.get_block(name="C", func_name="main") b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") l2, l3, l4 = sch.get_loops(block=b0) v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 16, 1, 2, 16]) l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9]) v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[16, 1, 8, 2, 2]) l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19]) v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64, decision=[1, 16, 32]) l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27]) sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24) l31 = sch.fuse(l10, l20) sch.bind(loop=l31, thread_axis="blockIdx.x") l32 = sch.fuse(l11, l21) sch.bind(loop=l32, thread_axis="vthread.x") l33 = sch.fuse(l12, l22) sch.bind(loop=l33, thread_axis="threadIdx.x") b34 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared") sch.compute_at(block=b34, loop=l28, preserve_unit_loops=True) _, _, _, _, l39, l40 = sch.get_loops(block=b34) l41 = sch.fuse(l39, l40) _, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4, decision=[262144, 1]) sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v43) b44 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared") sch.compute_at(block=b44, loop=l28, preserve_unit_loops=True) _, _, _, _, l49, l50 = sch.get_loops(block=b44) l51 = sch.fuse(l49, l50) _, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4, decision=[8192, 2]) sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v53) sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=True) # pylint: enable=line-too-long,invalid-name # fmt: on sch.enter_postproc() assert ctx.postprocs[0].apply(sch) tvm.ir.assert_structural_equal(sch.mod, AfterRewrite0)
def test_get_auto_tensorize_mapping_info_matmul(n, m, k, expected): matmul = create_prim_func( te_workload.matmul(n, m, k, in_dtype="float16", out_dtype="float32")) check_index_map(matmul, "C", WMMA_SYNC_16x16x16_f16f16f32_INTRIN, expected)