Exemple #1
0
    def gen_fallback_code(self, cpp_sig, native_cpp_sig):
        func_name = cpp_sig.def_name

        for param in cpp_sig.input_params:
            assert param.name

        if native_cpp_sig is None:
            params_name = [
                param.ipex_name if param.ipex_name != '' else param.name
                for param in cpp_sig.input_params
            ]
        else:
            params1_name = [param.name for param in cpp_sig.input_params]
            params2_name = [
                param.name for param in native_cpp_sig.input_params
            ]
            new_idxs = utils.reorder_params_idx(params1_name, params2_name)
            input_params = cpp_sig.input_params
            params_name = [
                input_params[new_idxs[idx]].ipex_name
                if input_params[new_idxs[idx]].ipex_name != '' else
                input_params[new_idxs[idx]].name
                for idx in range(len(new_idxs))
            ]

        code = ''
        # Wrap the input parameters as tensor option
        start_idx, end_idx = utils.query_tensor_options(cpp_sig.input_params)
        if start_idx >= 0 and end_idx > start_idx:
            # assert bool((end_idx - start_idx + 1) == 4)
            wrapped_options = 'ipex_wrapped_options'
            code += '  auto&& {} = at::TensorOptions().dtype(dtype).device(at::DeviceType::CPU).layout(layout).pinned_memory(pin_memory);\n'
            code = code.format(wrapped_options)
            # Remove original param name
            params_name = params_name[:start_idx] + [
                wrapped_options
            ] + params_name[end_idx + 1:]

        if self._native_funcs.is_tensor_member_function(func_name):
            assert "_ipex_self" in params_name
            params_name.remove('_ipex_self')
            if self.is_void_func(cpp_sig):
                code += '  {}.{}({});\n'.format('_ipex_self', cpp_sig.def_name,
                                                ', '.join(params_name))
            else:
                code += '  auto&& {} = {}.{}({});\n'.format(
                    _RESULT_NAME, '_ipex_self', cpp_sig.def_name,
                    ', '.join(params_name))
        else:
            if self.is_void_func(cpp_sig):
                code += '  at::{}({});\n'.format(cpp_sig.def_name,
                                                 ', '.join(params_name))
            else:
                code += '  auto&& {} = at::{}({});\n'.format(
                    _RESULT_NAME, cpp_sig.def_name, ', '.join(params_name))

        return code
Exemple #2
0
    def gen_dnnl_code(self, cpp_sig, native_cpp_sig, aten_func_sig_str):
        code = ''

        if not self.is_dnnl_func(aten_func_sig_str):
            return code

        param_vars = []
        dnnl_tensor_param_vars = []

        input_params = cpp_sig.input_params
        # Reorder the input parameters
        if native_cpp_sig is not None:
            params1_name = [param.name for param in cpp_sig.input_params]
            params2_name = [
                param.name for param in native_cpp_sig.input_params
            ]
            new_idxs = utils.reorder_params_idx(params1_name, params2_name)
            input_params = [
                cpp_sig.input_params[new_idxs[idx]]
                for idx in range(len(new_idxs))
            ]

        for param in input_params:
            if param.core_type == 'Tensor':
                dnnl_tensor_param_vars.append(param)

            if param.core_type == 'Tensor' and param.is_optional:
                param_vars.append(
                    "{}.has_value() ? {}.value() : at::Tensor()".format(
                        param.name, param.name))
            else:
                param_vars.append(param.name)

        code += '  try {\n'

        code += '    if (check_auto_dnnl()) {\n'

        if not self.is_ipex_func(aten_func_sig_str):
            # There are two different kind of DevOPs in IPEX
            #    1. DNNL Operator
            #    2. CPU BF16/INT8 Operator in Vanilla PyTorch. IPEX itegrates this kind of operators in IPEX for
            #       mixture precision.
            # For the type 2, IPEX does not need to check if DNNL supports these tensors.
            code += '      std::vector<at::Tensor> dnnl_input_tensors;\n'
            if len(dnnl_tensor_param_vars) > 0:
                for dnnl_tensor_param_var in dnnl_tensor_param_vars:
                    if dnnl_tensor_param_var.is_optional:
                        code += '      if ({}.has_value()) dnnl_input_tensors.push_back({}.value());\n'.format(
                            dnnl_tensor_param_var.name,
                            dnnl_tensor_param_var.name)
                    else:
                        code += '      dnnl_input_tensors.push_back({});\n'.format(
                            dnnl_tensor_param_var.name)

        fname = cpp_sig.def_name
        if fname.endswith('_'):
            assert len(dnnl_tensor_param_vars) > 0
            if self.is_ipex_func(aten_func_sig_str):
                code += self.gen_ipex_func_code(fname, param_vars)
            else:
                code += '      if (dbl::chk::dnnl_inplace_support_the_tensors(dnnl_input_tensors)) {\n'
                code += '        return AtenIpexCPUDev::dil_{}({});\n'.format(
                    fname, ', '.join(list(param_vars)))
                code += '      }\n'  # Check support tensors
        else:
            param_seq_str_vec = []
            for param_var in param_vars:
                param_seq_str = param_var
                param_seq_str_vec.append(param_seq_str)

            if self.is_ipex_func(aten_func_sig_str):
                code += self.gen_ipex_func_code(fname, param_seq_str_vec)
            else:
                code += '      if (dbl::chk::dnnl_support_the_tensors(dnnl_input_tensors)) {\n'
                code += '        return AtenIpexCPUDev::dil_{}({});\n'.format(
                    fname, ', '.join(param_seq_str_vec))
                code += '      }\n'  # Check support tensors
        code += '    }\n'  # Check auto dnnl
        code += '  } catch (std::exception& e) {\n'
        code += '#if defined(_DEBUG)\n'
        code += '    TORCH_WARN(e.what());\n'
        code += '#endif\n'
        code += '  }\n\n'

        return code