예제 #1
0
def test_rewrite_tensorize_conv2d_nchwc_vnni():
    mod = Conv2dNCHWcVNNIModuleTiled
    target = tvm.target.Target("llvm -mcpu=cascadelake -num-cores 4")
    ctx = _create_context(
        mod,
        target,
        [
            postproc.RewriteReductionBlock(),
            postproc.RewriteTensorize(True),
        ],
    )
    sch = tvm.tir.Schedule(mod, debug_mask="all")
    sch.enter_postproc()

    for proc in ctx.postprocs:
        proc.apply(sch)

    tvm.ir.assert_structural_equal(sch.mod, Conv2dNCHWcVNNIModuleTensorized)
예제 #2
0
def test_rewrite_tensorize_dense_dp4a():
    mod = DenseDP4ATiled
    target = tvm.target.Target("nvidia/geforce-rtx-3070")
    ctx = _create_context(
        mod,
        target,
        [
            postproc.RewriteCooperativeFetch(),
            postproc.RewriteReductionBlock(),
            postproc.RewriteTensorize(),
        ],
    )
    sch = tvm.tir.Schedule(mod, debug_mask="all")
    sch.enter_postproc()

    for proc in ctx.postprocs:
        proc.apply(sch)

    tvm.ir.assert_structural_equal(sch.mod, DenseDP4ATensorized)
        schedule_rule.ParallelizeVectorizeUnroll(
            max_jobs_per_core=-1,  # disable parallelize
            max_vectorize_extent=-1,  # disable vectorize
            unroll_max_steps=[0, 16, 64, 512, 1024],
            unroll_explicit=True,
        ),
    ]


sch_rules_for_dp4a = get_sch_rules_for_dp4a(DP4A_INTRIN)
sch_rules_for_sdot4 = get_sch_rules_for_dp4a(AMDGPU_SDOT4_INTRIN)

postprocs_for_vnni = [
    postproc.DisallowDynamicLoop(),
    postproc.RewriteParallelVectorizeUnroll(),
    postproc.RewriteReductionBlock(),
    postproc.RewriteTensorize(vectorize_init_loop=True),
]

postprocs_for_dp4a = [
    postproc.DisallowDynamicLoop(),
    postproc.RewriteCooperativeFetch(),
    postproc.RewriteUnboundBlock(),
    postproc.RewriteParallelVectorizeUnroll(),
    postproc.RewriteReductionBlock(),
    postproc.RewriteTensorize(),
    postproc.VerifyGPUCode(),
]


def tune_and_test(relay_mod, data_np, weight_np, op_name, target, sch_rules,