def generate_numba_apply_func( args: Tuple, kwargs: Dict[str, Any], func: Callable[..., Scalar], engine_kwargs: Optional[Dict[str, bool]], ): """ Generate a numba jitted apply function specified by values from engine_kwargs. 1. jit the user's function 2. Return a rolling apply function with the jitted function inline Configurations specified in engine_kwargs apply to both the user's function _AND_ the rolling apply function. Parameters ---------- args : tuple *args to be passed into the function kwargs : dict **kwargs to be passed into the function func : function function to be applied to each window and will be JITed engine_kwargs : dict dictionary of arguments to be passed into numba.jit Returns ------- Numba function """ nopython, nogil, parallel = get_jit_arguments(engine_kwargs) check_kwargs_and_nopython(kwargs, nopython) numba_func = jit_user_function(func, nopython, nogil, parallel) numba = import_optional_dependency("numba") if parallel: loop_range = numba.prange else: loop_range = range @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def roll_apply( values: np.ndarray, begin: np.ndarray, end: np.ndarray, minimum_periods: int, ) -> np.ndarray: result = np.empty(len(begin)) for i in loop_range(len(result)): start = begin[i] stop = end[i] window = values[start:stop] count_nan = np.sum(np.isnan(window)) if len(window) - count_nan >= minimum_periods: result[i] = numba_func(window, *args) else: result[i] = np.nan return result return roll_apply
def _aggregate_series_pure_python( self, obj: Series, func: F, *args, engine: str = "cython", engine_kwargs=None, **kwargs, ): if engine == "numba": nopython, nogil, parallel = get_jit_arguments(engine_kwargs) check_kwargs_and_nopython(kwargs, nopython) validate_udf(func) cache_key = (func, "groupby_agg") numba_func = NUMBA_FUNC_CACHE.get( cache_key, jit_user_function(func, nopython, nogil, parallel) ) group_index, _, ngroups = self.group_info counts = np.zeros(ngroups, dtype=int) result = None splitter = get_splitter(obj, group_index, ngroups, axis=0) for label, group in splitter: if engine == "numba": values, index = split_for_numba(group) res = numba_func(values, index, *args) if cache_key not in NUMBA_FUNC_CACHE: NUMBA_FUNC_CACHE[cache_key] = numba_func else: res = func(group, *args, **kwargs) if result is None: if isinstance(res, (Series, Index, np.ndarray)): if len(res) == 1: # e.g. test_agg_lambda_with_timezone lambda e: e.head(1) # FIXME: are we potentially losing important res.index info? res = res.item() else: raise ValueError("Function does not reduce") result = np.empty(ngroups, dtype="O") counts[label] = group.shape[0] result[label] = res assert result is not None result = lib.maybe_convert_objects(result, try_float=0) # TODO: maybe_cast_to_extension_array? return result, counts
def generate_numba_func( func: Callable, engine_kwargs: Optional[Dict[str, bool]], kwargs: dict, cache_key_str: str, ) -> Tuple[Callable, Tuple[Callable, str]]: """ Return a JITed function and cache key for the NUMBA_FUNC_CACHE This _may_ be specific to groupby (as it's only used there currently). Parameters ---------- func : function user defined function engine_kwargs : dict or None numba.jit arguments kwargs : dict kwargs for func cache_key_str : str string representing the second part of the cache key tuple Returns ------- (JITed function, cache key) Raises ------ NumbaUtilError """ nopython, nogil, parallel = get_jit_arguments(engine_kwargs) check_kwargs_and_nopython(kwargs, nopython) validate_udf(func) cache_key = (func, cache_key_str) numba_func = NUMBA_FUNC_CACHE.get( cache_key, jit_user_function(func, nopython, nogil, parallel) ) return numba_func, cache_key
def generate_numba_transform_func( args: Tuple, kwargs: Dict[str, Any], func: Callable[..., Scalar], engine_kwargs: Optional[Dict[str, bool]], ) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int], np.ndarray]: """ Generate a numba jitted transform function specified by values from engine_kwargs. 1. jit the user's function 2. Return a groupby agg function with the jitted function inline Configurations specified in engine_kwargs apply to both the user's function _AND_ the rolling apply function. Parameters ---------- args : tuple *args to be passed into the function kwargs : dict **kwargs to be passed into the function func : function function to be applied to each window and will be JITed engine_kwargs : dict dictionary of arguments to be passed into numba.jit Returns ------- Numba function """ nopython, nogil, parallel = get_jit_arguments(engine_kwargs) check_kwargs_and_nopython(kwargs, nopython) validate_udf(func) numba_func = jit_user_function(func, nopython, nogil, parallel) numba = import_optional_dependency("numba") if parallel: loop_range = numba.prange else: loop_range = range @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def group_transform( values: np.ndarray, index: np.ndarray, begin: np.ndarray, end: np.ndarray, num_groups: int, num_columns: int, ) -> np.ndarray: result = np.empty((len(values), num_columns)) for i in loop_range(num_groups): group_index = index[begin[i] : end[i]] for j in loop_range(num_columns): group = values[begin[i] : end[i], j] result[begin[i] : end[i], j] = numba_func(group, group_index, *args) return result return group_transform