def gen_fallback_code(self, cpp_sig, native_cpp_sig): func_name = cpp_sig.def_name for param in cpp_sig.input_params: assert param.name if native_cpp_sig is None: params_name = [ param.ipex_name if param.ipex_name != '' else param.name for param in cpp_sig.input_params ] else: params1_name = [param.name for param in cpp_sig.input_params] params2_name = [ param.name for param in native_cpp_sig.input_params ] new_idxs = utils.reorder_params_idx(params1_name, params2_name) input_params = cpp_sig.input_params params_name = [ input_params[new_idxs[idx]].ipex_name if input_params[new_idxs[idx]].ipex_name != '' else input_params[new_idxs[idx]].name for idx in range(len(new_idxs)) ] code = '' # Wrap the input parameters as tensor option start_idx, end_idx = utils.query_tensor_options(cpp_sig.input_params) if start_idx >= 0 and end_idx > start_idx: # assert bool((end_idx - start_idx + 1) == 4) wrapped_options = 'ipex_wrapped_options' code += ' auto&& {} = at::TensorOptions().dtype(dtype).device(at::DeviceType::CPU).layout(layout).pinned_memory(pin_memory);\n' code = code.format(wrapped_options) # Remove original param name params_name = params_name[:start_idx] + [ wrapped_options ] + params_name[end_idx + 1:] if self._native_funcs.is_tensor_member_function(func_name): assert "_ipex_self" in params_name params_name.remove('_ipex_self') if self.is_void_func(cpp_sig): code += ' {}.{}({});\n'.format('_ipex_self', cpp_sig.def_name, ', '.join(params_name)) else: code += ' auto&& {} = {}.{}({});\n'.format( _RESULT_NAME, '_ipex_self', cpp_sig.def_name, ', '.join(params_name)) else: if self.is_void_func(cpp_sig): code += ' at::{}({});\n'.format(cpp_sig.def_name, ', '.join(params_name)) else: code += ' auto&& {} = at::{}({});\n'.format( _RESULT_NAME, cpp_sig.def_name, ', '.join(params_name)) return code
def gen_dnnl_code(self, cpp_sig, native_cpp_sig, aten_func_sig_str): code = '' if not self.is_dnnl_func(aten_func_sig_str): return code param_vars = [] dnnl_tensor_param_vars = [] input_params = cpp_sig.input_params # Reorder the input parameters if native_cpp_sig is not None: params1_name = [param.name for param in cpp_sig.input_params] params2_name = [ param.name for param in native_cpp_sig.input_params ] new_idxs = utils.reorder_params_idx(params1_name, params2_name) input_params = [ cpp_sig.input_params[new_idxs[idx]] for idx in range(len(new_idxs)) ] for param in input_params: if param.core_type == 'Tensor': dnnl_tensor_param_vars.append(param) if param.core_type == 'Tensor' and param.is_optional: param_vars.append( "{}.has_value() ? {}.value() : at::Tensor()".format( param.name, param.name)) else: param_vars.append(param.name) code += ' try {\n' code += ' if (check_auto_dnnl()) {\n' if not self.is_ipex_func(aten_func_sig_str): # There are two different kind of DevOPs in IPEX # 1. DNNL Operator # 2. CPU BF16/INT8 Operator in Vanilla PyTorch. IPEX itegrates this kind of operators in IPEX for # mixture precision. # For the type 2, IPEX does not need to check if DNNL supports these tensors. code += ' std::vector<at::Tensor> dnnl_input_tensors;\n' if len(dnnl_tensor_param_vars) > 0: for dnnl_tensor_param_var in dnnl_tensor_param_vars: if dnnl_tensor_param_var.is_optional: code += ' if ({}.has_value()) dnnl_input_tensors.push_back({}.value());\n'.format( dnnl_tensor_param_var.name, dnnl_tensor_param_var.name) else: code += ' dnnl_input_tensors.push_back({});\n'.format( dnnl_tensor_param_var.name) fname = cpp_sig.def_name if fname.endswith('_'): assert len(dnnl_tensor_param_vars) > 0 if self.is_ipex_func(aten_func_sig_str): code += self.gen_ipex_func_code(fname, param_vars) else: code += ' if (dbl::chk::dnnl_inplace_support_the_tensors(dnnl_input_tensors)) {\n' code += ' return AtenIpexCPUDev::dil_{}({});\n'.format( fname, ', '.join(list(param_vars))) code += ' }\n' # Check support tensors else: param_seq_str_vec = [] for param_var in param_vars: param_seq_str = param_var param_seq_str_vec.append(param_seq_str) if self.is_ipex_func(aten_func_sig_str): code += self.gen_ipex_func_code(fname, param_seq_str_vec) else: code += ' if (dbl::chk::dnnl_support_the_tensors(dnnl_input_tensors)) {\n' code += ' return AtenIpexCPUDev::dil_{}({});\n'.format( fname, ', '.join(param_seq_str_vec)) code += ' }\n' # Check support tensors code += ' }\n' # Check auto dnnl code += ' } catch (std::exception& e) {\n' code += '#if defined(_DEBUG)\n' code += ' TORCH_WARN(e.what());\n' code += '#endif\n' code += ' }\n\n' return code