def numba_funcify_ScalarOp(op, node, **kwargs): # TODO: Do we need to cache these functions so that we don't end up # compiling the same Numba function over and over again? scalar_func_name = op.nfunc_spec[0] if scalar_func_name.startswith("scipy."): func_package = scipy scalar_func_name = scalar_func_name.split(".", 1)[-1] else: func_package = np if "." in scalar_func_name: scalar_func = reduce(getattr, [scipy] + scalar_func_name.split(".")) else: scalar_func = getattr(func_package, scalar_func_name) input_names = ", ".join([v.auto_name for v in node.inputs]) global_env = {"scalar_func": scalar_func} scalar_op_fn_name = get_name_for_object(scalar_func) scalar_op_src = f""" def {scalar_op_fn_name}({input_names}): return scalar_func({input_names}) """ scalar_op_fn = compile_function_src(scalar_op_src, scalar_op_fn_name, global_env) return numba.njit(scalar_op_fn)
def numba_funcify_Alloc(op, node, **kwargs): global_env = {"np": np, "to_scalar": to_scalar} shape_var_names = [v.auto_name for v in node.inputs[1:]] shape_var_item_names = [f"{name}_item" for name in shape_var_names] shapes_to_items_src = indent( "\n".join([ f"{item_name} = to_scalar({shape_name})" for item_name, shape_name in zip(shape_var_item_names, shape_var_names) ]), " " * 4, ) alloc_def_src = f""" def alloc(val, {", ".join(shape_var_names)}): val_np = np.asarray(val) {shapes_to_items_src} scalar_shape = {create_tuple_string(shape_var_item_names)} res = np.empty(scalar_shape, dtype=val_np.dtype) res[...] = val_np return res """ alloc_fn = compile_function_src(alloc_def_src, "alloc", global_env) return numba.njit(alloc_fn)
def numba_funcify_Elemwise(op, node, use_signature=False, identity=None, **kwargs): scalar_op_fn = numba_funcify(op.scalar_op, node, **kwargs) input_names = ", ".join([v.auto_name for v in node.inputs]) if use_signature: signature = [create_numba_signature(node, force_scalar=True)] else: signature = [] numba_vectorize = numba.vectorize(signature, identity=identity) global_env = { "scalar_op": scalar_op_fn, "numba_vectorize": numba_vectorize } elemwise_fn_name = f"elemwise_{get_name_for_object(scalar_op_fn)}" elemwise_src = f""" @numba_vectorize def {elemwise_fn_name}({input_names}): return scalar_op({input_names}) """ elemwise_fn = compile_function_src(elemwise_src, elemwise_fn_name, global_env) return elemwise_fn
def numba_funcify_Alloc(op, node, **kwargs): global_env = {"np": np, "to_scalar": numba_basic.to_scalar} unique_names = unique_name_generator( ["np", "to_scalar", "alloc", "val_np", "val", "scalar_shape", "res"], suffix_sep="_", ) shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs[1:]] shape_var_item_names = [f"{name}_item" for name in shape_var_names] shapes_to_items_src = indent( "\n".join( [ f"{item_name} = to_scalar({shape_name})" for item_name, shape_name in zip(shape_var_item_names, shape_var_names) ] ), " " * 4, ) alloc_def_src = f""" def alloc(val, {", ".join(shape_var_names)}): val_np = np.asarray(val) {shapes_to_items_src} scalar_shape = {create_tuple_string(shape_var_item_names)} res = np.empty(scalar_shape, dtype=val_np.dtype) res[...] = val_np return res """ alloc_fn = compile_function_src(alloc_def_src, "alloc", global_env) return numba.njit(alloc_fn)
def create_multiaxis_reducer(reduce_fn, identity, axes, ndim, dtype, input_name="input"): careduce_fn_name = f"careduce_{get_name_for_object(reduce_fn)}" careduce_axes_fns = () to_reduce = reversed(sorted(axes)) careduce_lines_src = [] var_name = input_name for i, axis in enumerate(to_reduce): careduce_axes_fns += (create_axis_reducer(reduce_fn, identity, axis - i, ndim, dtype), ) ndim -= 1 last_var_name = var_name var_name = f"axis_{i}_res" careduce_lines_src.append( f"{var_name} = careduce_axes_fns[{i}]({last_var_name})") careduce_assign_lines = indent("\n".join(careduce_lines_src), " " * 4) careduce_def_src = f""" def {careduce_fn_name}({input_name}): {careduce_assign_lines} return {var_name} """ global_env = {"careduce_axes_fns": careduce_axes_fns} careduce_fn = compile_function_src(careduce_def_src, careduce_fn_name, global_env) return careduce_fn
def numba_funcify_MakeVector(op, node, **kwargs): dtype = np.dtype(op.dtype) global_env = {"np": np, "to_scalar": numba_basic.to_scalar} unique_names = unique_name_generator( ["np", "to_scalar"], suffix_sep="_", ) input_names = [unique_names(v, force_unique=True) for v in node.inputs] def create_list_string(x): args = ", ".join([f"to_scalar({i})" for i in x] + ([""] if len(x) == 1 else [])) return f"[{args}]" makevector_def_src = f""" def makevector({", ".join(input_names)}): return np.array({create_list_string(input_names)}, dtype=np.{dtype}) """ makevector_fn = compile_function_src(makevector_def_src, "makevector", { **globals(), **global_env }) return numba_basic.numba_njit(makevector_fn)
def numba_funcify_AllocEmpty(op, node, **kwargs): global_env = { "np": np, "to_scalar": numba_basic.to_scalar, "dtype": np.dtype(op.dtype), } unique_names = unique_name_generator( ["np", "to_scalar", "dtype", "allocempty", "scalar_shape"], suffix_sep="_") shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs] shape_var_item_names = [f"{name}_item" for name in shape_var_names] shapes_to_items_src = indent( "\n".join([ f"{item_name} = to_scalar({shape_name})" for item_name, shape_name in zip(shape_var_item_names, shape_var_names) ]), " " * 4, ) alloc_def_src = f""" def allocempty({", ".join(shape_var_names)}): {shapes_to_items_src} scalar_shape = {create_tuple_string(shape_var_item_names)} return np.empty(scalar_shape, dtype) """ alloc_fn = compile_function_src(alloc_def_src, "allocempty", { **globals(), **global_env }) return numba_basic.numba_njit(alloc_fn)
def numba_funcify_Elemwise(op, node, **kwargs): scalar_op_fn = numba_funcify(op.scalar_op, node=node, inline="always", **kwargs) elemwise_fn = create_vectorize_func(scalar_op_fn, node, use_signature=False) elemwise_fn_name = elemwise_fn.__name__ if op.inplace_pattern: input_idx = op.inplace_pattern[0] sign_obj = inspect.signature(elemwise_fn.py_scalar_func) input_names = list(sign_obj.parameters.keys()) unique_names = unique_name_generator([elemwise_fn_name, "np"], suffix_sep="_") input_names = [unique_names(i, force_unique=True) for i in input_names] updated_input_name = input_names[input_idx] inplace_global_env = {elemwise_fn_name: elemwise_fn, "np": np} inplace_elemwise_fn_name = f"{elemwise_fn_name}_inplace" input_signature_str = ", ".join(input_names) if node.inputs[input_idx].ndim > 0: inplace_elemwise_src = f""" def {inplace_elemwise_fn_name}({input_signature_str}): return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name}) """ else: # We can't perform in-place updates on Numba scalars, so we need to # convert them to NumPy scalars. # TODO: We should really prevent the rewrites from creating # in-place updates on scalars when the Numba mode is selected (or # in general?). inplace_elemwise_src = f""" def {inplace_elemwise_fn_name}({input_signature_str}): {updated_input_name}_scalar = np.asarray({updated_input_name}) return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name}_scalar).item() """ inplace_elemwise_fn = compile_function_src( inplace_elemwise_src, inplace_elemwise_fn_name, { **globals(), **inplace_global_env }, ) return numba_basic.numba_njit( inline="always", fastmath=config.numba__fastmath)(inplace_elemwise_fn) return elemwise_fn
def numba_funcify_Subtensor(op, node, **kwargs): subtensor_def_src = create_index_func( node, objmode=isinstance(op, AdvancedSubtensor) ) global_env = {"np": np, "objmode": numba.objmode} subtensor_fn = compile_function_src(subtensor_def_src, "subtensor", global_env) return numba.njit(subtensor_fn)
def numba_funcify_IncSubtensor(op, node, **kwargs): incsubtensor_def_src = create_index_func(node, objmode=isinstance( op, AdvancedIncSubtensor)) global_env = {"np": np, "objmode": numba.objmode} incsubtensor_fn = compile_function_src(incsubtensor_def_src, "incsubtensor", { **globals(), **global_env }) return numba_njit(incsubtensor_fn)
def binary_to_nary_func(inputs: List[Variable], binary_op_name: str, binary_op: str): """Create a Numba-compatible N-ary function from a binary function.""" unique_names = unique_name_generator(["binary_op_name"], suffix_sep="_") input_names = [unique_names(v, force_unique=True) for v in inputs] input_signature = ", ".join(input_names) output_expr = binary_op.join(input_names) nary_src = f""" def {binary_op_name}({input_signature}): return {output_expr} """ nary_fn = compile_function_src(nary_src, binary_op_name, globals()) return nary_fn
def numba_funcify_Elemwise(op, node, **kwargs): scalar_op_fn = numba_funcify(op.scalar_op, node, **kwargs) input_names = ", ".join([v.auto_name for v in node.inputs]) global_env = {"scalar_op": scalar_op_fn, "vectorize": numba.vectorize} elemwise_fn_name = f"elemwise_{scalar_op_fn.__name__}" elemwise_src = f""" @vectorize def {elemwise_fn_name}({input_names}): return scalar_op({input_names}) """ elemwise_fn = compile_function_src(elemwise_src, elemwise_fn_name, global_env) return elemwise_fn
def create_numba_random_fn( op: Op, node: Apply, scalar_fn: Callable[[str], str], global_env: Optional[Dict[str, Any]] = None, ) -> Callable: """Create a vectorized function from a callable that generates the ``str`` function body. TODO: This could/should be generalized for other simple function construction cases that need unique-ified symbol names. """ np_random_fn_name = f"aesara_random_{get_name_for_object(op.name)}" if global_env: np_global_env = global_env.copy() else: np_global_env = {} np_global_env["np"] = np np_global_env["numba_vectorize"] = numba_basic.numba_vectorize unique_names = unique_name_generator( [ np_random_fn_name, ] + list(np_global_env.keys()) + [ "rng", "size", "dtype", ], suffix_sep="_", ) np_names = [unique_names(i, force_unique=True) for i in node.inputs[3:]] np_input_names = ", ".join(np_names) np_random_fn_src = f""" @numba_vectorize def {np_random_fn_name}({np_input_names}): {scalar_fn(*np_names)} """ np_random_fn = compile_function_src(np_random_fn_src, np_random_fn_name, { **globals(), **np_global_env }) return make_numba_random_fn(node, np_random_fn)
def numba_funcify_SpecifyShape(op, node, **kwargs): shape_inputs = node.inputs[1:] shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))] func_conditions = [ f"assert x.shape[{i}] == {shape_input_names}" for i, (shape_input, shape_input_names ) in enumerate(zip(shape_inputs, shape_input_names)) if shape_input is not NoneConst ] func = dedent(f""" def specify_shape(x, {create_arg_string(shape_input_names)}): {"; ".join(func_conditions)} return x """) specify_shape = compile_function_src(func, "specify_shape", globals()) return numba_njit(specify_shape)
def numba_funcify_ScalarOp(op, node, **kwargs): # TODO: Do we need to cache these functions so that we don't end up # compiling the same Numba function over and over again? scalar_func_name = op.nfunc_spec[0] if scalar_func_name.startswith("scipy."): func_package = scipy scalar_func_name = scalar_func_name.split(".", 1)[-1] else: func_package = np if "." in scalar_func_name: scalar_func = reduce(getattr, [scipy] + scalar_func_name.split(".")) else: scalar_func = getattr(func_package, scalar_func_name) scalar_op_fn_name = get_name_for_object(scalar_func) unique_names = unique_name_generator([scalar_op_fn_name, "scalar_func"], suffix_sep="_") input_names = ", ".join( [unique_names(v, force_unique=True) for v in node.inputs]) global_env = {"scalar_func": scalar_func} scalar_op_src = f""" def {scalar_op_fn_name}({input_names}): return scalar_func({input_names}) """ scalar_op_fn = compile_function_src(scalar_op_src, scalar_op_fn_name, { **globals(), **global_env }) signature = create_numba_signature(node, force_scalar=True) return numba_basic.numba_njit( signature, inline="always", fastmath=config.numba__fastmath)(scalar_op_fn)
def numba_funcify_AllocEmpty(op, node, **kwargs): global_env = {"np": np, "to_scalar": to_scalar, "dtype": op.dtype} shape_var_names = [v.auto_name for v in node.inputs] shape_var_item_names = [f"{name}_item" for name in shape_var_names] shapes_to_items_src = indent( "\n".join([ f"{item_name} = to_scalar({shape_name})" for item_name, shape_name in zip(shape_var_item_names, shape_var_names) ]), " " * 4, ) alloc_def_src = f""" def allocempty({", ".join(shape_var_names)}): {shapes_to_items_src} scalar_shape = {create_tuple_string(shape_var_item_names)} return np.empty(scalar_shape, dtype) """ alloc_fn = compile_function_src(alloc_def_src, "allocempty", global_env) return numba.njit(alloc_fn)
def create_multiaxis_reducer(scalar_op, identity, axes, ndim, dtype, input_name="input"): r"""Construct a function that reduces multiple axes. The functions generated by this function take the following form: .. code-block:: python def careduce_maximum(input): axis_0_res = careduce_axes_fn_0(input) axis_1_res = careduce_axes_fn_1(axis_0_res) ... axis_N_res = careduce_axes_fn_N(axis_N_minus_1_res) return axis_N_res The range 0-N is determined by the `axes` argument (i.e. the axes to be reduced). Parameters ========== scalar_op: The scalar :class:`Op` that performs the desired reduction. identity: The identity value for the reduction. axes: The axes to reduce. ndim: The number of dimensions of the result. dtype: The data type of the result. Returns ======= A Python function that can be JITed. """ if len(axes) == 1: return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype) careduce_fn_name = f"careduce_{scalar_op}" global_env = {} to_reduce = reversed(sorted(axes)) careduce_lines_src = [] var_name = input_name for i, axis in enumerate(to_reduce): careducer_axes_fn_name = f"careduce_axes_fn_{i}" reducer_py_fn = create_axis_reducer(scalar_op, identity, axis, ndim, dtype) reducer_fn = numba_basic.numba_njit( boundscheck=False, fastmath=config.numba__fastmath)(reducer_py_fn) global_env[careducer_axes_fn_name] = reducer_fn ndim -= 1 last_var_name = var_name var_name = f"axis_{i}_res" careduce_lines_src.append( f"{var_name} = {careducer_axes_fn_name}({last_var_name})") careduce_assign_lines = indent("\n".join(careduce_lines_src), " " * 4) careduce_def_src = f""" def {careduce_fn_name}({input_name}): {careduce_assign_lines} return {var_name} """ careduce_fn = compile_function_src(careduce_def_src, careduce_fn_name, { **globals(), **global_env }) return careduce_fn
def create_multiaxis_reducer(reduce_fn, identity, axes, ndim, dtype, input_name="input"): r"""Construct a function that reduces multiple axes. The functions generated by this function take the following form: .. code-block:: python def careduce_maximum(input): axis_0_res = careduce_axes_fn_0(input) axis_1_res = careduce_axes_fn_1(axis_0_res) ... axis_N_res = careduce_axes_fn_N(axis_N_minus_1_res) return axis_N_res The range 0-N is determined by the `axes` argument (i.e. the axes to be reduced). Parameters ========== reduce_fn: The Numba ``ufunc`` representing a binary op that can perform the reduction on arbitrary ``ndarray``\s. identity: The identity value for the reduction. axes: The axes to reduce. ndim: The number of dimensions of the result. dtype: The data type of the result. """ if len(axes) == 1: return create_axis_reducer(reduce_fn, identity, axes[0], ndim, dtype) careduce_fn_name = f"careduce_{get_name_for_object(reduce_fn)}" global_env = {} to_reduce = reversed(sorted(axes)) careduce_lines_src = [] var_name = input_name for i, axis in enumerate(to_reduce): careducer_axes_fn_name = f"careduce_axes_fn_{i}" global_env[careducer_axes_fn_name] = create_axis_reducer( reduce_fn, identity, axis - i, ndim, dtype) ndim -= 1 last_var_name = var_name var_name = f"axis_{i}_res" careduce_lines_src.append( f"{var_name} = {careducer_axes_fn_name}({last_var_name})") careduce_assign_lines = indent("\n".join(careduce_lines_src), " " * 4) careduce_def_src = f""" def {careduce_fn_name}({input_name}): {careduce_assign_lines} return {var_name} """ careduce_fn = compile_function_src(careduce_def_src, careduce_fn_name, global_env) return numba.njit(careduce_fn)
def make_numba_random_fn(node, np_random_func): """Create Numba implementations for existing Numba-supported ``np.random`` functions. The functions generated here add parameter broadcasting and the ``size`` argument to the Numba-supported scalar ``np.random`` functions. """ tuple_size = int(get_vector_length(node.inputs[1])) size_dims = tuple_size - max(i.ndim for i in node.inputs[3:]) # Make a broadcast-capable version of the Numba supported scalar sampling # function bcast_fn_name = f"aesara_random_{get_name_for_object(np_random_func)}" sized_fn_name = "sized_random_variable" unique_names = unique_name_generator( [ bcast_fn_name, sized_fn_name, "np", "np_random_func", "numba_vectorize", "to_fixed_tuple", "tuple_size", "size_dims", "rng", "size", "dtype", ], suffix_sep="_", ) bcast_fn_input_names = ", ".join( [unique_names(i, force_unique=True) for i in node.inputs[3:]]) bcast_fn_global_env = { "np_random_func": np_random_func, "numba_vectorize": numba.vectorize, } bcast_fn_src = f""" @numba_vectorize def {bcast_fn_name}({bcast_fn_input_names}): return np_random_func({bcast_fn_input_names}) """ bcast_fn = compile_function_src(bcast_fn_src, bcast_fn_name, bcast_fn_global_env) random_fn_input_names = ", ".join( ["rng", "size", "dtype"] + [unique_names(i) for i in node.inputs[3:]]) # Now, create a Numba JITable function that implements the `size` parameter out_dtype = node.outputs[1].type.numpy_dtype random_fn_global_env = { bcast_fn_name: bcast_fn, "out_dtype": out_dtype, } if tuple_size > 0: random_fn_body = dedent(f""" size = to_fixed_tuple(size, tuple_size) data = np.empty(size, dtype=out_dtype) for i in np.ndindex(size[:size_dims]): data[i] = {bcast_fn_name}({bcast_fn_input_names}) """) random_fn_global_env.update({ "np": np, "to_fixed_tuple": numba_ndarray.to_fixed_tuple, "tuple_size": tuple_size, "size_dims": size_dims, }) else: random_fn_body = f"""data = {bcast_fn_name}({bcast_fn_input_names})""" sized_fn_src = dedent(f""" def {sized_fn_name}({random_fn_input_names}): {indent(random_fn_body, " " * 4)} return (rng, data) """) random_fn = compile_function_src(sized_fn_src, sized_fn_name, random_fn_global_env) random_fn = numba.njit(random_fn) return random_fn
def create_axis_reducer( scalar_op: Op, identity: Union[np.ndarray, Number], axis: int, ndim: int, dtype: numba.types.Type, keepdims: bool = False, ) -> numba.core.dispatcher.Dispatcher: r"""Create Python function that performs a NumPy-like reduction on a given axis. The functions generated by this function take the following form: .. code-block:: python def careduce_axis(x): res_shape = tuple(shape[i] if i < axis else shape[i + 1] for i in range(ndim - 1)) res = np.full(res_shape, identity, dtype=dtype) x_axis_first = x.transpose(reaxis_first) for m in range(x.shape[axis]): reduce_fn(res, x_axis_first[m], res) if keepdims: return np.expand_dims(res, axis) else: return res This can be removed/replaced when https://github.com/numba/numba/issues/4504 is implemented. Parameters ========== scalar_op: The scalar :class:`Op` that performs the desired reduction. identity: The identity value for the reduction. axis: The axis to reduce. ndim: The number of dimensions of the result. dtype: The data type of the result. keepdims: Determines whether or not the reduced dimension is retained. Returns ======= A Python function that can be JITed. """ reduce_elemwise_fn_name = "careduce_axis" identity = str(identity) if identity == "inf": identity = "np.inf" elif identity == "-inf": identity = "-np.inf" global_env = { "np": np, "numba_basic": numba_basic, "out_dtype": dtype, } if ndim > 1: res_shape_tuple_ctor = create_tuple_creator( lambda i, shape: shape[i] if i < axis else shape[i + 1], ndim - 1) global_env["res_shape_tuple_ctor"] = res_shape_tuple_ctor res_indices = [] arr_indices = [] count = 0 for i in range(ndim): if i == axis: arr_indices.append("i") else: res_indices.append(f"idx_arr[{count}]") arr_indices.append(f"idx_arr[{count}]") count = count + 1 res_indices = ", ".join(res_indices) arr_indices = ", ".join(arr_indices) inplace_update_statement = scalar_in_place_fn(scalar_op, res_indices, "res", f"x[{arr_indices}]") inplace_update_statement = indent(inplace_update_statement, " " * 4 * 3) return_expr = f"np.expand_dims(res, {axis})" if keepdims else "res" reduce_elemwise_def_src = f""" def {reduce_elemwise_fn_name}(x): x_shape = np.shape(x) res_shape = res_shape_tuple_ctor(x_shape) res = np.full(res_shape, numba_basic.to_scalar({identity}), dtype=out_dtype) axis_shape = x.shape[{axis}] for idx_arr in np.ndindex(res_shape): for i in range(axis_shape): {inplace_update_statement} return {return_expr} """ else: inplace_update_statement = scalar_in_place_fn(scalar_op, "0", "res", "x[i]") inplace_update_statement = indent(inplace_update_statement, " " * 4 * 2) return_expr = "res" if keepdims else "res.item()" reduce_elemwise_def_src = f""" def {reduce_elemwise_fn_name}(x): res = np.full(1, numba_basic.to_scalar({identity}), dtype=out_dtype) axis_shape = x.shape[{axis}] for i in range(axis_shape): {inplace_update_statement} return {return_expr} """ reduce_elemwise_fn_py = compile_function_src(reduce_elemwise_def_src, reduce_elemwise_fn_name, { **globals(), **global_env }) return reduce_elemwise_fn_py
def numba_funcify_Scan(op, node, **kwargs): inner_fg = FunctionGraph(op.inner_inputs, op.inner_outputs) numba_at_inner_func = numba_basic.numba_njit( numba_funcify(inner_fg, **kwargs)) n_seqs = op.info.n_seqs n_mit_mot = op.info.n_mit_mot n_mit_sot = op.info.n_mit_sot n_nit_sot = op.info.n_nit_sot n_sit_sot = op.info.n_sit_sot tap_array = op.info.tap_array n_shared_outs = op.info.n_shared_outs mit_mot_in_taps = tuple(tap_array[:n_mit_mot]) mit_sot_in_taps = tuple(tap_array[n_mit_mot:n_mit_mot + n_mit_sot]) p_in_mit_mot = n_seqs p_in_mit_sot = p_in_mit_mot + n_mit_mot p_in_sit_sot = p_in_mit_sot + n_mit_sot p_outer_in_shared = p_in_sit_sot + n_sit_sot p_outer_in_nit_sot = p_outer_in_shared + n_shared_outs p_outer_in_non_seqs = p_outer_in_nit_sot + n_nit_sot input_names = [n.auto_name for n in node.inputs[1:]] outer_in_seqs_names = input_names[:n_seqs] outer_in_mit_mot_names = input_names[p_in_mit_mot:p_in_mit_mot + n_mit_mot] outer_in_mit_sot_names = input_names[p_in_mit_sot:p_in_mit_sot + n_mit_sot] outer_in_sit_sot_names = input_names[p_in_sit_sot:p_in_sit_sot + n_sit_sot] outer_in_shared_names = input_names[p_outer_in_shared:p_outer_in_shared + n_shared_outs] outer_in_nit_sot_names = input_names[ p_outer_in_nit_sot:p_outer_in_nit_sot + n_nit_sot] outer_in_feedback_names = input_names[n_seqs:p_outer_in_non_seqs] outer_in_non_seqs_names = input_names[p_outer_in_non_seqs:] inner_in_indexed = [] allocate_mem_to_nit_sot = "" for _name in outer_in_seqs_names: # A sequence with multiple taps is provided as multiple modified # input sequences to the Scan Op sliced appropriately # to keep following the logic of a normal sequence. index = "[i]" inner_in_indexed.append(_name + index) name_to_input_map = dict(zip(input_names, node.inputs[1:])) mit_sot_name_to_taps = dict(zip(outer_in_mit_sot_names, mit_sot_in_taps)) inner_out_name_to_index = {} for _name in outer_in_feedback_names: if _name in outer_in_mit_sot_names: curr_taps = mit_sot_name_to_taps[_name] min_tap = min(curr_taps) for _tap in curr_taps: index = idx_to_str(_tap - min_tap) inner_in_indexed.append(_name + index) inner_out_name_to_index[_name] = -min_tap if _name in outer_in_sit_sot_names: # Note that the outputs with single taps which are not # -1 are (for instance taps = [-2]) are classified # as mit-sot so the code for handling sit-sots remains # constant as follows index = "[i]" inner_in_indexed.append(_name + index) inner_out_name_to_index[_name] = 1 if _name in outer_in_nit_sot_names: inner_out_name_to_index[_name] = 0 # In case of nit-sots we are provided shape of the array # instead of actual arrays like other cases, hence we # allocate space for the results accordingly. curr_nit_sot_position = input_names.index(_name) - n_seqs curr_nit_sot = inner_fg.outputs[curr_nit_sot_position] mem_shape = ["1"] * curr_nit_sot.ndim curr_dtype = curr_nit_sot.type.numpy_dtype.name allocate_mem_to_nit_sot += f""" {_name} = [np.zeros(({create_arg_string(mem_shape)}), dtype=np.{curr_dtype})]*{_name}.item() """ # The non_seqs are passed to inner function as-is inner_in_indexed += outer_in_non_seqs_names inner_out_indexed = [ _name + idx_to_str(idx) for _name, idx in inner_out_name_to_index.items() ] while_logic = "" if op.info.as_while: # The inner function will be returning a boolean as last argument inner_out_indexed.append("while_flag") while_logic += """ if while_flag: """ for _name, idx in inner_out_name_to_index.items(): while_logic += f""" {_name} = {_name}[:i+{idx+1}] """ while_logic += """ break """ global_env = locals() global_env["np"] = np scan_op_src = f""" def scan(n_steps, {", ".join(input_names)}): {allocate_mem_to_nit_sot} for i in range(n_steps): inner_args = {create_tuple_string(inner_in_indexed)} {create_tuple_string(inner_out_indexed)} = numba_at_inner_func(*inner_args) {while_logic} return {create_arg_string( outer_in_mit_sot_names + outer_in_sit_sot_names + outer_in_nit_sot_names )} """ scalar_op_fn = compile_function_src(scan_op_src, "scan", { **globals(), **global_env }) return numba_basic.numba_njit(scalar_op_fn)
def numba_funcify_ScalarOp(op, node, **kwargs): # TODO: Do we need to cache these functions so that we don't end up # compiling the same Numba function over and over again? scalar_func_name = op.nfunc_spec[0] if scalar_func_name.startswith("scipy."): func_package = scipy scalar_func_name = scalar_func_name.split(".", 1)[-1] else: func_package = np if "." in scalar_func_name: scalar_func = reduce(getattr, [scipy] + scalar_func_name.split(".")) else: scalar_func = getattr(func_package, scalar_func_name) scalar_op_fn_name = get_name_for_object(scalar_func) unique_names = unique_name_generator([scalar_op_fn_name, "scalar_func"], suffix_sep="_") global_env = {"scalar_func": scalar_func} input_tmp_dtypes = None if func_package == scipy and hasattr(scalar_func, "types"): # The `numba-scipy` bindings don't provide implementations for all # inputs types, so we need to convert the inputs to floats and back. inp_dtype_kinds = tuple( np.dtype(inp.type.dtype).kind for inp in node.inputs) accepted_inp_kinds = tuple( sig_type.split("->")[0] for sig_type in scalar_func.types) if not any( all(dk == ik for dk, ik in zip(inp_dtype_kinds, ok_kinds)) for ok_kinds in accepted_inp_kinds): # They're usually ordered from lower-to-higher precision, so # we pick the last acceptable input types # # XXX: We should pick the first acceptable float/int types in # reverse, excluding all the incompatible ones (e.g. `"0"`). # The assumption is that this is only used by `numba-scipy`-exposed # functions, although it's possible for this to be triggered by # something else from the `scipy` package input_tmp_dtypes = tuple( np.dtype(k) for k in accepted_inp_kinds[-1]) if input_tmp_dtypes is None: unique_names = unique_name_generator( [scalar_op_fn_name, "scalar_func"], suffix_sep="_") input_names = ", ".join( [unique_names(v, force_unique=True) for v in node.inputs]) scalar_op_src = f""" def {scalar_op_fn_name}({input_names}): return scalar_func({input_names}) """ else: global_env["direct_cast"] = numba_basic.direct_cast global_env["output_dtype"] = np.dtype(node.outputs[0].type.dtype) input_tmp_dtype_names = { f"inp_tmp_dtype_{i}": i_dtype for i, i_dtype in enumerate(input_tmp_dtypes) } global_env.update(input_tmp_dtype_names) unique_names = unique_name_generator( [scalar_op_fn_name, "scalar_func"] + list(global_env.keys()), suffix_sep="_") input_names = [unique_names(v, force_unique=True) for v in node.inputs] converted_call_args = ", ".join([ f"direct_cast({i_name}, {i_tmp_dtype_name})" for i_name, i_tmp_dtype_name in zip(input_names, input_tmp_dtype_names.keys()) ]) scalar_op_src = f""" def {scalar_op_fn_name}({', '.join(input_names)}): return direct_cast(scalar_func({converted_call_args}), output_dtype) """ scalar_op_fn = compile_function_src(scalar_op_src, scalar_op_fn_name, { **globals(), **global_env }) signature = create_numba_signature(node, force_scalar=True) return numba_basic.numba_njit( signature, inline="always", fastmath=config.numba__fastmath)(scalar_op_fn)