示例#1
0
文件: x86.py 项目: saudet/tvm
def batch_matmul_strategy_cpu(attrs, inputs, out_type, target):
    """batch_matmul x86 strategy"""
    strategy = _op.OpStrategy()
    if is_dynamic(out_type) or is_auto_scheduler_enabled():
        strategy.add_implementation(
            wrap_compute_batch_matmul(
                topi.nn.batch_matmul, need_auto_scheduler_layout=True, need_out_dtype=True
            ),
            wrap_topi_schedule(topi.generic.nn.schedule_batch_matmul),
            name="batch_matmul.generic",
            plevel=10,
        )
    else:
        strategy.add_implementation(
            wrap_compute_batch_matmul(topi.x86.batch_matmul, need_out_dtype=True),
            wrap_topi_schedule(topi.x86.schedule_batch_matmul),
            name="batch_matmul.x86",
            plevel=10,
        )
    if "cblas" in target.libs:
        strategy.add_implementation(
            wrap_compute_batch_matmul(topi.x86.batch_matmul_cblas),
            wrap_topi_schedule(topi.x86.schedule_batch_matmul_cblas),
            name="batch_matmul_cblas.x86",
            plevel=15,
        )
    if "mkl" in target.libs:
        strategy.add_implementation(
            wrap_compute_batch_matmul(topi.x86.batch_matmul_mkl),
            wrap_topi_schedule(topi.x86.schedule_batch_matmul_mkl),
            name="batch_matmul_mkl.x86",
            plevel=15,
        )
    return strategy
示例#2
0
 def _vm_wrapper(*args, **kwargs):
     args = self._convert_args(main, args, kwargs)
     ret_type = self.mod["main"].checked_type.ret_type
     if is_dynamic(ret_type) and "llvm" not in str(
             self.target) and "arm" not in str(self.target):
         raise ValueError(
             "Virtual Machine only supports dynamic graphs on CPU, got output type",
             ret_type, "on target", self.target)
     return self.vm.run(*args)
示例#3
0
文件: x86.py 项目: chenghanpeng/tvm
def batch_matmul_strategy_cpu(attrs, inputs, out_type, target):
    """batch_matmul x86 strategy"""
    strategy = _op.OpStrategy()
    mcpu = Target.current().mcpu

    need_auto_scheduler_layout = is_auto_scheduler_enabled()
    need_meta_schedule_layout = is_meta_schedule_enabled()

    if (not attrs.transpose_a and attrs.transpose_b and target_has_vnni(mcpu)
            and inputs[0].dtype == "uint8" and inputs[1].dtype == "int8"
            and inputs[1].shape[-2] % 16 == 0
            and inputs[1].shape[-1] % 4 == 0):
        strategy.add_implementation(
            wrap_compute_batch_matmul(topi.x86.batch_matmul_vnni_compute,
                                      need_out_dtype=True),
            wrap_topi_schedule(topi.x86.schedule_batch_matmul_vnni),
            name="batch_matmul_vnni.x86",
            plevel=10,
        )
    elif is_dynamic(
            out_type
    ) or need_auto_scheduler_layout or need_meta_schedule_layout:
        strategy.add_implementation(
            wrap_compute_batch_matmul(
                topi.nn.batch_matmul,
                need_out_dtype=True,
                need_auto_scheduler_layout=need_auto_scheduler_layout,
                need_meta_schedule_layout=need_meta_schedule_layout,
            ),
            wrap_topi_schedule(topi.generic.nn.schedule_batch_matmul),
            name="batch_matmul.generic",
            plevel=10,
        )
    else:
        strategy.add_implementation(
            wrap_compute_batch_matmul(topi.x86.batch_matmul,
                                      need_out_dtype=True),
            wrap_topi_schedule(topi.x86.schedule_batch_matmul),
            name="batch_matmul.x86",
            plevel=10,
        )
    if "cblas" in target.libs:
        strategy.add_implementation(
            wrap_compute_batch_matmul(topi.x86.batch_matmul_cblas),
            wrap_topi_schedule(topi.x86.schedule_batch_matmul_cblas),
            name="batch_matmul_cblas.x86",
            plevel=15,
        )
    if "mkl" in target.libs:
        strategy.add_implementation(
            wrap_compute_batch_matmul(topi.x86.batch_matmul_mkl),
            wrap_topi_schedule(topi.x86.schedule_batch_matmul_mkl),
            name="batch_matmul_mkl.x86",
            plevel=15,
        )
    return strategy