Beispiel #1
0
    def gen_fallback_code(self, cpp_sig, native_cpp_sig):
        func_name = cpp_sig.def_name

        for param in cpp_sig.input_params:
            assert param.name

        if native_cpp_sig is None:
            params_name = [
                param.ipex_name if param.ipex_name != '' else param.name
                for param in cpp_sig.input_params
            ]
        else:
            params1_name = [param.name for param in cpp_sig.input_params]
            params2_name = [
                param.name for param in native_cpp_sig.input_params
            ]
            new_idxs = utils.reorder_params_idx(params1_name, params2_name)
            input_params = cpp_sig.input_params
            params_name = [
                input_params[new_idxs[idx]].ipex_name
                if input_params[new_idxs[idx]].ipex_name != '' else
                input_params[new_idxs[idx]].name
                for idx in range(len(new_idxs))
            ]

        code = ''
        # Wrap the input parameters as tensor option
        start_idx, end_idx = utils.query_tensor_options(cpp_sig.input_params)
        if start_idx >= 0 and end_idx > start_idx:
            # assert bool((end_idx - start_idx + 1) == 4)
            wrapped_options = 'ipex_wrapped_options'
            code += '  auto&& {} = at::TensorOptions().dtype(dtype).device(at::DeviceType::CPU).layout(layout).pinned_memory(pin_memory);\n'
            code = code.format(wrapped_options)
            # Remove original param name
            params_name = params_name[:start_idx] + [
                wrapped_options
            ] + params_name[end_idx + 1:]

        if self._native_funcs.is_tensor_member_function(func_name):
            assert "_ipex_self" in params_name
            params_name.remove('_ipex_self')
            if self.is_void_func(cpp_sig):
                code += '  {}.{}({});\n'.format('_ipex_self', cpp_sig.def_name,
                                                ', '.join(params_name))
            else:
                code += '  auto&& {} = {}.{}({});\n'.format(
                    _RESULT_NAME, '_ipex_self', cpp_sig.def_name,
                    ', '.join(params_name))
        else:
            if self.is_void_func(cpp_sig):
                code += '  at::{}({});\n'.format(cpp_sig.def_name,
                                                 ', '.join(params_name))
            else:
                code += '  auto&& {} = at::{}({});\n'.format(
                    _RESULT_NAME, cpp_sig.def_name, ', '.join(params_name))

        return code
    def gen_fallback_code(self, cpp_sig):
        for param in cpp_sig.input_params:
            assert param.name
        params_name = [
            param.ipex_name if param.ipex_name != '' else param.name
            for param in cpp_sig.input_params
        ]

        code = ''
        start_idx, end_idx = utils.query_tensor_options(cpp_sig.input_params)
        if start_idx >= 0 and end_idx > start_idx:
            # assert bool((end_idx - start_idx + 1) == 4)
            wrapped_options = 'ipex_wrapped_options'
            code += '  auto&& {} = at::TensorOptions().dtype(dtype).device(at::DeviceType::CPU).layout(layout).pinned_memory(pin_memory);\n'
            code = code.format(wrapped_options)
            # Remove original param name
            params_name = params_name[:start_idx] + [
                wrapped_options
            ] + params_name[end_idx + 1:]

        if cpp_sig.is_tensor_member_func:
            assert "_ipex_self" in params_name
            params_name.remove('_ipex_self')
            if self.is_void_func(cpp_sig):
                code += '  {}.{}({});\n'.format('_ipex_self', cpp_sig.def_name,
                                                ', '.join(params_name))
            else:
                code += '  auto&& {} = {}.{}({});\n'.format(
                    _RESULT_NAME, '_ipex_self', cpp_sig.def_name,
                    ', '.join(params_name))
        else:
            if self.is_void_func(cpp_sig):
                code += '  at::{}({});\n'.format(cpp_sig.def_name,
                                                 ', '.join(params_name))
            else:
                code += '  auto&& {} = at::{}({});\n'.format(
                    _RESULT_NAME, cpp_sig.def_name, ', '.join(params_name))
        return code