def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): """batch_matmul x86 strategy""" strategy = _op.OpStrategy() if is_dynamic(out_type) or is_auto_scheduler_enabled(): strategy.add_implementation( wrap_compute_batch_matmul( topi.nn.batch_matmul, need_auto_scheduler_layout=True, need_out_dtype=True ), wrap_topi_schedule(topi.generic.nn.schedule_batch_matmul), name="batch_matmul.generic", plevel=10, ) else: strategy.add_implementation( wrap_compute_batch_matmul(topi.x86.batch_matmul, need_out_dtype=True), wrap_topi_schedule(topi.x86.schedule_batch_matmul), name="batch_matmul.x86", plevel=10, ) if "cblas" in target.libs: strategy.add_implementation( wrap_compute_batch_matmul(topi.x86.batch_matmul_cblas), wrap_topi_schedule(topi.x86.schedule_batch_matmul_cblas), name="batch_matmul_cblas.x86", plevel=15, ) if "mkl" in target.libs: strategy.add_implementation( wrap_compute_batch_matmul(topi.x86.batch_matmul_mkl), wrap_topi_schedule(topi.x86.schedule_batch_matmul_mkl), name="batch_matmul_mkl.x86", plevel=15, ) return strategy
def _vm_wrapper(*args, **kwargs): args = self._convert_args(main, args, kwargs) ret_type = self.mod["main"].checked_type.ret_type if is_dynamic(ret_type) and "llvm" not in str( self.target) and "arm" not in str(self.target): raise ValueError( "Virtual Machine only supports dynamic graphs on CPU, got output type", ret_type, "on target", self.target) return self.vm.run(*args)
def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): """batch_matmul x86 strategy""" strategy = _op.OpStrategy() mcpu = Target.current().mcpu need_auto_scheduler_layout = is_auto_scheduler_enabled() need_meta_schedule_layout = is_meta_schedule_enabled() if (not attrs.transpose_a and attrs.transpose_b and target_has_vnni(mcpu) and inputs[0].dtype == "uint8" and inputs[1].dtype == "int8" and inputs[1].shape[-2] % 16 == 0 and inputs[1].shape[-1] % 4 == 0): strategy.add_implementation( wrap_compute_batch_matmul(topi.x86.batch_matmul_vnni_compute, need_out_dtype=True), wrap_topi_schedule(topi.x86.schedule_batch_matmul_vnni), name="batch_matmul_vnni.x86", plevel=10, ) elif is_dynamic( out_type ) or need_auto_scheduler_layout or need_meta_schedule_layout: strategy.add_implementation( wrap_compute_batch_matmul( topi.nn.batch_matmul, need_out_dtype=True, need_auto_scheduler_layout=need_auto_scheduler_layout, need_meta_schedule_layout=need_meta_schedule_layout, ), wrap_topi_schedule(topi.generic.nn.schedule_batch_matmul), name="batch_matmul.generic", plevel=10, ) else: strategy.add_implementation( wrap_compute_batch_matmul(topi.x86.batch_matmul, need_out_dtype=True), wrap_topi_schedule(topi.x86.schedule_batch_matmul), name="batch_matmul.x86", plevel=10, ) if "cblas" in target.libs: strategy.add_implementation( wrap_compute_batch_matmul(topi.x86.batch_matmul_cblas), wrap_topi_schedule(topi.x86.schedule_batch_matmul_cblas), name="batch_matmul_cblas.x86", plevel=15, ) if "mkl" in target.libs: strategy.add_implementation( wrap_compute_batch_matmul(topi.x86.batch_matmul_mkl), wrap_topi_schedule(topi.x86.schedule_batch_matmul_mkl), name="batch_matmul_mkl.x86", plevel=15, ) return strategy