예제 #1
0
def test_gpu_batch_norm_bmn():
    expected = [
        [],
        [
            'b0 = sch.get_block(name="C", func_name="main")',
            "b1, = sch.get_consumers(block=b0)",
            "l2, = sch.get_loops(block=b1)",
            "v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
            "l4, l5 = sch.split(loop=l2, factors=[None, v3], preserve_unit_iters=True)",
            'sch.bind(loop=l5, thread_axis="threadIdx.x")',
            "sch.compute_at(block=b0, loop=l4, preserve_unit_loops=True)",
            'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
            "l6, l7, l8, l9 = sch.get_loops(block=b0)",
            "l10 = sch.fuse(l8, l9, preserve_unit_iters=True)",
            "l11, l12 = sch.split(loop=l10, factors=[None, v3], preserve_unit_iters=True)",
            'sch.bind(loop=l12, thread_axis="threadIdx.x")',
        ],
    ]
    target = Target("nvidia/geforce-rtx-3090", host="llvm")
    ctx = _create_context(
        create_prim_func(te_workload.norm_bmn(
            B=1,
            M=512,
            N=512,
        )),
        target=target,
        rule=cross_thread_reduction(target=target),
    )
    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
    assert len(spaces) == 2
    check_trace(spaces, expected)
예제 #2
0
def test_gpu_softmax_mn_after_inline():
    expected = [
        [],
        [
            'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")',
            "v1 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
            "l2, l3 = sch.get_loops(block=b0)",
            "l4, l5 = sch.split(loop=l3, factors=[None, v1], preserve_unit_iters=True)",
            'sch.bind(loop=l5, thread_axis="threadIdx.x")',
        ],
        [
            'b0 = sch.get_block(name="T_softmax_expsum", func_name="main")',
            "b1, = sch.get_consumers(block=b0)",
            "l2, l3 = sch.get_loops(block=b1)",
            "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
            "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)",
            'sch.bind(loop=l6, thread_axis="threadIdx.x")',
            "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)",
            'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
            "l7, l8, l9 = sch.get_loops(block=b0)",
            "l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)",
            'sch.bind(loop=l11, thread_axis="threadIdx.x")',
        ],
        [
            'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")',
            'b1 = sch.get_block(name="T_softmax_expsum", func_name="main")',
            "b2, = sch.get_consumers(block=b1)",
            "l3, l4 = sch.get_loops(block=b2)",
            "v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
            "l6, l7 = sch.split(loop=l4, factors=[None, v5], preserve_unit_iters=True)",
            'sch.bind(loop=l7, thread_axis="threadIdx.x")',
            "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True)",
            'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")',
            "l8, l9, l10 = sch.get_loops(block=b1)",
            "l11, l12 = sch.split(loop=l10, factors=[None, v5], preserve_unit_iters=True)",
            'sch.bind(loop=l12, thread_axis="threadIdx.x")',
            "b13, b14 = sch.get_consumers(block=b0)",
            "l15, l16, l17, l18 = sch.get_loops(block=b13)",
            "sch.compute_at(block=b0, loop=l15, preserve_unit_loops=True)",
            'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
            "l19, l20, l21 = sch.get_loops(block=b0)",
            "l22, l23 = sch.split(loop=l21, factors=[None, v5], preserve_unit_iters=True)",
            'sch.bind(loop=l23, thread_axis="threadIdx.x")',
        ],
    ]
    target = Target("nvidia/geforce-rtx-3090", host="llvm")
    ctx = _create_context(
        mod=Softmax_mn_after_inline,
        target=target,
        rule=cross_thread_reduction(target=target),
    )
    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
    assert len(spaces) == 4
    check_trace(spaces, expected)