def numba_funcify_MakeVector(op, node, **kwargs): dtype = np.dtype(op.dtype) global_env = {"np": np, "to_scalar": numba_basic.to_scalar} unique_names = unique_name_generator( ["np", "to_scalar"], suffix_sep="_", ) input_names = [unique_names(v, force_unique=True) for v in node.inputs] def create_list_string(x): args = ", ".join([f"to_scalar({i})" for i in x] + ([""] if len(x) == 1 else [])) return f"[{args}]" makevector_def_src = f""" def makevector({", ".join(input_names)}): return np.array({create_list_string(input_names)}, dtype=np.{dtype}) """ makevector_fn = compile_function_src(makevector_def_src, "makevector", { **globals(), **global_env }) return numba_basic.numba_njit(makevector_fn)
def numba_funcify_AllocEmpty(op, node, **kwargs): global_env = { "np": np, "to_scalar": numba_basic.to_scalar, "dtype": np.dtype(op.dtype), } unique_names = unique_name_generator( ["np", "to_scalar", "dtype", "allocempty", "scalar_shape"], suffix_sep="_") shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs] shape_var_item_names = [f"{name}_item" for name in shape_var_names] shapes_to_items_src = indent( "\n".join([ f"{item_name} = to_scalar({shape_name})" for item_name, shape_name in zip(shape_var_item_names, shape_var_names) ]), " " * 4, ) alloc_def_src = f""" def allocempty({", ".join(shape_var_names)}): {shapes_to_items_src} scalar_shape = {create_tuple_string(shape_var_item_names)} return np.empty(scalar_shape, dtype) """ alloc_fn = compile_function_src(alloc_def_src, "allocempty", { **globals(), **global_env }) return numba_basic.numba_njit(alloc_fn)
def numba_funcify_SoftmaxGrad(op, node, **kwargs): sm_at = node.inputs[1] sm_dtype = sm_at.type.numpy_dtype sm_dtype = numba.np.numpy_support.from_dtype(sm_dtype) axis = op.axis if axis is not None: reduce_sum_py = create_axis_reducer(add_as, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True) jit_fn = numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) reduce_sum = jit_fn(reduce_sum_py) else: reduce_sum = np.sum def softmax_grad_py_fn(dy, sm): dy_times_sm = dy * sm sum_dy_times_sm = reduce_sum(dy_times_sm) dx = dy_times_sm - sum_dy_times_sm * sm return dx softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn) return softmax_grad
def jit_compile_reducer(node, fn, **kwds): """Compile Python source for reduction loops using additional optimizations. Parameters ========== node An node from which the signature can be derived. fn The Python function object to compile. kwds Extra keywords to be added to the :func:`numba.njit` function. Returns ======= A :func:`numba.njit`-compiled function. """ signature = create_numba_signature(node, reduce_to_scalar=True) # Eagerly compile the function using increased optimizations. This should # help improve nested loop reductions. with use_optimized_cheap_pass(): res = numba_basic.numba_njit( signature, boundscheck=False, fastmath=config.numba__fastmath, **kwds, )(fn) return res
def numba_funcify_LogSoftmax(op, node, **kwargs): x_at = node.inputs[0] x_dtype = x_at.type.numpy_dtype x_dtype = numba.np.numpy_support.from_dtype(x_dtype) axis = op.axis if axis is not None: reduce_max_py = create_axis_reducer( scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True ) reduce_sum_py = create_axis_reducer( add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True ) jit_fn = numba_basic.numba_njit( boundscheck=False, fastmath=config.numba__fastmath ) reduce_max = jit_fn(reduce_max_py) reduce_sum = jit_fn(reduce_sum_py) else: reduce_max = np.max reduce_sum = np.sum def log_softmax_py_fn(x): xdev = x - reduce_max(x) lsm = xdev - np.log(reduce_sum(np.exp(xdev))) return lsm log_softmax = jit_compile_reducer(node, log_softmax_py_fn) return log_softmax
def numba_funcify_Composite(op, node, **kwargs): signature = create_numba_signature(node, force_scalar=True) composite_fn = numba_basic.numba_njit( signature, fastmath=config.numba__fastmath)(numba_funcify(op.fgraph, squeeze_output=True, **kwargs)) return composite_fn
def numba_funcify_Elemwise(op, node, **kwargs): scalar_op_fn = numba_funcify(op.scalar_op, node=node, inline="always", **kwargs) elemwise_fn = create_vectorize_func(scalar_op_fn, node, use_signature=False) elemwise_fn_name = elemwise_fn.__name__ if op.inplace_pattern: input_idx = op.inplace_pattern[0] sign_obj = inspect.signature(elemwise_fn.py_scalar_func) input_names = list(sign_obj.parameters.keys()) unique_names = unique_name_generator([elemwise_fn_name, "np"], suffix_sep="_") input_names = [unique_names(i, force_unique=True) for i in input_names] updated_input_name = input_names[input_idx] inplace_global_env = {elemwise_fn_name: elemwise_fn, "np": np} inplace_elemwise_fn_name = f"{elemwise_fn_name}_inplace" input_signature_str = ", ".join(input_names) if node.inputs[input_idx].ndim > 0: inplace_elemwise_src = f""" def {inplace_elemwise_fn_name}({input_signature_str}): return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name}) """ else: # We can't perform in-place updates on Numba scalars, so we need to # convert them to NumPy scalars. # TODO: We should really prevent the rewrites from creating # in-place updates on scalars when the Numba mode is selected (or # in general?). inplace_elemwise_src = f""" def {inplace_elemwise_fn_name}({input_signature_str}): {updated_input_name}_scalar = np.asarray({updated_input_name}) return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name}_scalar).item() """ inplace_elemwise_fn = compile_function_src( inplace_elemwise_src, inplace_elemwise_fn_name, { **globals(), **inplace_global_env }, ) return numba_basic.numba_njit( inline="always", fastmath=config.numba__fastmath)(inplace_elemwise_fn) return elemwise_fn
def numba_funcify_Mul(op, node, **kwargs): signature = create_numba_signature(node, force_scalar=True) nary_mul_fn = binary_to_nary_func(node.inputs, "mul", "*") return numba_basic.numba_njit(signature, inline="always", fastmath=config.numba__fastmath)(nary_mul_fn)
def numba_funcify_ScalarOp(op, node, **kwargs): # TODO: Do we need to cache these functions so that we don't end up # compiling the same Numba function over and over again? scalar_func_name = op.nfunc_spec[0] if scalar_func_name.startswith("scipy."): func_package = scipy scalar_func_name = scalar_func_name.split(".", 1)[-1] else: func_package = np if "." in scalar_func_name: scalar_func = reduce(getattr, [scipy] + scalar_func_name.split(".")) else: scalar_func = getattr(func_package, scalar_func_name) scalar_op_fn_name = get_name_for_object(scalar_func) unique_names = unique_name_generator([scalar_op_fn_name, "scalar_func"], suffix_sep="_") input_names = ", ".join( [unique_names(v, force_unique=True) for v in node.inputs]) global_env = {"scalar_func": scalar_func} scalar_op_src = f""" def {scalar_op_fn_name}({input_names}): return scalar_func({input_names}) """ scalar_op_fn = compile_function_src(scalar_op_src, scalar_op_fn_name, { **globals(), **global_env }) signature = create_numba_signature(node, force_scalar=True) return numba_basic.numba_njit( signature, inline="always", fastmath=config.numba__fastmath)(scalar_op_fn)
def numba_funcify_Scan(op, node, **kwargs): inner_fg = FunctionGraph(op.inner_inputs, op.inner_outputs) numba_at_inner_func = numba_basic.numba_njit( numba_funcify(inner_fg, **kwargs)) n_seqs = op.info.n_seqs n_mit_mot = op.info.n_mit_mot n_mit_sot = op.info.n_mit_sot n_nit_sot = op.info.n_nit_sot n_sit_sot = op.info.n_sit_sot tap_array = op.info.tap_array n_shared_outs = op.info.n_shared_outs mit_mot_in_taps = tuple(tap_array[:n_mit_mot]) mit_sot_in_taps = tuple(tap_array[n_mit_mot:n_mit_mot + n_mit_sot]) p_in_mit_mot = n_seqs p_in_mit_sot = p_in_mit_mot + n_mit_mot p_in_sit_sot = p_in_mit_sot + n_mit_sot p_outer_in_shared = p_in_sit_sot + n_sit_sot p_outer_in_nit_sot = p_outer_in_shared + n_shared_outs p_outer_in_non_seqs = p_outer_in_nit_sot + n_nit_sot input_names = [n.auto_name for n in node.inputs[1:]] outer_in_seqs_names = input_names[:n_seqs] outer_in_mit_mot_names = input_names[p_in_mit_mot:p_in_mit_mot + n_mit_mot] outer_in_mit_sot_names = input_names[p_in_mit_sot:p_in_mit_sot + n_mit_sot] outer_in_sit_sot_names = input_names[p_in_sit_sot:p_in_sit_sot + n_sit_sot] outer_in_shared_names = input_names[p_outer_in_shared:p_outer_in_shared + n_shared_outs] outer_in_nit_sot_names = input_names[ p_outer_in_nit_sot:p_outer_in_nit_sot + n_nit_sot] outer_in_feedback_names = input_names[n_seqs:p_outer_in_non_seqs] outer_in_non_seqs_names = input_names[p_outer_in_non_seqs:] inner_in_indexed = [] allocate_mem_to_nit_sot = "" for _name in outer_in_seqs_names: # A sequence with multiple taps is provided as multiple modified # input sequences to the Scan Op sliced appropriately # to keep following the logic of a normal sequence. index = "[i]" inner_in_indexed.append(_name + index) name_to_input_map = dict(zip(input_names, node.inputs[1:])) mit_sot_name_to_taps = dict(zip(outer_in_mit_sot_names, mit_sot_in_taps)) inner_out_name_to_index = {} for _name in outer_in_feedback_names: if _name in outer_in_mit_sot_names: curr_taps = mit_sot_name_to_taps[_name] min_tap = min(curr_taps) for _tap in curr_taps: index = idx_to_str(_tap - min_tap) inner_in_indexed.append(_name + index) inner_out_name_to_index[_name] = -min_tap if _name in outer_in_sit_sot_names: # Note that the outputs with single taps which are not # -1 are (for instance taps = [-2]) are classified # as mit-sot so the code for handling sit-sots remains # constant as follows index = "[i]" inner_in_indexed.append(_name + index) inner_out_name_to_index[_name] = 1 if _name in outer_in_nit_sot_names: inner_out_name_to_index[_name] = 0 # In case of nit-sots we are provided shape of the array # instead of actual arrays like other cases, hence we # allocate space for the results accordingly. curr_nit_sot_position = input_names.index(_name) - n_seqs curr_nit_sot = inner_fg.outputs[curr_nit_sot_position] mem_shape = ["1"] * curr_nit_sot.ndim curr_dtype = curr_nit_sot.type.numpy_dtype.name allocate_mem_to_nit_sot += f""" {_name} = [np.zeros(({create_arg_string(mem_shape)}), dtype=np.{curr_dtype})]*{_name}.item() """ # The non_seqs are passed to inner function as-is inner_in_indexed += outer_in_non_seqs_names inner_out_indexed = [ _name + idx_to_str(idx) for _name, idx in inner_out_name_to_index.items() ] while_logic = "" if op.info.as_while: # The inner function will be returning a boolean as last argument inner_out_indexed.append("while_flag") while_logic += """ if while_flag: """ for _name, idx in inner_out_name_to_index.items(): while_logic += f""" {_name} = {_name}[:i+{idx+1}] """ while_logic += """ break """ global_env = locals() global_env["np"] = np scan_op_src = f""" def scan(n_steps, {", ".join(input_names)}): {allocate_mem_to_nit_sot} for i in range(n_steps): inner_args = {create_tuple_string(inner_in_indexed)} {create_tuple_string(inner_out_indexed)} = numba_at_inner_func(*inner_args) {while_logic} return {create_arg_string( outer_in_mit_sot_names + outer_in_sit_sot_names + outer_in_nit_sot_names )} """ scalar_op_fn = compile_function_src(scan_op_src, "scan", { **globals(), **global_env }) return numba_basic.numba_njit(scalar_op_fn)
def make_numba_random_fn(node, np_random_func): """Create Numba implementations for existing Numba-supported ``np.random`` functions. The functions generated here add parameter broadcasting and the ``size`` argument to the Numba-supported scalar ``np.random`` functions. """ tuple_size = int(get_vector_length(node.inputs[1])) size_dims = tuple_size - max(i.ndim for i in node.inputs[3:]) # Make a broadcast-capable version of the Numba supported scalar sampling # function bcast_fn_name = f"aesara_random_{get_name_for_object(np_random_func)}" sized_fn_name = "sized_random_variable" unique_names = unique_name_generator( [ bcast_fn_name, sized_fn_name, "np", "np_random_func", "numba_vectorize", "to_fixed_tuple", "tuple_size", "size_dims", "rng", "size", "dtype", ], suffix_sep="_", ) bcast_fn_input_names = ", ".join( [unique_names(i, force_unique=True) for i in node.inputs[3:]]) bcast_fn_global_env = { "np_random_func": np_random_func, "numba_vectorize": numba_basic.numba_vectorize, } bcast_fn_src = f""" @numba_vectorize def {bcast_fn_name}({bcast_fn_input_names}): return np_random_func({bcast_fn_input_names}) """ bcast_fn = compile_function_src(bcast_fn_src, bcast_fn_name, { **globals(), **bcast_fn_global_env }) random_fn_input_names = ", ".join( ["rng", "size", "dtype"] + [unique_names(i) for i in node.inputs[3:]]) # Now, create a Numba JITable function that implements the `size` parameter out_dtype = node.outputs[1].type.numpy_dtype random_fn_global_env = { bcast_fn_name: bcast_fn, "out_dtype": out_dtype, } if tuple_size > 0: random_fn_body = dedent(f""" size = to_fixed_tuple(size, tuple_size) data = np.empty(size, dtype=out_dtype) for i in np.ndindex(size[:size_dims]): data[i] = {bcast_fn_name}({bcast_fn_input_names}) """) random_fn_global_env.update({ "np": np, "to_fixed_tuple": numba_ndarray.to_fixed_tuple, "tuple_size": tuple_size, "size_dims": size_dims, }) else: random_fn_body = f"""data = {bcast_fn_name}({bcast_fn_input_names})""" sized_fn_src = dedent(f""" def {sized_fn_name}({random_fn_input_names}): {indent(random_fn_body, " " * 4)} return (rng, data) """) random_fn = compile_function_src(sized_fn_src, sized_fn_name, { **globals(), **random_fn_global_env }) random_fn = numba_basic.numba_njit(random_fn) return random_fn
def create_multiaxis_reducer(scalar_op, identity, axes, ndim, dtype, input_name="input"): r"""Construct a function that reduces multiple axes. The functions generated by this function take the following form: .. code-block:: python def careduce_maximum(input): axis_0_res = careduce_axes_fn_0(input) axis_1_res = careduce_axes_fn_1(axis_0_res) ... axis_N_res = careduce_axes_fn_N(axis_N_minus_1_res) return axis_N_res The range 0-N is determined by the `axes` argument (i.e. the axes to be reduced). Parameters ========== scalar_op: The scalar :class:`Op` that performs the desired reduction. identity: The identity value for the reduction. axes: The axes to reduce. ndim: The number of dimensions of the result. dtype: The data type of the result. Returns ======= A Python function that can be JITed. """ if len(axes) == 1: return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype) careduce_fn_name = f"careduce_{scalar_op}" global_env = {} to_reduce = reversed(sorted(axes)) careduce_lines_src = [] var_name = input_name for i, axis in enumerate(to_reduce): careducer_axes_fn_name = f"careduce_axes_fn_{i}" reducer_py_fn = create_axis_reducer(scalar_op, identity, axis, ndim, dtype) reducer_fn = numba_basic.numba_njit( boundscheck=False, fastmath=config.numba__fastmath)(reducer_py_fn) global_env[careducer_axes_fn_name] = reducer_fn ndim -= 1 last_var_name = var_name var_name = f"axis_{i}_res" careduce_lines_src.append( f"{var_name} = {careducer_axes_fn_name}({last_var_name})") careduce_assign_lines = indent("\n".join(careduce_lines_src), " " * 4) careduce_def_src = f""" def {careduce_fn_name}({input_name}): {careduce_assign_lines} return {var_name} """ careduce_fn = compile_function_src(careduce_def_src, careduce_fn_name, { **globals(), **global_env }) return careduce_fn
def numba_funcify_ScalarOp(op, node, **kwargs): # TODO: Do we need to cache these functions so that we don't end up # compiling the same Numba function over and over again? scalar_func_name = op.nfunc_spec[0] if scalar_func_name.startswith("scipy."): func_package = scipy scalar_func_name = scalar_func_name.split(".", 1)[-1] else: func_package = np if "." in scalar_func_name: scalar_func = reduce(getattr, [scipy] + scalar_func_name.split(".")) else: scalar_func = getattr(func_package, scalar_func_name) scalar_op_fn_name = get_name_for_object(scalar_func) unique_names = unique_name_generator([scalar_op_fn_name, "scalar_func"], suffix_sep="_") global_env = {"scalar_func": scalar_func} input_tmp_dtypes = None if func_package == scipy and hasattr(scalar_func, "types"): # The `numba-scipy` bindings don't provide implementations for all # inputs types, so we need to convert the inputs to floats and back. inp_dtype_kinds = tuple( np.dtype(inp.type.dtype).kind for inp in node.inputs) accepted_inp_kinds = tuple( sig_type.split("->")[0] for sig_type in scalar_func.types) if not any( all(dk == ik for dk, ik in zip(inp_dtype_kinds, ok_kinds)) for ok_kinds in accepted_inp_kinds): # They're usually ordered from lower-to-higher precision, so # we pick the last acceptable input types # # XXX: We should pick the first acceptable float/int types in # reverse, excluding all the incompatible ones (e.g. `"0"`). # The assumption is that this is only used by `numba-scipy`-exposed # functions, although it's possible for this to be triggered by # something else from the `scipy` package input_tmp_dtypes = tuple( np.dtype(k) for k in accepted_inp_kinds[-1]) if input_tmp_dtypes is None: unique_names = unique_name_generator( [scalar_op_fn_name, "scalar_func"], suffix_sep="_") input_names = ", ".join( [unique_names(v, force_unique=True) for v in node.inputs]) scalar_op_src = f""" def {scalar_op_fn_name}({input_names}): return scalar_func({input_names}) """ else: global_env["direct_cast"] = numba_basic.direct_cast global_env["output_dtype"] = np.dtype(node.outputs[0].type.dtype) input_tmp_dtype_names = { f"inp_tmp_dtype_{i}": i_dtype for i, i_dtype in enumerate(input_tmp_dtypes) } global_env.update(input_tmp_dtype_names) unique_names = unique_name_generator( [scalar_op_fn_name, "scalar_func"] + list(global_env.keys()), suffix_sep="_") input_names = [unique_names(v, force_unique=True) for v in node.inputs] converted_call_args = ", ".join([ f"direct_cast({i_name}, {i_tmp_dtype_name})" for i_name, i_tmp_dtype_name in zip(input_names, input_tmp_dtype_names.keys()) ]) scalar_op_src = f""" def {scalar_op_fn_name}({', '.join(input_names)}): return direct_cast(scalar_func({converted_call_args}), output_dtype) """ scalar_op_fn = compile_function_src(scalar_op_src, scalar_op_fn_name, { **globals(), **global_env }) signature = create_numba_signature(node, force_scalar=True) return numba_basic.numba_njit( signature, inline="always", fastmath=config.numba__fastmath)(scalar_op_fn)