def test_gpu_batch_norm_bmn(): expected = [ [], [ 'b0 = sch.get_block(name="C", func_name="main")', "b1, = sch.get_consumers(block=b0)", "l2, = sch.get_loops(block=b1)", "v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l4, l5 = sch.split(loop=l2, factors=[None, v3], preserve_unit_iters=True)", 'sch.bind(loop=l5, thread_axis="threadIdx.x")', "sch.compute_at(block=b0, loop=l4, preserve_unit_loops=True)", 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', "l6, l7, l8, l9 = sch.get_loops(block=b0)", "l10 = sch.fuse(l8, l9, preserve_unit_iters=True)", "l11, l12 = sch.split(loop=l10, factors=[None, v3], preserve_unit_iters=True)", 'sch.bind(loop=l12, thread_axis="threadIdx.x")', ], ] target = Target("nvidia/geforce-rtx-3090", host="llvm") ctx = _create_context( create_prim_func(te_workload.norm_bmn( B=1, M=512, N=512, )), target=target, rule=cross_thread_reduction(target=target), ) spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) assert len(spaces) == 2 check_trace(spaces, expected)
def test_gpu_softmax_mn_after_inline(): expected = [ [], [ 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', "v1 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l2, l3 = sch.get_loops(block=b0)", "l4, l5 = sch.split(loop=l3, factors=[None, v1], preserve_unit_iters=True)", 'sch.bind(loop=l5, thread_axis="threadIdx.x")', ], [ 'b0 = sch.get_block(name="T_softmax_expsum", func_name="main")', "b1, = sch.get_consumers(block=b0)", "l2, l3 = sch.get_loops(block=b1)", "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)", 'sch.bind(loop=l6, thread_axis="threadIdx.x")', "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)", 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', "l7, l8, l9 = sch.get_loops(block=b0)", "l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)", 'sch.bind(loop=l11, thread_axis="threadIdx.x")', ], [ 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', 'b1 = sch.get_block(name="T_softmax_expsum", func_name="main")', "b2, = sch.get_consumers(block=b1)", "l3, l4 = sch.get_loops(block=b2)", "v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l6, l7 = sch.split(loop=l4, factors=[None, v5], preserve_unit_iters=True)", 'sch.bind(loop=l7, thread_axis="threadIdx.x")', "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True)", 'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")', "l8, l9, l10 = sch.get_loops(block=b1)", "l11, l12 = sch.split(loop=l10, factors=[None, v5], preserve_unit_iters=True)", 'sch.bind(loop=l12, thread_axis="threadIdx.x")', "b13, b14 = sch.get_consumers(block=b0)", "l15, l16, l17, l18 = sch.get_loops(block=b13)", "sch.compute_at(block=b0, loop=l15, preserve_unit_loops=True)", 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', "l19, l20, l21 = sch.get_loops(block=b0)", "l22, l23 = sch.split(loop=l21, factors=[None, v5], preserve_unit_iters=True)", 'sch.bind(loop=l23, thread_axis="threadIdx.x")', ], ] target = Target("nvidia/geforce-rtx-3090", host="llvm") ctx = _create_context( mod=Softmax_mn_after_inline, target=target, rule=cross_thread_reduction(target=target), ) spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) assert len(spaces) == 4 check_trace(spaces, expected)