def numba_funcify_QRFull(op, node, **kwargs): mode = op.mode if mode != "reduced": warnings.warn( ("Numba will use object mode to allow the " "`mode` argument to `numpy.linalg.qr`."), UserWarning, ) if len(node.outputs) > 1: ret_sig = numba.types.Tuple( [get_numba_type(o.type) for o in node.outputs]) else: ret_sig = get_numba_type(node.outputs[0].type) @numba.njit def qr_full(x): with numba.objmode(ret=ret_sig): ret = np.linalg.qr(x, mode=mode) return ret else: out_dtype = node.outputs[0].type.numpy_dtype inputs_cast = int_to_float_fn(node.inputs, out_dtype) @numba.njit(inline="always") def qr_full(x): return np.linalg.qr(inputs_cast(x)) return qr_full
def numba_funcify_SVD(op, node, **kwargs): full_matrices = op.full_matrices compute_uv = op.compute_uv if not compute_uv: warnings.warn( ("Numba will use object mode to allow the " "`compute_uv` argument to `numpy.linalg.svd`."), UserWarning, ) ret_sig = get_numba_type(node.outputs[0].type) @numba.njit def svd(x): with numba.objmode(ret=ret_sig): ret = np.linalg.svd(x, full_matrices, compute_uv) return ret else: out_dtype = node.outputs[0].type.numpy_dtype inputs_cast = int_to_float_fn(node.inputs, out_dtype) @numba.njit(inline="always") def svd(x): return np.linalg.svd(inputs_cast(x), full_matrices) return svd
def numba_funcify_MatrixPinv(op, node, **kwargs): out_dtype = node.outputs[0].type.numpy_dtype inputs_cast = int_to_float_fn(node.inputs, out_dtype) @numba.njit(inline="always") def matrixpinv(x): return np.linalg.pinv(inputs_cast(x)).astype(out_dtype) return matrixpinv
def numba_funcify_Det(op, node, **kwargs): out_dtype = node.outputs[0].type.numpy_dtype inputs_cast = int_to_float_fn(node.inputs, out_dtype) @numba.njit(inline="always") def det(x): return numba_basic.direct_cast(np.linalg.det(inputs_cast(x)), out_dtype) return det
def numba_funcify_Eig(op, node, **kwargs): out_dtype_1 = node.outputs[0].type.numpy_dtype out_dtype_2 = node.outputs[1].type.numpy_dtype inputs_cast = int_to_float_fn(node.inputs, out_dtype_1) @numba.njit def eig(x): out = np.linalg.eig(inputs_cast(x)) return (out[0].astype(out_dtype_1), out[1].astype(out_dtype_2)) return eig